diff --git a/deploy/demo_server.py b/deploy/demo_server.py index 53be16f7..eca13dce 100644 --- a/deploy/demo_server.py +++ b/deploy/demo_server.py @@ -171,11 +171,11 @@ def start_server(): feeding_dict=data_generator.feeding) if args.decoding_method == "ctc_greedy": - result_transcript = ds2_model.infer_batch_greedy( + result_transcript = ds2_model.decode_batch_greedy( probs_split=probs_split, vocab_list=vocab_list) else: - result_transcript = ds2_model.infer_batch_beam_search( + result_transcript = ds2_model.decode_batch_beam_search( probs_split=probs_split, beam_alpha=args.alpha, beam_beta=args.beta, diff --git a/infer.py b/infer.py index 5dd9b406..ff45a5dc 100644 --- a/infer.py +++ b/infer.py @@ -98,11 +98,11 @@ def infer(): probs_split = ds2_model.infer_probs_batch(infer_data=infer_data, feeding_dict=data_generator.feeding) if args.decoding_method == "ctc_greedy": - result_transcripts = ds2_model.infer_batch_greedy( + result_transcripts = ds2_model.decode_batch_greedy( probs_split=probs_split, vocab_list=vocab_list) else: - result_transcripts = ds2_model.infer_batch_beam_search( + result_transcripts = ds2_model.decode_batch_beam_search( probs_split=probs_split, beam_alpha=args.alpha, beam_beta=args.beta, diff --git a/model_utils/model.py b/model_utils/model.py index 70ba7bb9..a8283fae 100644 --- a/model_utils/model.py +++ b/model_utils/model.py @@ -205,8 +205,9 @@ class DeepSpeech2Model(object): ] return probs_split - def infer_batch_greedy(self, probs_split, vocab_list): - """ + def decode_batch_greedy(self, probs_split, vocab_list): + """Decode by best path for a batch of probs matrix input. + :param probs_split: List of 2-D probability matrix, and each consists of prob vectors for one speech utterancce. :param probs_split: List of matrix @@ -256,11 +257,10 @@ class DeepSpeech2Model(object): self.logger.info("no language model provided, " "decoding by pure beam search without scorer.") - def infer_batch_beam_search(self, probs_split, beam_alpha, beam_beta, - beam_size, cutoff_prob, cutoff_top_n, - vocab_list, num_processes): - """Model inference. Infer the transcription for a batch of speech - utterances. + def decode_batch_beam_search(self, probs_split, beam_alpha, beam_beta, + beam_size, cutoff_prob, cutoff_top_n, + vocab_list, num_processes): + """Decode by beam search for a batch of probs matrix input. :param probs_split: List of 2-D probability matrix, and each consists of prob vectors for one speech utterancce. diff --git a/test.py b/test.py index 24ce54a2..a82893c0 100644 --- a/test.py +++ b/test.py @@ -102,11 +102,11 @@ def evaluate(): feeding_dict=data_generator.feeding) if args.decoding_method == "ctc_greedy": - result_transcripts = ds2_model.infer_batch_greedy( + result_transcripts = ds2_model.decode_batch_greedy( probs_split=probs_split, vocab_list=vocab_list) else: - result_transcripts = ds2_model.infer_batch_beam_search( + result_transcripts = ds2_model.decode_batch_beam_search( probs_split=probs_split, beam_alpha=args.alpha, beam_beta=args.beta, diff --git a/tools/tune.py b/tools/tune.py index 923e6c3c..d8e28c58 100644 --- a/tools/tune.py +++ b/tools/tune.py @@ -128,7 +128,7 @@ def tune(): num_ins += len(target_transcripts) # grid search for index, (alpha, beta) in enumerate(params_grid): - result_transcripts = ds2_model.infer_batch_beam_search( + result_transcripts = ds2_model.decode_batch_beam_search( probs_split=probs_split, beam_alpha=alpha, beam_beta=beta,