diff --git a/infer.py b/infer.py index 1539fbaaf..5dd9b406d 100644 --- a/infer.py +++ b/infer.py @@ -90,17 +90,18 @@ def infer(): # decoders only accept string encoded in utf-8 vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list] + if args.decoding_method == "ctc_beam_search": + ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path, + vocab_list) + + ds2_model.logger.info("start inference ...") probs_split = ds2_model.infer_probs_batch(infer_data=infer_data, feeding_dict=data_generator.feeding) if args.decoding_method == "ctc_greedy": - ds2_model.logger.info("start inference ...") result_transcripts = ds2_model.infer_batch_greedy( probs_split=probs_split, vocab_list=vocab_list) else: - ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path, - vocab_list) - ds2_model.logger.info("start inference ...") result_transcripts = ds2_model.infer_batch_beam_search( probs_split=probs_split, beam_alpha=args.alpha,