diff --git a/infer.py b/infer.py index ff45a5dc8..4a5f8cb05 100644 --- a/infer.py +++ b/infer.py @@ -90,18 +90,19 @@ def infer(): # decoders only accept string encoded in utf-8 vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list] - if args.decoding_method == "ctc_beam_search": - ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path, - vocab_list) - - ds2_model.logger.info("start inference ...") - probs_split = ds2_model.infer_probs_batch(infer_data=infer_data, - feeding_dict=data_generator.feeding) if args.decoding_method == "ctc_greedy": + ds2_model.logger.info("start inference ...") + probs_split = ds2_model.infer_probs_batch(infer_data=infer_data, + feeding_dict=data_generator.feeding) result_transcripts = ds2_model.decode_batch_greedy( probs_split=probs_split, vocab_list=vocab_list) else: + ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path, + vocab_list) + ds2_model.logger.info("start inference ...") + probs_split = ds2_model.infer_probs_batch(infer_data=infer_data, + feeding_dict=data_generator.feeding) result_transcripts = ds2_model.decode_batch_beam_search( probs_split=probs_split, beam_alpha=args.alpha,