|
|
|
@ -11,6 +11,9 @@
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
import inspect
|
|
|
|
|
from functools import partial
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
|
from paddle import nn
|
|
|
|
|
from paddle.nn import functional as F
|
|
|
|
@ -32,18 +35,19 @@ class CTCLoss(nn.Layer):
|
|
|
|
|
# last token id as blank id
|
|
|
|
|
self.loss = nn.CTCLoss(blank=blank, reduction=reduction)
|
|
|
|
|
self.batch_average = batch_average
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"CTCLoss Loss reduction: {reduction}, div-bs: {batch_average}")
|
|
|
|
|
logger.info(f"CTCLoss Grad Norm Type: {grad_norm_type}")
|
|
|
|
|
|
|
|
|
|
# instance for norm_by_times
|
|
|
|
|
# batch for norm_by_batchsize
|
|
|
|
|
# frame for norm_by_total_logits_len
|
|
|
|
|
assert grad_norm_type in ('instance', 'batch', 'frame', None)
|
|
|
|
|
self.norm_by_times = False
|
|
|
|
|
self.norm_by_batchsize = False
|
|
|
|
|
self.norm_by_total_logits_len = False
|
|
|
|
|
logger.info(f"CTCLoss Grad Norm Type: {grad_norm_type}")
|
|
|
|
|
if grad_norm_type == 'instance':
|
|
|
|
|
if grad_norm_type is None:
|
|
|
|
|
# no grad norm
|
|
|
|
|
pass
|
|
|
|
|
elif grad_norm_type == 'instance':
|
|
|
|
|
self.norm_by_times = True
|
|
|
|
|
elif grad_norm_type == 'batch':
|
|
|
|
|
self.norm_by_batchsize = True
|
|
|
|
@ -51,6 +55,22 @@ class CTCLoss(nn.Layer):
|
|
|
|
|
self.norm_by_total_logits_len = True
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"CTCLoss Grad Norm no support {grad_norm_type}")
|
|
|
|
|
self.kwargs = {
|
|
|
|
|
"norm_by_times": self.norm_by_times,
|
|
|
|
|
"norm_by_batchsize": self.norm_by_batchsize,
|
|
|
|
|
"norm_by_total_logits_len": self.norm_by_total_logits_len,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Derive only the args which the func has
|
|
|
|
|
try:
|
|
|
|
|
param = inspect.signature(self.loss.forward).parameters
|
|
|
|
|
except ValueError:
|
|
|
|
|
# Some function, e.g. built-in function, are failed
|
|
|
|
|
param = {}
|
|
|
|
|
_kwargs = {k: v for k, v in self.kwargs.items() if k in param}
|
|
|
|
|
_notin = {k: v for k, v in self.kwargs.items() if k not in param}
|
|
|
|
|
logger.info(f"{self.loss} kwargs:{_kwargs}, not support: {_notin}")
|
|
|
|
|
self.loss_fn = partial(self.loss.forward, **_kwargs)
|
|
|
|
|
|
|
|
|
|
def forward(self, logits, ys_pad, hlens, ys_lens):
|
|
|
|
|
"""Compute CTC loss.
|
|
|
|
@ -70,14 +90,7 @@ class CTCLoss(nn.Layer):
|
|
|
|
|
# logits: (B, L, D) -> (L, B, D)
|
|
|
|
|
logits = logits.transpose([1, 0, 2])
|
|
|
|
|
ys_pad = ys_pad.astype(paddle.int32)
|
|
|
|
|
loss = self.loss(
|
|
|
|
|
logits,
|
|
|
|
|
ys_pad,
|
|
|
|
|
hlens,
|
|
|
|
|
ys_lens,
|
|
|
|
|
norm_by_times=self.norm_by_times,
|
|
|
|
|
norm_by_batchsize=self.norm_by_batchsize,
|
|
|
|
|
norm_by_total_logits_len=self.norm_by_total_logits_len)
|
|
|
|
|
loss = self.loss_fn(logits, ys_pad, hlens, ys_lens)
|
|
|
|
|
if self.batch_average:
|
|
|
|
|
# Batch-size average
|
|
|
|
|
loss = loss / B
|
|
|
|
@ -152,7 +165,7 @@ class LabelSmoothingLoss(nn.Layer):
|
|
|
|
|
# use zeros_like instead of torch.no_grad() for true_dist,
|
|
|
|
|
# since no_grad() can not be exported by JIT
|
|
|
|
|
true_dist = paddle.full_like(x, self.smoothing / (self.size - 1))
|
|
|
|
|
ignore = (target == self.padding_idx) # (B,)
|
|
|
|
|
ignore = target == self.padding_idx # (B,)
|
|
|
|
|
|
|
|
|
|
#TODO(Hui Zhang): target = target * (1 - ignore) # avoid -1 index
|
|
|
|
|
target = target.masked_fill(ignore, 0) # avoid -1 index
|
|
|
|
@ -163,8 +176,10 @@ class LabelSmoothingLoss(nn.Layer):
|
|
|
|
|
|
|
|
|
|
kl = self.criterion(F.log_softmax(x, axis=1), true_dist)
|
|
|
|
|
|
|
|
|
|
total = len(target) - int(ignore.sum())
|
|
|
|
|
#TODO(Hui Zhang): sum not support bool type
|
|
|
|
|
#total = len(target) - int(ignore.sum())
|
|
|
|
|
total = len(target) - int(ignore.type_as(target).sum())
|
|
|
|
|
denom = total if self.normalize_length else B
|
|
|
|
|
#TODO(Hui Zhang): numer = (kl * (1 - ignore)).sum()
|
|
|
|
|
#numer = (kl * (1 - ignore)).sum()
|
|
|
|
|
numer = kl.masked_fill(ignore.unsqueeze(1), 0).sum()
|
|
|
|
|
return numer / denom
|
|
|
|
|