|
|
|
@ -67,10 +67,10 @@ class CTCLoss(nn.Layer):
|
|
|
|
|
except ValueError:
|
|
|
|
|
# Some function, e.g. built-in function, are failed
|
|
|
|
|
param = {}
|
|
|
|
|
_kwargs = {k: v for k, v in self.kwargs.items() if k in param}
|
|
|
|
|
self._kwargs = {k: v for k, v in self.kwargs.items() if k in param}
|
|
|
|
|
_notin = {k: v for k, v in self.kwargs.items() if k not in param}
|
|
|
|
|
logger.info(f"{self.loss} kwargs:{_kwargs}, not support: {_notin}")
|
|
|
|
|
self.loss_fn = partial(self.loss.forward, **_kwargs)
|
|
|
|
|
logger.info(f"{self.loss} kwargs:{self._kwargs}, not support: {_notin}")
|
|
|
|
|
#self.loss_fn = partial(self.loss.forward, **_kwargs)
|
|
|
|
|
|
|
|
|
|
def forward(self, logits, ys_pad, hlens, ys_lens):
|
|
|
|
|
"""Compute CTC loss.
|
|
|
|
@ -90,7 +90,8 @@ class CTCLoss(nn.Layer):
|
|
|
|
|
# logits: (B, L, D) -> (L, B, D)
|
|
|
|
|
logits = logits.transpose([1, 0, 2])
|
|
|
|
|
ys_pad = ys_pad.astype(paddle.int32)
|
|
|
|
|
loss = self.loss_fn(logits, ys_pad, hlens, ys_lens)
|
|
|
|
|
#loss = self.loss_fn(logits, ys_pad, hlens, ys_lens)
|
|
|
|
|
loss = self.loss(logits, ys_pad, hlens, ys_lens)
|
|
|
|
|
if self.batch_average:
|
|
|
|
|
# Batch-size average
|
|
|
|
|
loss = loss / B
|
|
|
|
|