Rename prefix 'infer_batch' to 'decode_batch'

pull/122/head
Yibing Liu 7 years ago
parent 66a3908818
commit 6c2cf40ce1

@ -171,11 +171,11 @@ def start_server():
feeding_dict=data_generator.feeding)
if args.decoding_method == "ctc_greedy":
result_transcript = ds2_model.infer_batch_greedy(
result_transcript = ds2_model.decode_batch_greedy(
probs_split=probs_split,
vocab_list=vocab_list)
else:
result_transcript = ds2_model.infer_batch_beam_search(
result_transcript = ds2_model.decode_batch_beam_search(
probs_split=probs_split,
beam_alpha=args.alpha,
beam_beta=args.beta,

@ -98,11 +98,11 @@ def infer():
probs_split = ds2_model.infer_probs_batch(infer_data=infer_data,
feeding_dict=data_generator.feeding)
if args.decoding_method == "ctc_greedy":
result_transcripts = ds2_model.infer_batch_greedy(
result_transcripts = ds2_model.decode_batch_greedy(
probs_split=probs_split,
vocab_list=vocab_list)
else:
result_transcripts = ds2_model.infer_batch_beam_search(
result_transcripts = ds2_model.decode_batch_beam_search(
probs_split=probs_split,
beam_alpha=args.alpha,
beam_beta=args.beta,

@ -205,8 +205,9 @@ class DeepSpeech2Model(object):
]
return probs_split
def infer_batch_greedy(self, probs_split, vocab_list):
"""
def decode_batch_greedy(self, probs_split, vocab_list):
"""Decode by best path for a batch of probs matrix input.
:param probs_split: List of 2-D probability matrix, and each consists
of prob vectors for one speech utterancce.
:param probs_split: List of matrix
@ -256,11 +257,10 @@ class DeepSpeech2Model(object):
self.logger.info("no language model provided, "
"decoding by pure beam search without scorer.")
def infer_batch_beam_search(self, probs_split, beam_alpha, beam_beta,
beam_size, cutoff_prob, cutoff_top_n,
vocab_list, num_processes):
"""Model inference. Infer the transcription for a batch of speech
utterances.
def decode_batch_beam_search(self, probs_split, beam_alpha, beam_beta,
beam_size, cutoff_prob, cutoff_top_n,
vocab_list, num_processes):
"""Decode by beam search for a batch of probs matrix input.
:param probs_split: List of 2-D probability matrix, and each consists
of prob vectors for one speech utterancce.

@ -102,11 +102,11 @@ def evaluate():
feeding_dict=data_generator.feeding)
if args.decoding_method == "ctc_greedy":
result_transcripts = ds2_model.infer_batch_greedy(
result_transcripts = ds2_model.decode_batch_greedy(
probs_split=probs_split,
vocab_list=vocab_list)
else:
result_transcripts = ds2_model.infer_batch_beam_search(
result_transcripts = ds2_model.decode_batch_beam_search(
probs_split=probs_split,
beam_alpha=args.alpha,
beam_beta=args.beta,

@ -128,7 +128,7 @@ def tune():
num_ins += len(target_transcripts)
# grid search
for index, (alpha, beta) in enumerate(params_grid):
result_transcripts = ds2_model.infer_batch_beam_search(
result_transcripts = ds2_model.decode_batch_beam_search(
probs_split=probs_split,
beam_alpha=alpha,
beam_beta=beta,

Loading…
Cancel
Save