|
|
|
@ -62,14 +62,15 @@ class CTCLoss(nn.Layer):
|
|
|
|
|
"""Compute CTC loss.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
logits ([paddle.Tensor]): [description]
|
|
|
|
|
ys_pad ([paddle.Tensor]): [description]
|
|
|
|
|
hlens ([paddle.Tensor]): [description]
|
|
|
|
|
ys_lens ([paddle.Tensor]): [description]
|
|
|
|
|
logits ([paddle.Tensor]): [B, Tmax, D]
|
|
|
|
|
ys_pad ([paddle.Tensor]): [B, Tmax]
|
|
|
|
|
hlens ([paddle.Tensor]): [B]
|
|
|
|
|
ys_lens ([paddle.Tensor]): [B]
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
[paddle.Tensor]: scalar. If reduction is 'none', then (N), where N = \text{batch size}.
|
|
|
|
|
"""
|
|
|
|
|
B = paddle.shape(logits)[0]
|
|
|
|
|
# warp-ctc need logits, and do softmax on logits by itself
|
|
|
|
|
# warp-ctc need activation with shape [T, B, V + 1]
|
|
|
|
|
# logits: (B, L, D) -> (L, B, D)
|
|
|
|
@ -78,5 +79,5 @@ class CTCLoss(nn.Layer):
|
|
|
|
|
|
|
|
|
|
# wenet do batch-size average, deepspeech2 not do this
|
|
|
|
|
# Batch-size average
|
|
|
|
|
# loss = loss / paddle.shape(logits)[1]
|
|
|
|
|
# loss = loss / B
|
|
|
|
|
return loss
|
|
|
|
|