add deploy bin and test it

pull/538/head
Hui Zhang 5 years ago
parent 4a3768ad18
commit a1b04176ac

1
.gitignore vendored

@ -2,3 +2,4 @@
*.pyc *.pyc
tools/venv tools/venv
.vscode .vscode
*.log

@ -80,7 +80,7 @@ def main():
# prepare audio recorder # prepare audio recorder
p = pyaudio.PyAudio() p = pyaudio.PyAudio()
stream = p.open( stream = p.open(
format=pyaudio.paInt32, format=pyaudio.paInt16,
channels=1, channels=1,
rate=16000, rate=16000,
input=True, input=True,

@ -0,0 +1,54 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Record wav from Microphone"""
# http://people.csail.mit.edu/hubert/pyaudio/
import pyaudio
import wave
CHUNK = 1024
FORMAT = pyaudio.paInt16
CHANNELS = 1
RATE = 16000
RECORD_SECONDS = 5
WAVE_OUTPUT_FILENAME = "output.wav"
p = pyaudio.PyAudio()
stream = p.open(
format=FORMAT,
channels=CHANNELS,
rate=RATE,
input=True,
frames_per_buffer=CHUNK)
print("* recording")
frames = []
for i in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
data = stream.read(CHUNK)
frames.append(data)
print("* done recording")
stream.stop_stream()
stream.close()
p.terminate()
wf = wave.open(WAVE_OUTPUT_FILENAME, 'wb')
wf.setnchannels(CHANNELS)
wf.setsampwidth(p.get_sample_size(FORMAT))
wf.setframerate(RATE)
wf.writeframes(b''.join(frames))
wf.close()

@ -0,0 +1,59 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Socket client to send wav to ASR server."""
import struct
import socket
import argparse
import wave
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--host_ip",
default="localhost",
type=str,
help="Server IP address. (default: %(default)s)")
parser.add_argument(
"--host_port",
default=8086,
type=int,
help="Server Port. (default: %(default)s)")
args = parser.parse_args()
WAVE_OUTPUT_FILENAME = "output.wav"
def main():
wf = wave.open(WAVE_OUTPUT_FILENAME, 'rb')
nframe = wf.getnframes()
data = wf.readframes(nframe)
print(f"Wave: {WAVE_OUTPUT_FILENAME}")
print(f"Wave samples: {nframe}")
print(f"Wave channels: {wf.getnchannels()}")
print(f"Wave sample rate: {wf.getframerate()}")
print(f"Wave sample width: {wf.getsampwidth()}")
assert isinstance(data, bytes)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((args.host_ip, args.host_port))
sent = data
sock.sendall(struct.pack('>i', len(sent)) + sent)
print('Speech[length=%d] Sent.' % len(sent))
# Receive data from the server and shut down
received = sock.recv(1024)
print("Recognition Results: {}".format(received.decode('utf8')))
sock.close()
if __name__ == "__main__":
main()

@ -0,0 +1,216 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Server-end for the ASR demo."""
import os
import time
import random
import argparse
import functools
from time import gmtime, strftime
import socketserver
import struct
import wave
import paddle
import numpy as np
from deepspeech.training.cli import default_argument_parser
from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.frontend.utility import read_manifest
from deepspeech.utils.utility import add_arguments, print_arguments
from deepspeech.models.deepspeech2 import DeepSpeech2Model
from deepspeech.io.dataset import ManifestDataset
class AsrTCPServer(socketserver.TCPServer):
"""The ASR TCP Server."""
def __init__(self,
server_address,
RequestHandlerClass,
speech_save_dir,
audio_process_handler,
bind_and_activate=True):
self.speech_save_dir = speech_save_dir
self.audio_process_handler = audio_process_handler
socketserver.TCPServer.__init__(
self, server_address, RequestHandlerClass, bind_and_activate=True)
class AsrRequestHandler(socketserver.BaseRequestHandler):
"""The ASR request handler."""
def handle(self):
# receive data through TCP socket
chunk = self.request.recv(1024)
target_len = struct.unpack('>i', chunk[:4])[0]
data = chunk[4:]
while len(data) < target_len:
chunk = self.request.recv(1024)
data += chunk
# write to file
filename = self._write_to_file(data)
print("Received utterance[length=%d] from %s, saved to %s." %
(len(data), self.client_address[0], filename))
start_time = time.time()
transcript = self.server.audio_process_handler(filename)
finish_time = time.time()
print("Response Time: %f, Transcript: %s" %
(finish_time - start_time, transcript))
self.request.sendall(transcript.encode('utf-8'))
def _write_to_file(self, data):
# prepare save dir and filename
if not os.path.exists(self.server.speech_save_dir):
os.mkdir(self.server.speech_save_dir)
timestamp = strftime("%Y%m%d%H%M%S", gmtime())
out_filename = os.path.join(
self.server.speech_save_dir,
timestamp + "_" + self.client_address[0] + ".wav")
# write to wav file
file = wave.open(out_filename, 'wb')
file.setnchannels(1)
file.setsampwidth(2)
file.setframerate(16000)
file.writeframes(data)
file.close()
return out_filename
def warm_up_test(audio_process_handler,
manifest_path,
num_test_cases,
random_seed=0):
"""Warming-up test."""
manifest = read_manifest(manifest_path)
rng = random.Random(random_seed)
samples = rng.sample(manifest, num_test_cases)
for idx, sample in enumerate(samples):
print("Warm-up Test Case %d: %s", idx, sample['audio_filepath'])
start_time = time.time()
transcript = audio_process_handler(sample['audio_filepath'])
finish_time = time.time()
print("Response Time: %f, Transcript: %s" %
(finish_time - start_time, transcript))
def start_server(config, args):
"""Start the ASR server"""
dataset = ManifestDataset(
config.data.test_manifest,
config.data.vocab_filepath,
config.data.mean_std_filepath,
augmentation_config="{}",
max_duration=config.data.max_duration,
min_duration=config.data.min_duration,
stride_ms=config.data.stride_ms,
window_ms=config.data.window_ms,
n_fft=config.data.n_fft,
max_freq=config.data.max_freq,
target_sample_rate=config.data.target_sample_rate,
specgram_type=config.data.specgram_type,
use_dB_normalization=config.data.use_dB_normalization,
target_dB=config.data.target_dB,
random_seed=config.data.random_seed,
keep_transcription_text=True)
model = DeepSpeech2Model(
feat_size=dataset.feature_size,
dict_size=dataset.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
model.from_pretrained(args.checkpoint_path)
model.eval()
# prepare ASR inference handler
def file_to_transcript(filename):
feature = dataset.process_utterance(filename, "")
audio = np.array([feature[0]]).astype('float32') #[1, D, T]
audio_len = feature[0].shape[1]
audio_len = np.array([audio_len]).astype('int64') # [1]
result_transcript = model.decode(
paddle.to_tensor(audio),
paddle.to_tensor(audio_len),
vocab_list=dataset.vocab_list,
decoding_method=config.decoding.decoding_method,
lang_model_path=config.decoding.lang_model_path,
beam_alpha=config.decoding.alpha,
beam_beta=config.decoding.beta,
beam_size=config.decoding.beam_size,
cutoff_prob=config.decoding.cutoff_prob,
cutoff_top_n=config.decoding.cutoff_top_n,
num_processes=config.decoding.num_proc_bsearch)
return result_transcript[0]
# warming up with utterrances sampled from Librispeech
print('-----------------------------------------------------------')
print('Warming up ...')
warm_up_test(
audio_process_handler=file_to_transcript,
manifest_path=args.warmup_manifest,
num_test_cases=3)
print('-----------------------------------------------------------')
# start the server
server = AsrTCPServer(
server_address=(args.host_ip, args.host_port),
RequestHandlerClass=AsrRequestHandler,
speech_save_dir=args.speech_save_dir,
audio_process_handler=file_to_transcript)
print("ASR Server Started.")
server.serve_forever()
def main(config, args):
start_server(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('host_ip', str,
'localhost',
"Server's IP address.")
add_arg('host_port', int, 8086, "Server's IP port.")
add_arg('speech_save_dir', str,
'demo_cache',
"Directory to save demo audios.")
add_arg('warmup_manifest', str, None, "Filepath of manifest to warm up.")
args = parser.parse_args()
print_arguments(args)
# https://yaml.org/type/float.html
config = get_cfg_defaults()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
args.warmup_manifest = config.data.test_manifest
print_arguments(args)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
main(config, args)

@ -72,6 +72,7 @@ def tune(config, args):
num_conv_layers=config.model.num_conv_layers, num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers, num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size, rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights) share_rnn_weights=config.model.share_rnn_weights)
model.from_pretrained(args.checkpoint_path) model.from_pretrained(args.checkpoint_path)
model.eval() model.eval()

@ -302,7 +302,15 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
error_rate_type, num_ins, num_ins, errors_sum / len_refs) error_rate_type, num_ins, num_ins, errors_sum / len_refs)
self.logger.info(msg) self.logger.info(msg)
def run_test(self):
self.resume_or_load()
try:
self.test()
except KeyboardInterrupt:
exit(-1)
def export(self): def export(self):
self.infer_model.from_pretrained(self.args.checkpoint_path)
self.infer_model.eval() self.infer_model.eval()
feat_dim = self.test_loader.dataset.feature_size feat_dim = self.test_loader.dataset.feature_size
# static_model = paddle.jit.to_static( # static_model = paddle.jit.to_static(
@ -322,15 +330,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
dtype='int64'), # audio_length, [B] dtype='int64'), # audio_length, [B]
]) ])
def run_test(self):
self.resume_or_load()
try:
self.test()
except KeyboardInterrupt:
exit(-1)
def run_export(self): def run_export(self):
self.resume_or_load()
try: try:
self.export() self.export()
except KeyboardInterrupt: except KeyboardInterrupt:

@ -1,252 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Server-end for the ASR demo."""
import os
import time
import random
import argparse
import functools
from time import gmtime, strftime
import SocketServer
import struct
import wave
import paddle.fluid as fluid
import numpy as np
from deepspeech.frontend.utility import read_manifest
from deepspeech.utils.utility import add_arguments, print_arguments
from deepspeech.exps.deepspeech2.model import DeepSpeech2Model
from deepspeech.exps.deepspeech2.dataset import DataGenerator
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('host_port', int, 8086, "Server's IP port.")
add_arg('beam_size', int, 500, "Beam search width.")
add_arg('num_conv_layers', int, 2, "# of convolution layers.")
add_arg('num_rnn_layers', int, 3, "# of recurrent layers.")
add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.")
add_arg('alpha', float, 2.5, "Coef of LM for beam search.")
add_arg('beta', float, 0.3, "Coef of WC for beam search.")
add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.")
add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.")
add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.")
add_arg('use_gpu', bool, True, "Use GPU or not.")
add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across "
"bi-directional RNNs. Not for GRU.")
add_arg('host_ip', str,
'localhost',
"Server's IP address.")
add_arg('speech_save_dir', str,
'demo_cache',
"Directory to save demo audios.")
add_arg('warmup_manifest', str,
'data/librispeech/manifest.test-clean',
"Filepath of manifest to warm up.")
add_arg('mean_std_path', str,
'data/librispeech/mean_std.npz',
"Filepath of normalizer's mean & std.")
add_arg('vocab_path', str,
'data/librispeech/eng_vocab.txt',
"Filepath of vocabulary.")
add_arg('model_path', str,
'./checkpoints/libri/step_final',
"If None, the training starts from scratch, "
"otherwise, it resumes from the pre-trained model.")
add_arg('lang_model_path', str,
'lm/data/common_crawl_00.prune01111.trie.klm',
"Filepath for language model.")
add_arg('decoding_method', str,
'ctc_beam_search',
"Decoding method. Options: ctc_beam_search, ctc_greedy",
choices = ['ctc_beam_search', 'ctc_greedy'])
add_arg('specgram_type', str,
'linear',
"Audio feature type. Options: linear, mfcc.",
choices=['linear', 'mfcc'])
# yapf: disable
args = parser.parse_args()
class AsrTCPServer(SocketServer.TCPServer):
"""The ASR TCP Server."""
def __init__(self,
server_address,
RequestHandlerClass,
speech_save_dir,
audio_process_handler,
bind_and_activate=True):
self.speech_save_dir = speech_save_dir
self.audio_process_handler = audio_process_handler
SocketServer.TCPServer.__init__(
self, server_address, RequestHandlerClass, bind_and_activate=True)
class AsrRequestHandler(SocketServer.BaseRequestHandler):
"""The ASR request handler."""
def handle(self):
# receive data through TCP socket
chunk = self.request.recv(1024)
target_len = struct.unpack('>i', chunk[:4])[0]
data = chunk[4:]
while len(data) < target_len:
chunk = self.request.recv(1024)
data += chunk
# write to file
filename = self._write_to_file(data)
print("Received utterance[length=%d] from %s, saved to %s." %
(len(data), self.client_address[0], filename))
start_time = time.time()
transcript = self.server.audio_process_handler(filename)
finish_time = time.time()
print("Response Time: %f, Transcript: %s" %
(finish_time - start_time, transcript))
self.request.sendall(transcript.encode('utf-8'))
def _write_to_file(self, data):
# prepare save dir and filename
if not os.path.exists(self.server.speech_save_dir):
os.mkdir(self.server.speech_save_dir)
timestamp = strftime("%Y%m%d%H%M%S", gmtime())
out_filename = os.path.join(
self.server.speech_save_dir,
timestamp + "_" + self.client_address[0] + ".wav")
# write to wav file
file = wave.open(out_filename, 'wb')
file.setnchannels(1)
file.setsampwidth(4)
file.setframerate(16000)
file.writeframes(data)
file.close()
return out_filename
def warm_up_test(audio_process_handler,
manifest_path,
num_test_cases,
random_seed=0):
"""Warming-up test."""
manifest = read_manifest(manifest_path)
rng = random.Random(random_seed)
samples = rng.sample(manifest, num_test_cases)
for idx, sample in enumerate(samples):
print("Warm-up Test Case %d: %s", idx, sample['audio_filepath'])
start_time = time.time()
transcript = audio_process_handler(sample['audio_filepath'])
finish_time = time.time()
print("Response Time: %f, Transcript: %s" %
(finish_time - start_time, transcript))
def start_server():
"""Start the ASR server"""
# prepare data generator
if args.use_gpu:
place = fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
data_generator = DataGenerator(
vocab_filepath=args.vocab_path,
mean_std_filepath=args.mean_std_path,
augmentation_config='{}',
specgram_type=args.specgram_type,
keep_transcription_text=True,
place = place,
is_training = False)
# prepare ASR model
ds2_model = DeepSpeech2Model(
vocab_size=data_generator.vocab_size,
num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers,
rnn_layer_size=args.rnn_layer_size,
use_gru=args.use_gru,
init_from_pretrained_model=args.model_path,
place=place,
share_rnn_weights=args.share_rnn_weights)
vocab_list = [chars for chars in data_generator.vocab_list]
if args.decoding_method == "ctc_beam_search":
ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path,
vocab_list)
# prepare ASR inference handler
def file_to_transcript(filename):
feature = data_generator.process_utterance(filename, "")
audio_len = feature[0].shape[1]
mask_shape0 = (feature[0].shape[0] - 1) // 2 + 1
mask_shape1 = (feature[0].shape[1] - 1) // 3 + 1
mask_max_len = (audio_len - 1) // 3 + 1
mask_ones = np.ones((mask_shape0, mask_shape1))
mask_zeros = np.zeros((mask_shape0, mask_max_len - mask_shape1))
mask = np.repeat(
np.reshape(
np.concatenate((mask_ones, mask_zeros), axis=1),
(1, mask_shape0, mask_max_len)),
32,
axis=0)
feature = (np.array([feature[0]]).astype('float32'),
None,
np.array([audio_len]).astype('int64').reshape([-1,1]),
np.array([mask]).astype('float32'))
probs_split = ds2_model.infer_batch_probs(
infer_data=feature,
feeding_dict=data_generator.feeding)
if args.decoding_method == "ctc_greedy":
result_transcript = ds2_model.decode_batch_greedy(
probs_split=probs_split,
vocab_list=vocab_list)
else:
result_transcript = ds2_model.decode_batch_beam_search(
probs_split=probs_split,
beam_alpha=args.alpha,
beam_beta=args.beta,
beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob,
cutoff_top_n=args.cutoff_top_n,
vocab_list=vocab_list,
num_processes=1)
return result_transcript[0]
# warming up with utterrances sampled from Librispeech
print('-----------------------------------------------------------')
print('Warming up ...')
warm_up_test(
audio_process_handler=file_to_transcript,
manifest_path=args.warmup_manifest,
num_test_cases=3)
print('-----------------------------------------------------------')
# start the server
server = AsrTCPServer(
server_address=(args.host_ip, args.host_port),
RequestHandlerClass=AsrRequestHandler,
speech_save_dir=args.speech_save_dir,
audio_process_handler=file_to_transcript)
print("ASR Server Started.")
server.serve_forever()
def main():
print_arguments(args)
start_server()
if __name__ == "__main__":
main()

