fix decoding probs bugs

pull/522/head
Hui Zhang 5 years ago
parent 49d55a865c
commit 1f206a69e6

@ -373,6 +373,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
target_transcripts = self.ordid2token(texts, texts_len)
result_transcripts = self.model.decode_probs(
probs.numpy(),
logits_len,
vocab_list,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,

@ -661,16 +661,19 @@ class DeepSpeech2(nn.Layer):
self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path,
vocab_list)
def decode_probs(self, probs, vocab_list, decoding_method, lang_model_path,
beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes):
""" probs: activation after softmax """
def decode_probs(self, probs, logits_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size,
cutoff_prob, cutoff_top_n, num_processes):
""" probs: activation after softmax
logits_len: audio output lens
"""
probs_split = [probs[i, :l, :] for i, l in enumerate(logits_len)]
if decoding_method == "ctc_greedy":
result_transcripts = self._decode_batch_greedy(
probs_split=probs, vocab_list=vocab_list)
probs_split=probs_split, vocab_list=vocab_list)
elif decoding_method == "ctc_beam_search":
result_transcripts = self._decode_batch_beam_search(
probs_split=probs,
probs_split=probs_split,
beam_alpha=beam_alpha,
beam_beta=beam_beta,
beam_size=beam_size,
@ -686,12 +689,11 @@ class DeepSpeech2(nn.Layer):
def decode(self, audio, audio_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes):
_, probs, audio_lens = self.predict(audio, audio_len)
probs_split = [probs[i, :l, :] for i, l in enumerate(audio_lens)]
return self.decode_probs(probs_split, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta,
beam_size, cutoff_prob, cutoff_top_n,
num_processes)
_, probs, logits_lens = self.predict(audio, audio_len)
return self.decode_probs(probs.numpy(), logits_lens, vocab_list,
decoding_method, lang_model_path, beam_alpha,
beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes)
def from_pretrained(self, checkpoint_path):
"""Build a model from a pretrained model.

@ -114,7 +114,7 @@ def tune(config, args):
return trans
audio, text, audio_len, text_len = infer_data
_, probs, _ = model.predict(audio, audio_len)
_, probs, logits_lens = model.predict(audio, audio_len)
target_transcripts = ordid2token(text, text_len)
num_ins += audio.shape[0]
@ -122,7 +122,8 @@ def tune(config, args):
for index, (alpha, beta) in enumerate(params_grid):
print(f"tuneing: alpha={alpha} beta={beta}")
result_transcripts = model.decode_probs(
probs.numpy(), vocab_list, config.decoding.decoding_method,
probs.numpy(), logits_lens, vocab_list,
config.decoding.decoding_method,
config.decoding.lang_model_path, alpha, beta,
config.decoding.beam_size, config.decoding.cutoff_prob,
config.decoding.cutoff_top_n, config.decoding.num_proc_bsearch)

Loading…
Cancel
Save