|
|
|
@ -54,7 +54,7 @@ class CTCLoss(nn.Layer):
|
|
|
|
|
self.norm_by_total_logits_len = True
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"CTCLoss Grad Norm no support {grad_norm_type}")
|
|
|
|
|
self.kwargs = {
|
|
|
|
|
kwargs = {
|
|
|
|
|
"norm_by_times": self.norm_by_times,
|
|
|
|
|
"norm_by_batchsize": self.norm_by_batchsize,
|
|
|
|
|
"norm_by_total_logits_len": self.norm_by_total_logits_len,
|
|
|
|
@ -66,10 +66,9 @@ class CTCLoss(nn.Layer):
|
|
|
|
|
except ValueError:
|
|
|
|
|
# Some function, e.g. built-in function, are failed
|
|
|
|
|
param = {}
|
|
|
|
|
self._kwargs = {k: v for k, v in self.kwargs.items() if k in param}
|
|
|
|
|
_notin = {k: v for k, v in self.kwargs.items() if k not in param}
|
|
|
|
|
self._kwargs = {k: v for k, v in kwargs.items() if k in param}
|
|
|
|
|
_notin = {k: v for k, v in kwargs.items() if k not in param}
|
|
|
|
|
logger.info(f"{self.loss} kwargs:{self._kwargs}, not support: {_notin}")
|
|
|
|
|
#self.loss_fn = partial(self.loss.forward, **_kwargs)
|
|
|
|
|
|
|
|
|
|
def forward(self, logits, ys_pad, hlens, ys_lens):
|
|
|
|
|
"""Compute CTC loss.
|
|
|
|
@ -89,8 +88,7 @@ class CTCLoss(nn.Layer):
|
|
|
|
|
# logits: (B, L, D) -> (L, B, D)
|
|
|
|
|
logits = logits.transpose([1, 0, 2])
|
|
|
|
|
ys_pad = ys_pad.astype(paddle.int32)
|
|
|
|
|
#loss = self.loss_fn(logits, ys_pad, hlens, ys_lens)
|
|
|
|
|
loss = self.loss(logits, ys_pad, hlens, ys_lens)
|
|
|
|
|
loss = self.loss(logits, ys_pad, hlens, ys_lens, **self._kwargs)
|
|
|
|
|
if self.batch_average:
|
|
|
|
|
# Batch-size average
|
|
|
|
|
loss = loss / B
|
|
|
|
|