Fix bugs for demo_server.py.

pull/2/head
yangyaming 7 years ago
parent 8dcf1d5667
commit 42c58daf5f

@ -0,0 +1,19 @@
"""Set up paths for DS2"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import sys
def add_path(path):
if path not in sys.path:
sys.path.insert(0, path)
this_dir = os.path.dirname(__file__)
# Add project path to PYTHONPATH
proj_path = os.path.join(this_dir, '..')
add_path(proj_path)

@ -12,7 +12,7 @@ import paddle.v2 as paddle
import _init_paths import _init_paths
from data_utils.data import DataGenerator from data_utils.data import DataGenerator
from model_utils.model import DeepSpeech2Model from model_utils.model import DeepSpeech2Model
from data_utils.utils import read_manifest from data_utils.utility import read_manifest
from utils.utility import add_arguments, print_arguments from utils.utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
@ -26,6 +26,7 @@ add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.")
add_arg('alpha', float, 0.36, "Coef of LM for beam search.") add_arg('alpha', float, 0.36, "Coef of LM for beam search.")
add_arg('beta', float, 0.25, "Coef of WC for beam search.") add_arg('beta', float, 0.25, "Coef of WC for beam search.")
add_arg('cutoff_prob', float, 0.99, "Cutoff probability for pruning.") add_arg('cutoff_prob', float, 0.99, "Cutoff probability for pruning.")
add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.")
add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.") add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.")
add_arg('use_gpu', bool, True, "Use GPU or not.") add_arg('use_gpu', bool, True, "Use GPU or not.")
add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across " add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across "
@ -156,6 +157,8 @@ def start_server():
pretrained_model_path=args.model_path, pretrained_model_path=args.model_path,
share_rnn_weights=args.share_rnn_weights) share_rnn_weights=args.share_rnn_weights)
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
# prepare ASR inference handler # prepare ASR inference handler
def file_to_transcript(filename): def file_to_transcript(filename):
feature = data_generator.process_utterance(filename, "") feature = data_generator.process_utterance(filename, "")
@ -166,7 +169,8 @@ def start_server():
beam_beta=args.beta, beam_beta=args.beta,
beam_size=args.beam_size, beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob, cutoff_prob=args.cutoff_prob,
vocab_list=data_generator.vocab_list, cutoff_top_n=args.cutoff_top_n,
vocab_list=vocab_list,
language_model_path=args.lang_model_path, language_model_path=args.lang_model_path,
num_processes=1) num_processes=1)
return result_transcript[0] return result_transcript[0]

Loading…
Cancel
Save