fix decoding probs bugs

pull/522/head
Hui Zhang 5 years ago
parent 49d55a865c
commit 1f206a69e6

@ -373,6 +373,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
target_transcripts = self.ordid2token(texts, texts_len) target_transcripts = self.ordid2token(texts, texts_len)
result_transcripts = self.model.decode_probs( result_transcripts = self.model.decode_probs(
probs.numpy(), probs.numpy(),
logits_len,
vocab_list, vocab_list,
decoding_method=cfg.decoding_method, decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path, lang_model_path=cfg.lang_model_path,

@ -661,16 +661,19 @@ class DeepSpeech2(nn.Layer):
self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path, self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path,
vocab_list) vocab_list)
def decode_probs(self, probs, vocab_list, decoding_method, lang_model_path, def decode_probs(self, probs, logits_len, vocab_list, decoding_method,
beam_alpha, beam_beta, beam_size, cutoff_prob, lang_model_path, beam_alpha, beam_beta, beam_size,
cutoff_top_n, num_processes): cutoff_prob, cutoff_top_n, num_processes):
""" probs: activation after softmax """ """ probs: activation after softmax
logits_len: audio output lens
"""
probs_split = [probs[i, :l, :] for i, l in enumerate(logits_len)]
if decoding_method == "ctc_greedy": if decoding_method == "ctc_greedy":
result_transcripts = self._decode_batch_greedy( result_transcripts = self._decode_batch_greedy(
probs_split=probs, vocab_list=vocab_list) probs_split=probs_split, vocab_list=vocab_list)
elif decoding_method == "ctc_beam_search": elif decoding_method == "ctc_beam_search":
result_transcripts = self._decode_batch_beam_search( result_transcripts = self._decode_batch_beam_search(
probs_split=probs, probs_split=probs_split,
beam_alpha=beam_alpha, beam_alpha=beam_alpha,
beam_beta=beam_beta, beam_beta=beam_beta,
beam_size=beam_size, beam_size=beam_size,
@ -686,12 +689,11 @@ class DeepSpeech2(nn.Layer):
def decode(self, audio, audio_len, vocab_list, decoding_method, def decode(self, audio, audio_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes): cutoff_top_n, num_processes):
_, probs, audio_lens = self.predict(audio, audio_len) _, probs, logits_lens = self.predict(audio, audio_len)
probs_split = [probs[i, :l, :] for i, l in enumerate(audio_lens)] return self.decode_probs(probs.numpy(), logits_lens, vocab_list,
return self.decode_probs(probs_split, vocab_list, decoding_method, decoding_method, lang_model_path, beam_alpha,
lang_model_path, beam_alpha, beam_beta, beam_beta, beam_size, cutoff_prob,
beam_size, cutoff_prob, cutoff_top_n, cutoff_top_n, num_processes)
num_processes)
def from_pretrained(self, checkpoint_path): def from_pretrained(self, checkpoint_path):
"""Build a model from a pretrained model. """Build a model from a pretrained model.

@ -114,7 +114,7 @@ def tune(config, args):
return trans return trans
audio, text, audio_len, text_len = infer_data audio, text, audio_len, text_len = infer_data
_, probs, _ = model.predict(audio, audio_len) _, probs, logits_lens = model.predict(audio, audio_len)
target_transcripts = ordid2token(text, text_len) target_transcripts = ordid2token(text, text_len)
num_ins += audio.shape[0] num_ins += audio.shape[0]
@ -122,7 +122,8 @@ def tune(config, args):
for index, (alpha, beta) in enumerate(params_grid): for index, (alpha, beta) in enumerate(params_grid):
print(f"tuneing: alpha={alpha} beta={beta}") print(f"tuneing: alpha={alpha} beta={beta}")
result_transcripts = model.decode_probs( result_transcripts = model.decode_probs(
probs.numpy(), vocab_list, config.decoding.decoding_method, probs.numpy(), logits_lens, vocab_list,
config.decoding.decoding_method,
config.decoding.lang_model_path, alpha, beta, config.decoding.lang_model_path, alpha, beta,
config.decoding.beam_size, config.decoding.cutoff_prob, config.decoding.beam_size, config.decoding.cutoff_prob,
config.decoding.cutoff_top_n, config.decoding.num_proc_bsearch) config.decoding.cutoff_top_n, config.decoding.num_proc_bsearch)

Loading…
Cancel
Save