|
|
|
@ -188,7 +188,7 @@ class Wav2vec2ASR(nn.Layer):
|
|
|
|
|
x_lens = x.shape[1]
|
|
|
|
|
ctc_probs = self.ctc.log_softmax(x) # (B, maxlen, vocab_size)
|
|
|
|
|
topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1)
|
|
|
|
|
topk_index = topk_index.view(batch_size, x_lens) # (B, maxlen)
|
|
|
|
|
topk_index = topk_index.view([batch_size, x_lens]) # (B, maxlen)
|
|
|
|
|
|
|
|
|
|
hyps = [hyp.tolist() for hyp in topk_index]
|
|
|
|
|
hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]
|
|
|
|
|