@ -1,2 +1,4 @@
data data
ckpt* ckpt*
demo_cache
*.log

@ -2,9 +2,13 @@
source path.sh source path.sh
# run on MacOS
# brew install portaudio
# pip install pyaudio
# pip install keyboard
# start demo client # start demo client
CUDA_VISIBLE_DEVICES=0 \ python3 -u ${BIN_DIR}/deploy/client.py \
python3 -u ${MAIN_ROOT}/deploy/demo_client.py \
--host_ip="localhost" \ --host_ip="localhost" \
--host_port=8086 \ --host_port=8086 \
@ -13,5 +17,4 @@ if [ $? -ne 0 ]; then
exit 1 exit 1
fi fi
exit 0 exit 0

@ -14,28 +14,13 @@ fi
# infer # infer
CUDA_VISIBLE_DEVICES=0 \ CUDA_VISIBLE_DEVICES=0 \
python3 -u ${MAIN_ROOT}/infer.py \ python3 -u ${BIN_DIR}/infer.py \
--num_samples=10 \ --device 'gpu' \
--beam_size=300 \ --nproc 1 \
--num_proc_bsearch=8 \ --config conf/deepspeech2.yaml \
--num_conv_layers=2 \ --checkpoint_path data/pretrain/params.pdparams \
--num_rnn_layers=3 \ --opts data.mean_std_filepath data/pretrain/mean_std.npz \
--rnn_layer_size=1024 \ --opts data.vocab_filepath data/pretrain/vocab.txt
--alpha=2.6 \
--beta=5.0 \
--cutoff_prob=0.99 \
--cutoff_top_n=40 \
--use_gru=True \
--use_gpu=False \
--share_rnn_weights=False \
--infer_manifest="data/manifest.test" \
--mean_std_path="data/pretrain/mean_std.npz" \
--vocab_path="data/pretrain/vocab.txt" \
--model_path="data/pretrain" \
--lang_model_path="data/lm/zh_giga.no_cna_cmn.prune01244.klm" \
--decoding_method="ctc_beam_search" \
--error_rate_type="cer" \
--specgram_type="linear"
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Failed in inference!" echo "Failed in inference!"

