diff --git a/deploy/demo_server.py b/deploy/demo_server.py index eca13dcea..1cafb7a58 100644 --- a/deploy/demo_server.py +++ b/deploy/demo_server.py @@ -166,7 +166,7 @@ def start_server(): # prepare ASR inference handler def file_to_transcript(filename): feature = data_generator.process_utterance(filename, "") - probs_split = ds2_model.infer_probs_batch( + probs_split = ds2_model.infer_batch_probs( infer_data=[feature], feeding_dict=data_generator.feeding) diff --git a/infer.py b/infer.py index 4a5f8cb05..f4d75685b 100644 --- a/infer.py +++ b/infer.py @@ -92,7 +92,7 @@ def infer(): if args.decoding_method == "ctc_greedy": ds2_model.logger.info("start inference ...") - probs_split = ds2_model.infer_probs_batch(infer_data=infer_data, + probs_split = ds2_model.infer_batch_probs(infer_data=infer_data, feeding_dict=data_generator.feeding) result_transcripts = ds2_model.decode_batch_greedy( probs_split=probs_split, @@ -101,7 +101,7 @@ def infer(): ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path, vocab_list) ds2_model.logger.info("start inference ...") - probs_split = ds2_model.infer_probs_batch(infer_data=infer_data, + probs_split = ds2_model.infer_batch_probs(infer_data=infer_data, feeding_dict=data_generator.feeding) result_transcripts = ds2_model.decode_batch_beam_search( probs_split=probs_split, diff --git a/model_utils/model.py b/model_utils/model.py index a8283fae4..4b3764bf2 100644 --- a/model_utils/model.py +++ b/model_utils/model.py @@ -173,7 +173,7 @@ class DeepSpeech2Model(object): # run inference return self._loss_inferer.infer(input=infer_data) - def infer_probs_batch(self, infer_data, feeding_dict): + def infer_batch_probs(self, infer_data, feeding_dict): """Infer the prob matrices for a batch of speech utterances. :param infer_data: List of utterances to infer, with each utterance diff --git a/test.py b/test.py index a82893c03..e5a3346a0 100644 --- a/test.py +++ b/test.py @@ -97,7 +97,7 @@ def evaluate(): errors_sum, len_refs, num_ins = 0.0, 0, 0 ds2_model.logger.info("start evaluation ...") for infer_data in batch_reader(): - probs_split = ds2_model.infer_probs_batch( + probs_split = ds2_model.infer_batch_probs( infer_data=infer_data, feeding_dict=data_generator.feeding) diff --git a/tools/tune.py b/tools/tune.py index d8e28c58a..da785189f 100644 --- a/tools/tune.py +++ b/tools/tune.py @@ -120,7 +120,7 @@ def tune(): for infer_data in batch_reader(): if (args.num_batches >= 0) and (cur_batch >= args.num_batches): break - probs_split = ds2_model.infer_probs_batch( + probs_split = ds2_model.infer_batch_probs( infer_data=infer_data, feeding_dict=data_generator.feeding) target_transcripts = [ data[1] for data in infer_data ]