Adapt demo_server to the decoupling in infer_batch()

pull/122/head
Yibing Liu 7 years ago
parent 10d3370970
commit 3a36c8a69e

@ -160,22 +160,30 @@ def start_server():
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list] vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
if args.decoding_method == "ctc_beam_search":
ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path,
vocab_list)
# prepare ASR inference handler # prepare ASR inference handler
def file_to_transcript(filename): def file_to_transcript(filename):
feature = data_generator.process_utterance(filename, "") feature = data_generator.process_utterance(filename, "")
probs_split = ds2_model.infer_probs_batch(
result_transcript = ds2_model.infer_batch(
infer_data=[feature], infer_data=[feature],
decoding_method=args.decoding_method, feeding_dict=data_generator.feeding)
if args.decoding_method == "ctc_greedy":
result_transcript = ds2_model.infer_batch_greedy(
probs_split=probs_split,
vocab_list=vocab_list)
else:
result_transcript = ds2_model.infer_batch_beam_search(
probs_split=probs_split,
beam_alpha=args.alpha, beam_alpha=args.alpha,
beam_beta=args.beta, beam_beta=args.beta,
beam_size=args.beam_size, beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob, cutoff_prob=args.cutoff_prob,
cutoff_top_n=args.cutoff_top_n, cutoff_top_n=args.cutoff_top_n,
vocab_list=vocab_list, vocab_list=vocab_list,
language_model_path=args.lang_model_path, num_processes=1)
num_processes=1,
feeding_dict=data_generator.feeding)
return result_transcript[0] return result_transcript[0]
# warming up with utterrances sampled from Librispeech # warming up with utterrances sampled from Librispeech

Loading…
Cancel
Save