adapt view behavior change, fix KeyError. (#3794)

* adapt view behavior change, fix KeyError.

* fix readme demo run error.

* fixed opencc version
pull/3807/head
zxcd 4 months ago committed by GitHub
parent e8018a11ce
commit 91170bd260
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -274,7 +274,7 @@ class ASRExecutor(BaseExecutor):
# fbank # fbank
audio = preprocessing(audio, **preprocess_args) audio = preprocessing(audio, **preprocess_args)
audio_len = paddle.to_tensor([audio.shape[0]]).unsqueeze(axis=0) audio_len = paddle.to_tensor(audio.shape[0]).unsqueeze(axis=0)
audio = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0) audio = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0)
self._inputs["audio"] = audio self._inputs["audio"] = audio

@ -188,7 +188,7 @@ class Wav2vec2ASR(nn.Layer):
x_lens = x.shape[1] x_lens = x.shape[1]
ctc_probs = self.ctc.log_softmax(x) # (B, maxlen, vocab_size) ctc_probs = self.ctc.log_softmax(x) # (B, maxlen, vocab_size)
topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1) topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1)
topk_index = topk_index.view([batch_size, x_lens]) # (B, maxlen) topk_index = topk_index.reshape([batch_size, x_lens]) # (B, maxlen)
hyps = [hyp.tolist() for hyp in topk_index] hyps = [hyp.tolist() for hyp in topk_index]
hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps] hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]

@ -48,7 +48,7 @@ base = [
"matplotlib", "matplotlib",
"nara_wpe", "nara_wpe",
"onnxruntime>=1.11.0", "onnxruntime>=1.11.0",
"opencc", "opencc==1.1.6",
"opencc-python-reimplemented", "opencc-python-reimplemented",
"pandas", "pandas",
"paddleaudio>=1.1.0", "paddleaudio>=1.1.0",

Loading…
Cancel
Save