diff --git a/decoders/swig_wrapper.py b/decoders/swig_wrapper.py index 0a0579ad0..3051f4e82 100644 --- a/decoders/swig_wrapper.py +++ b/decoders/swig_wrapper.py @@ -46,7 +46,7 @@ def ctc_greedy_decoder(probs_seq, vocabulary): :rtype: str """ result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary) - return result.decode('utf-8') + return result def ctc_beam_search_decoder(probs_seq, diff --git a/model_utils/network.py b/model_utils/network.py index c8fa95dc4..1e7545ee6 100644 --- a/model_utils/network.py +++ b/model_utils/network.py @@ -686,8 +686,9 @@ class DeepSpeech2(nn.Layer): def decode(self, audio, audio_len, vocab_list, decoding_method, lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, cutoff_top_n, num_processes): - _, probs, _ = self.predict(audio, audio_len) - return self.decode_probs(probs.numpy(), vocab_list, decoding_method, + _, probs, audio_lens = self.predict(audio, audio_len) + probs_split = [probs[i, :l, :] for i, l in enumerate(audio_lens)] + return self.decode_probs(probs_split, vocab_list, decoding_method, lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, cutoff_top_n, num_processes)