|
|
|
@ -661,16 +661,19 @@ class DeepSpeech2(nn.Layer):
|
|
|
|
|
self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path,
|
|
|
|
|
vocab_list)
|
|
|
|
|
|
|
|
|
|
def decode_probs(self, probs, vocab_list, decoding_method, lang_model_path,
|
|
|
|
|
beam_alpha, beam_beta, beam_size, cutoff_prob,
|
|
|
|
|
cutoff_top_n, num_processes):
|
|
|
|
|
""" probs: activation after softmax """
|
|
|
|
|
def decode_probs(self, probs, logits_len, vocab_list, decoding_method,
|
|
|
|
|
lang_model_path, beam_alpha, beam_beta, beam_size,
|
|
|
|
|
cutoff_prob, cutoff_top_n, num_processes):
|
|
|
|
|
""" probs: activation after softmax
|
|
|
|
|
logits_len: audio output lens
|
|
|
|
|
"""
|
|
|
|
|
probs_split = [probs[i, :l, :] for i, l in enumerate(logits_len)]
|
|
|
|
|
if decoding_method == "ctc_greedy":
|
|
|
|
|
result_transcripts = self._decode_batch_greedy(
|
|
|
|
|
probs_split=probs, vocab_list=vocab_list)
|
|
|
|
|
probs_split=probs_split, vocab_list=vocab_list)
|
|
|
|
|
elif decoding_method == "ctc_beam_search":
|
|
|
|
|
result_transcripts = self._decode_batch_beam_search(
|
|
|
|
|
probs_split=probs,
|
|
|
|
|
probs_split=probs_split,
|
|
|
|
|
beam_alpha=beam_alpha,
|
|
|
|
|
beam_beta=beam_beta,
|
|
|
|
|
beam_size=beam_size,
|
|
|
|
@ -686,12 +689,11 @@ class DeepSpeech2(nn.Layer):
|
|
|
|
|
def decode(self, audio, audio_len, vocab_list, decoding_method,
|
|
|
|
|
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
|
|
|
|
|
cutoff_top_n, num_processes):
|
|
|
|
|
_, probs, audio_lens = self.predict(audio, audio_len)
|
|
|
|
|
probs_split = [probs[i, :l, :] for i, l in enumerate(audio_lens)]
|
|
|
|
|
return self.decode_probs(probs_split, vocab_list, decoding_method,
|
|
|
|
|
lang_model_path, beam_alpha, beam_beta,
|
|
|
|
|
beam_size, cutoff_prob, cutoff_top_n,
|
|
|
|
|
num_processes)
|
|
|
|
|
_, probs, logits_lens = self.predict(audio, audio_len)
|
|
|
|
|
return self.decode_probs(probs.numpy(), logits_lens, vocab_list,
|
|
|
|
|
decoding_method, lang_model_path, beam_alpha,
|
|
|
|
|
beam_beta, beam_size, cutoff_prob,
|
|
|
|
|
cutoff_top_n, num_processes)
|
|
|
|
|
|
|
|
|
|
def from_pretrained(self, checkpoint_path):
|
|
|
|
|
"""Build a model from a pretrained model.
|
|
|
|
|