From 3a36c8a69ea50200439794f7cb87a97267044887 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Fri, 12 Jan 2018 23:22:08 +0800 Subject: [PATCH] Adapt demo_server to the decoupling in infer_batch() --- deploy/demo_server.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/deploy/demo_server.py b/deploy/demo_server.py index d64f9f015..53be16f77 100644 --- a/deploy/demo_server.py +++ b/deploy/demo_server.py @@ -160,22 +160,30 @@ def start_server(): vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list] + if args.decoding_method == "ctc_beam_search": + ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path, + vocab_list) # prepare ASR inference handler def file_to_transcript(filename): feature = data_generator.process_utterance(filename, "") - - result_transcript = ds2_model.infer_batch( + probs_split = ds2_model.infer_probs_batch( infer_data=[feature], - decoding_method=args.decoding_method, - beam_alpha=args.alpha, - beam_beta=args.beta, - beam_size=args.beam_size, - cutoff_prob=args.cutoff_prob, - cutoff_top_n=args.cutoff_top_n, - vocab_list=vocab_list, - language_model_path=args.lang_model_path, - num_processes=1, feeding_dict=data_generator.feeding) + + if args.decoding_method == "ctc_greedy": + result_transcript = ds2_model.infer_batch_greedy( + probs_split=probs_split, + vocab_list=vocab_list) + else: + result_transcript = ds2_model.infer_batch_beam_search( + probs_split=probs_split, + beam_alpha=args.alpha, + beam_beta=args.beta, + beam_size=args.beam_size, + cutoff_prob=args.cutoff_prob, + cutoff_top_n=args.cutoff_top_n, + vocab_list=vocab_list, + num_processes=1) return result_transcript[0] # warming up with utterrances sampled from Librispeech