|
|
|
@ -81,10 +81,10 @@ class CTCDecoderBase(nn.Layer):
|
|
|
|
|
Args:
|
|
|
|
|
hs_pad (Tensor): batch of padded hidden state sequences (B, Tmax, D)
|
|
|
|
|
hlens (Tensor): batch of lengths of hidden state sequences (B)
|
|
|
|
|
ys_pad (Tenosr): batch of padded character id sequence tensor (B, Lmax)
|
|
|
|
|
ys_pad (Tensor): batch of padded character id sequence tensor (B, Lmax)
|
|
|
|
|
ys_lens (Tensor): batch of lengths of character sequence (B)
|
|
|
|
|
Returns:
|
|
|
|
|
loss (Tenosr): ctc loss value, scalar.
|
|
|
|
|
loss (Tensor): ctc loss value, scalar.
|
|
|
|
|
"""
|
|
|
|
|
logits = self.ctc_lo(self.dropout(hs_pad))
|
|
|
|
|
loss = self.criterion(logits, ys_pad, hlens, ys_lens)
|
|
|
|
@ -252,8 +252,8 @@ class CTCDecoder(CTCDecoderBase):
|
|
|
|
|
"""ctc decoding with probs.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
probs (Tenosr): activation after softmax
|
|
|
|
|
logits_lens (Tenosr): audio output lens
|
|
|
|
|
probs (Tensor): activation after softmax
|
|
|
|
|
logits_lens (Tensor): audio output lens
|
|
|
|
|
vocab_list ([type]): [description]
|
|
|
|
|
decoding_method ([type]): [description]
|
|
|
|
|
lang_model_path ([type]): [description]
|
|
|
|
|