|
|
@ -128,7 +128,7 @@ class Wav2vec2ASR(nn.Layer):
|
|
|
|
# with other batch decoding mode
|
|
|
|
# with other batch decoding mode
|
|
|
|
elif decoding_method == 'ctc_prefix_beam_search':
|
|
|
|
elif decoding_method == 'ctc_prefix_beam_search':
|
|
|
|
assert feats.shape[0] == 1
|
|
|
|
assert feats.shape[0] == 1
|
|
|
|
if tokenizer is None:
|
|
|
|
if tokenizer is None and sb_pipeline is False:
|
|
|
|
hyp = self.ctc_prefix_beam_search(feats, beam_size)
|
|
|
|
hyp = self.ctc_prefix_beam_search(feats, beam_size)
|
|
|
|
res = [text_feature.defeaturize(hyp)]
|
|
|
|
res = [text_feature.defeaturize(hyp)]
|
|
|
|
res_tokenids = [hyp]
|
|
|
|
res_tokenids = [hyp]
|
|
|
|