From 42c58daf5fa80c35d9aa5cd1ecee0ba157629fff Mon Sep 17 00:00:00 2001 From: yangyaming Date: Thu, 28 Sep 2017 11:03:47 +0800 Subject: [PATCH] Fix bugs for demo_server.py. --- deploy/_init_paths.py | 19 +++++++++++++++++++ deploy/demo_server.py | 8 ++++++-- 2 files changed, 25 insertions(+), 2 deletions(-) create mode 100644 deploy/_init_paths.py diff --git a/deploy/_init_paths.py b/deploy/_init_paths.py new file mode 100644 index 00000000..ddabb535 --- /dev/null +++ b/deploy/_init_paths.py @@ -0,0 +1,19 @@ +"""Set up paths for DS2""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os.path +import sys + + +def add_path(path): + if path not in sys.path: + sys.path.insert(0, path) + + +this_dir = os.path.dirname(__file__) + +# Add project path to PYTHONPATH +proj_path = os.path.join(this_dir, '..') +add_path(proj_path) diff --git a/deploy/demo_server.py b/deploy/demo_server.py index 7c558419..e3cc6705 100644 --- a/deploy/demo_server.py +++ b/deploy/demo_server.py @@ -12,7 +12,7 @@ import paddle.v2 as paddle import _init_paths from data_utils.data import DataGenerator from model_utils.model import DeepSpeech2Model -from data_utils.utils import read_manifest +from data_utils.utility import read_manifest from utils.utility import add_arguments, print_arguments parser = argparse.ArgumentParser(description=__doc__) @@ -26,6 +26,7 @@ add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.") add_arg('alpha', float, 0.36, "Coef of LM for beam search.") add_arg('beta', float, 0.25, "Coef of WC for beam search.") add_arg('cutoff_prob', float, 0.99, "Cutoff probability for pruning.") +add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.") add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.") add_arg('use_gpu', bool, True, "Use GPU or not.") add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across " @@ -156,6 +157,8 @@ def start_server(): pretrained_model_path=args.model_path, share_rnn_weights=args.share_rnn_weights) + vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list] + # prepare ASR inference handler def file_to_transcript(filename): feature = data_generator.process_utterance(filename, "") @@ -166,7 +169,8 @@ def start_server(): beam_beta=args.beta, beam_size=args.beam_size, cutoff_prob=args.cutoff_prob, - vocab_list=data_generator.vocab_list, + cutoff_top_n=args.cutoff_top_n, + vocab_list=vocab_list, language_model_path=args.lang_model_path, num_processes=1) return result_transcript[0]