From 1f206a69e66c4409cd67d3bdf23a8dafb93932ab Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Fri, 26 Feb 2021 03:29:51 +0000 Subject: [PATCH] fix decoding probs bugs --- model_utils/model.py | 1 + model_utils/network.py | 26 ++++++++++++++------------ tune.py | 5 +++-- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/model_utils/model.py b/model_utils/model.py index 6520d94a3..174410a91 100644 --- a/model_utils/model.py +++ b/model_utils/model.py @@ -373,6 +373,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): target_transcripts = self.ordid2token(texts, texts_len) result_transcripts = self.model.decode_probs( probs.numpy(), + logits_len, vocab_list, decoding_method=cfg.decoding_method, lang_model_path=cfg.lang_model_path, diff --git a/model_utils/network.py b/model_utils/network.py index 1e7545ee6..0c2f4dbf2 100644 --- a/model_utils/network.py +++ b/model_utils/network.py @@ -661,16 +661,19 @@ class DeepSpeech2(nn.Layer): self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path, vocab_list) - def decode_probs(self, probs, vocab_list, decoding_method, lang_model_path, - beam_alpha, beam_beta, beam_size, cutoff_prob, - cutoff_top_n, num_processes): - """ probs: activation after softmax """ + def decode_probs(self, probs, logits_len, vocab_list, decoding_method, + lang_model_path, beam_alpha, beam_beta, beam_size, + cutoff_prob, cutoff_top_n, num_processes): + """ probs: activation after softmax + logits_len: audio output lens + """ + probs_split = [probs[i, :l, :] for i, l in enumerate(logits_len)] if decoding_method == "ctc_greedy": result_transcripts = self._decode_batch_greedy( - probs_split=probs, vocab_list=vocab_list) + probs_split=probs_split, vocab_list=vocab_list) elif decoding_method == "ctc_beam_search": result_transcripts = self._decode_batch_beam_search( - probs_split=probs, + probs_split=probs_split, beam_alpha=beam_alpha, beam_beta=beam_beta, beam_size=beam_size, @@ -686,12 +689,11 @@ class DeepSpeech2(nn.Layer): def decode(self, audio, audio_len, vocab_list, decoding_method, lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, cutoff_top_n, num_processes): - _, probs, audio_lens = self.predict(audio, audio_len) - probs_split = [probs[i, :l, :] for i, l in enumerate(audio_lens)] - return self.decode_probs(probs_split, vocab_list, decoding_method, - lang_model_path, beam_alpha, beam_beta, - beam_size, cutoff_prob, cutoff_top_n, - num_processes) + _, probs, logits_lens = self.predict(audio, audio_len) + return self.decode_probs(probs.numpy(), logits_lens, vocab_list, + decoding_method, lang_model_path, beam_alpha, + beam_beta, beam_size, cutoff_prob, + cutoff_top_n, num_processes) def from_pretrained(self, checkpoint_path): """Build a model from a pretrained model. diff --git a/tune.py b/tune.py index b269265ae..6e25af154 100644 --- a/tune.py +++ b/tune.py @@ -114,7 +114,7 @@ def tune(config, args): return trans audio, text, audio_len, text_len = infer_data - _, probs, _ = model.predict(audio, audio_len) + _, probs, logits_lens = model.predict(audio, audio_len) target_transcripts = ordid2token(text, text_len) num_ins += audio.shape[0] @@ -122,7 +122,8 @@ def tune(config, args): for index, (alpha, beta) in enumerate(params_grid): print(f"tuneing: alpha={alpha} beta={beta}") result_transcripts = model.decode_probs( - probs.numpy(), vocab_list, config.decoding.decoding_method, + probs.numpy(), logits_lens, vocab_list, + config.decoding.decoding_method, config.decoding.lang_model_path, alpha, beta, config.decoding.beam_size, config.decoding.cutoff_prob, config.decoding.cutoff_top_n, config.decoding.num_proc_bsearch)