@ -0,0 +1,40 @@
#! /usr/bin/env bash
# TODO: replace the model with a mandarin model
if [[ $# != 1 ]];then
echo "usage: server.sh checkpoint_path"
exit -1
fi
source path.sh
# download language model
bash local/download_lm_ch.sh
if [ $? -ne 0 ]; then
exit 1
fi
# download well-trained model
bash local/download_model.sh
if [ $? -ne 0 ]; then
exit 1
fi
# start demo server
CUDA_VISIBLE_DEVICES=0 \
python3 -u ${BIN_DIR}/deploy/server.py \
--device 'gpu' \
--nproc 1 \
--config conf/deepspeech2.yaml \
--host_ip="localhost" \
--host_port=8086 \
--speech_save_dir="demo_cache" \
--checkpoint_path ${1}
if [ $? -ne 0 ]; then
echo "Failed in starting demo server!"
exit 1
fi
exit 0

@ -1,8 +0,0 @@
export MAIN_ROOT=${PWD}/../../
export PATH=${MAIN_ROOT}:${PWD}/tools:${PATH}
export LC_ALL=C
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}

@ -1,54 +0,0 @@
#! /usr/bin/env bash
# TODO: replace the model with a mandarin model
source path.sh
# download language model
cd ${MAIN_ROOT}/models/lm > /dev/null
bash download_lm_en.sh
if [ $? -ne 0 ]; then
exit 1
fi
cd - > /dev/null
# download well-trained model
cd ${MAIN_ROOT}/models/baidu_en8k > /dev/null
bash download_model.sh
if [ $? -ne 0 ]; then
exit 1
fi
cd - > /dev/null
# start demo server
CUDA_VISIBLE_DEVICES=0 \
python3 -u ${MAIN_ROOT}/deploy/demo_server.py \
--host_ip="localhost" \
--host_port=8086 \
--num_conv_layers=2 \
--num_rnn_layers=3 \
--rnn_layer_size=1024 \
--alpha=1.15 \
--beta=0.15 \
--cutoff_prob=1.0 \
--cutoff_top_n=40 \
--use_gru=True \
--use_gpu=True \
--share_rnn_weights=False \
--speech_save_dir="demo_cache" \
--warmup_manifest="${MAIN_ROOT}/examples/tiny/data/manifest.test-clean" \
--mean_std_path="${MAIN_ROOT}/models/baidu_en8k/mean_std.npz" \
--vocab_path="${MAIN_ROOT}/models/baidu_en8k/vocab.txt" \
--model_path="${MAIN_ROOT}/models/baidu_en8k" \
--lang_model_path="${MAIN_ROOT}/models/lm/common_crawl_00.prune01111.trie.klm" \
--decoding_method="ctc_beam_search" \
--specgram_type="linear"
if [ $? -ne 0 ]; then
echo "Failed in starting demo server!"
exit 1
fi
exit 0
Loading…
Cancel
Save