From a1b04176ac67a4a6ea94daddd47fbebed4d5de33 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 3 Mar 2021 08:44:16 +0000 Subject: [PATCH] add deploy bin and test it --- .gitignore | 1 + .../exps/deepspeech2/bin/deploy/client.py | 2 +- .../exps/deepspeech2/bin/deploy/record.py | 54 ++++ .../exps/deepspeech2/bin/deploy/send.py | 59 ++++ .../exps/deepspeech2/bin/deploy/server.py | 216 +++++++++++++++ deepspeech/exps/deepspeech2/bin/tune.py | 1 + deepspeech/exps/deepspeech2/model.py | 16 +- deploy/demo_server.py | 252 ------------------ examples/aishell/.gitignore | 2 + .../local/client.sh} | 9 +- examples/aishell/local/infer_golden.sh | 29 +- examples/aishell/local/server.sh | 40 +++ examples/deploy_demo/path.sh | 8 - .../deploy_demo/run_english_demo_server.sh | 54 ---- 14 files changed, 395 insertions(+), 348 deletions(-) rename deploy/demo_client.py => deepspeech/exps/deepspeech2/bin/deploy/client.py (98%) create mode 100644 deepspeech/exps/deepspeech2/bin/deploy/record.py create mode 100644 deepspeech/exps/deepspeech2/bin/deploy/send.py create mode 100644 deepspeech/exps/deepspeech2/bin/deploy/server.py delete mode 100644 deploy/demo_server.py rename examples/{deploy_demo/run_demo_client.sh => aishell/local/client.sh} (60%) create mode 100644 examples/aishell/local/server.sh delete mode 100644 examples/deploy_demo/path.sh delete mode 100644 examples/deploy_demo/run_english_demo_server.sh diff --git a/.gitignore b/.gitignore index dee7e4b33..24b4c4de6 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ *.pyc tools/venv .vscode +*.log diff --git a/deploy/demo_client.py b/deepspeech/exps/deepspeech2/bin/deploy/client.py similarity index 98% rename from deploy/demo_client.py rename to deepspeech/exps/deepspeech2/bin/deploy/client.py index b4aa50e8e..ae86e97f4 100644 --- a/deploy/demo_client.py +++ b/deepspeech/exps/deepspeech2/bin/deploy/client.py @@ -80,7 +80,7 @@ def main(): # prepare audio recorder p = pyaudio.PyAudio() stream = p.open( - format=pyaudio.paInt32, + format=pyaudio.paInt16, channels=1, rate=16000, input=True, diff --git a/deepspeech/exps/deepspeech2/bin/deploy/record.py b/deepspeech/exps/deepspeech2/bin/deploy/record.py new file mode 100644 index 000000000..717747593 --- /dev/null +++ b/deepspeech/exps/deepspeech2/bin/deploy/record.py @@ -0,0 +1,54 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Record wav from Microphone""" +# http://people.csail.mit.edu/hubert/pyaudio/ +import pyaudio +import wave + +CHUNK = 1024 +FORMAT = pyaudio.paInt16 +CHANNELS = 1 +RATE = 16000 +RECORD_SECONDS = 5 +WAVE_OUTPUT_FILENAME = "output.wav" + +p = pyaudio.PyAudio() + +stream = p.open( + format=FORMAT, + channels=CHANNELS, + rate=RATE, + input=True, + frames_per_buffer=CHUNK) + +print("* recording") + +frames = [] + +for i in range(0, int(RATE / CHUNK * RECORD_SECONDS)): + data = stream.read(CHUNK) + frames.append(data) + +print("* done recording") + +stream.stop_stream() +stream.close() +p.terminate() + +wf = wave.open(WAVE_OUTPUT_FILENAME, 'wb') +wf.setnchannels(CHANNELS) +wf.setsampwidth(p.get_sample_size(FORMAT)) +wf.setframerate(RATE) +wf.writeframes(b''.join(frames)) +wf.close() diff --git a/deepspeech/exps/deepspeech2/bin/deploy/send.py b/deepspeech/exps/deepspeech2/bin/deploy/send.py new file mode 100644 index 000000000..09a269ceb --- /dev/null +++ b/deepspeech/exps/deepspeech2/bin/deploy/send.py @@ -0,0 +1,59 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Socket client to send wav to ASR server.""" +import struct +import socket +import argparse +import wave + +parser = argparse.ArgumentParser(description=__doc__) +parser.add_argument( + "--host_ip", + default="localhost", + type=str, + help="Server IP address. (default: %(default)s)") +parser.add_argument( + "--host_port", + default=8086, + type=int, + help="Server Port. (default: %(default)s)") +args = parser.parse_args() + +WAVE_OUTPUT_FILENAME = "output.wav" + + +def main(): + wf = wave.open(WAVE_OUTPUT_FILENAME, 'rb') + nframe = wf.getnframes() + data = wf.readframes(nframe) + print(f"Wave: {WAVE_OUTPUT_FILENAME}") + print(f"Wave samples: {nframe}") + print(f"Wave channels: {wf.getnchannels()}") + print(f"Wave sample rate: {wf.getframerate()}") + print(f"Wave sample width: {wf.getsampwidth()}") + assert isinstance(data, bytes) + + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect((args.host_ip, args.host_port)) + sent = data + sock.sendall(struct.pack('>i', len(sent)) + sent) + print('Speech[length=%d] Sent.' % len(sent)) + # Receive data from the server and shut down + received = sock.recv(1024) + print("Recognition Results: {}".format(received.decode('utf8'))) + sock.close() + + +if __name__ == "__main__": + main() diff --git a/deepspeech/exps/deepspeech2/bin/deploy/server.py b/deepspeech/exps/deepspeech2/bin/deploy/server.py new file mode 100644 index 000000000..6e0e9a603 --- /dev/null +++ b/deepspeech/exps/deepspeech2/bin/deploy/server.py @@ -0,0 +1,216 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Server-end for the ASR demo.""" +import os +import time +import random +import argparse +import functools +from time import gmtime, strftime +import socketserver +import struct +import wave +import paddle +import numpy as np + +from deepspeech.training.cli import default_argument_parser +from deepspeech.exps.deepspeech2.config import get_cfg_defaults + +from deepspeech.frontend.utility import read_manifest +from deepspeech.utils.utility import add_arguments, print_arguments + +from deepspeech.models.deepspeech2 import DeepSpeech2Model +from deepspeech.io.dataset import ManifestDataset + + +class AsrTCPServer(socketserver.TCPServer): + """The ASR TCP Server.""" + + def __init__(self, + server_address, + RequestHandlerClass, + speech_save_dir, + audio_process_handler, + bind_and_activate=True): + self.speech_save_dir = speech_save_dir + self.audio_process_handler = audio_process_handler + socketserver.TCPServer.__init__( + self, server_address, RequestHandlerClass, bind_and_activate=True) + + +class AsrRequestHandler(socketserver.BaseRequestHandler): + """The ASR request handler.""" + + def handle(self): + # receive data through TCP socket + chunk = self.request.recv(1024) + target_len = struct.unpack('>i', chunk[:4])[0] + data = chunk[4:] + while len(data) < target_len: + chunk = self.request.recv(1024) + data += chunk + # write to file + filename = self._write_to_file(data) + + print("Received utterance[length=%d] from %s, saved to %s." % + (len(data), self.client_address[0], filename)) + start_time = time.time() + transcript = self.server.audio_process_handler(filename) + finish_time = time.time() + print("Response Time: %f, Transcript: %s" % + (finish_time - start_time, transcript)) + self.request.sendall(transcript.encode('utf-8')) + + def _write_to_file(self, data): + # prepare save dir and filename + if not os.path.exists(self.server.speech_save_dir): + os.mkdir(self.server.speech_save_dir) + timestamp = strftime("%Y%m%d%H%M%S", gmtime()) + out_filename = os.path.join( + self.server.speech_save_dir, + timestamp + "_" + self.client_address[0] + ".wav") + # write to wav file + file = wave.open(out_filename, 'wb') + file.setnchannels(1) + file.setsampwidth(2) + file.setframerate(16000) + file.writeframes(data) + file.close() + return out_filename + + +def warm_up_test(audio_process_handler, + manifest_path, + num_test_cases, + random_seed=0): + """Warming-up test.""" + manifest = read_manifest(manifest_path) + rng = random.Random(random_seed) + samples = rng.sample(manifest, num_test_cases) + for idx, sample in enumerate(samples): + print("Warm-up Test Case %d: %s", idx, sample['audio_filepath']) + start_time = time.time() + transcript = audio_process_handler(sample['audio_filepath']) + finish_time = time.time() + print("Response Time: %f, Transcript: %s" % + (finish_time - start_time, transcript)) + + +def start_server(config, args): + """Start the ASR server""" + dataset = ManifestDataset( + config.data.test_manifest, + config.data.vocab_filepath, + config.data.mean_std_filepath, + augmentation_config="{}", + max_duration=config.data.max_duration, + min_duration=config.data.min_duration, + stride_ms=config.data.stride_ms, + window_ms=config.data.window_ms, + n_fft=config.data.n_fft, + max_freq=config.data.max_freq, + target_sample_rate=config.data.target_sample_rate, + specgram_type=config.data.specgram_type, + use_dB_normalization=config.data.use_dB_normalization, + target_dB=config.data.target_dB, + random_seed=config.data.random_seed, + keep_transcription_text=True) + + model = DeepSpeech2Model( + feat_size=dataset.feature_size, + dict_size=dataset.vocab_size, + num_conv_layers=config.model.num_conv_layers, + num_rnn_layers=config.model.num_rnn_layers, + rnn_size=config.model.rnn_layer_size, + use_gru=config.model.use_gru, + share_rnn_weights=config.model.share_rnn_weights) + model.from_pretrained(args.checkpoint_path) + model.eval() + + # prepare ASR inference handler + def file_to_transcript(filename): + feature = dataset.process_utterance(filename, "") + audio = np.array([feature[0]]).astype('float32') #[1, D, T] + audio_len = feature[0].shape[1] + audio_len = np.array([audio_len]).astype('int64') # [1] + + result_transcript = model.decode( + paddle.to_tensor(audio), + paddle.to_tensor(audio_len), + vocab_list=dataset.vocab_list, + decoding_method=config.decoding.decoding_method, + lang_model_path=config.decoding.lang_model_path, + beam_alpha=config.decoding.alpha, + beam_beta=config.decoding.beta, + beam_size=config.decoding.beam_size, + cutoff_prob=config.decoding.cutoff_prob, + cutoff_top_n=config.decoding.cutoff_top_n, + num_processes=config.decoding.num_proc_bsearch) + return result_transcript[0] + + # warming up with utterrances sampled from Librispeech + print('-----------------------------------------------------------') + print('Warming up ...') + warm_up_test( + audio_process_handler=file_to_transcript, + manifest_path=args.warmup_manifest, + num_test_cases=3) + print('-----------------------------------------------------------') + + # start the server + server = AsrTCPServer( + server_address=(args.host_ip, args.host_port), + RequestHandlerClass=AsrRequestHandler, + speech_save_dir=args.speech_save_dir, + audio_process_handler=file_to_transcript) + print("ASR Server Started.") + server.serve_forever() + + +def main(config, args): + start_server(config, args) + + +if __name__ == "__main__": + parser = default_argument_parser() + add_arg = functools.partial(add_arguments, argparser=parser) + # yapf: disable + add_arg('host_ip', str, + 'localhost', + "Server's IP address.") + add_arg('host_port', int, 8086, "Server's IP port.") + add_arg('speech_save_dir', str, + 'demo_cache', + "Directory to save demo audios.") + add_arg('warmup_manifest', str, None, "Filepath of manifest to warm up.") + args = parser.parse_args() + print_arguments(args) + + # https://yaml.org/type/float.html + config = get_cfg_defaults() + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + + args.warmup_manifest = config.data.test_manifest + print_arguments(args) + + if args.dump_config: + with open(args.dump_config, 'w') as f: + print(config, file=f) + + main(config, args) diff --git a/deepspeech/exps/deepspeech2/bin/tune.py b/deepspeech/exps/deepspeech2/bin/tune.py index eb6bddd9a..a8bcc1855 100644 --- a/deepspeech/exps/deepspeech2/bin/tune.py +++ b/deepspeech/exps/deepspeech2/bin/tune.py @@ -72,6 +72,7 @@ def tune(config, args): num_conv_layers=config.model.num_conv_layers, num_rnn_layers=config.model.num_rnn_layers, rnn_size=config.model.rnn_layer_size, + use_gru=config.model.use_gru, share_rnn_weights=config.model.share_rnn_weights) model.from_pretrained(args.checkpoint_path) model.eval() diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index eb34c43e5..c414ff130 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -302,7 +302,15 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): error_rate_type, num_ins, num_ins, errors_sum / len_refs) self.logger.info(msg) + def run_test(self): + self.resume_or_load() + try: + self.test() + except KeyboardInterrupt: + exit(-1) + def export(self): + self.infer_model.from_pretrained(self.args.checkpoint_path) self.infer_model.eval() feat_dim = self.test_loader.dataset.feature_size # static_model = paddle.jit.to_static( @@ -322,15 +330,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): dtype='int64'), # audio_length, [B] ]) - def run_test(self): - self.resume_or_load() - try: - self.test() - except KeyboardInterrupt: - exit(-1) - def run_export(self): - self.resume_or_load() try: self.export() except KeyboardInterrupt: diff --git a/deploy/demo_server.py b/deploy/demo_server.py deleted file mode 100644 index 299b58091..000000000 --- a/deploy/demo_server.py +++ /dev/null @@ -1,252 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Server-end for the ASR demo.""" -import os -import time -import random -import argparse -import functools -from time import gmtime, strftime -import SocketServer -import struct -import wave -import paddle.fluid as fluid -import numpy as np - -from deepspeech.frontend.utility import read_manifest -from deepspeech.utils.utility import add_arguments, print_arguments - -from deepspeech.exps.deepspeech2.model import DeepSpeech2Model -from deepspeech.exps.deepspeech2.dataset import DataGenerator - -parser = argparse.ArgumentParser(description=__doc__) -add_arg = functools.partial(add_arguments, argparser=parser) -# yapf: disable -add_arg('host_port', int, 8086, "Server's IP port.") -add_arg('beam_size', int, 500, "Beam search width.") -add_arg('num_conv_layers', int, 2, "# of convolution layers.") -add_arg('num_rnn_layers', int, 3, "# of recurrent layers.") -add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.") -add_arg('alpha', float, 2.5, "Coef of LM for beam search.") -add_arg('beta', float, 0.3, "Coef of WC for beam search.") -add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.") -add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.") -add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.") -add_arg('use_gpu', bool, True, "Use GPU or not.") -add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across " - "bi-directional RNNs. Not for GRU.") -add_arg('host_ip', str, - 'localhost', - "Server's IP address.") -add_arg('speech_save_dir', str, - 'demo_cache', - "Directory to save demo audios.") -add_arg('warmup_manifest', str, - 'data/librispeech/manifest.test-clean', - "Filepath of manifest to warm up.") -add_arg('mean_std_path', str, - 'data/librispeech/mean_std.npz', - "Filepath of normalizer's mean & std.") -add_arg('vocab_path', str, - 'data/librispeech/eng_vocab.txt', - "Filepath of vocabulary.") -add_arg('model_path', str, - './checkpoints/libri/step_final', - "If None, the training starts from scratch, " - "otherwise, it resumes from the pre-trained model.") -add_arg('lang_model_path', str, - 'lm/data/common_crawl_00.prune01111.trie.klm', - "Filepath for language model.") -add_arg('decoding_method', str, - 'ctc_beam_search', - "Decoding method. Options: ctc_beam_search, ctc_greedy", - choices = ['ctc_beam_search', 'ctc_greedy']) -add_arg('specgram_type', str, - 'linear', - "Audio feature type. Options: linear, mfcc.", - choices=['linear', 'mfcc']) -# yapf: disable -args = parser.parse_args() - - -class AsrTCPServer(SocketServer.TCPServer): - """The ASR TCP Server.""" - - def __init__(self, - server_address, - RequestHandlerClass, - speech_save_dir, - audio_process_handler, - bind_and_activate=True): - self.speech_save_dir = speech_save_dir - self.audio_process_handler = audio_process_handler - SocketServer.TCPServer.__init__( - self, server_address, RequestHandlerClass, bind_and_activate=True) - - -class AsrRequestHandler(SocketServer.BaseRequestHandler): - """The ASR request handler.""" - - def handle(self): - # receive data through TCP socket - chunk = self.request.recv(1024) - target_len = struct.unpack('>i', chunk[:4])[0] - data = chunk[4:] - while len(data) < target_len: - chunk = self.request.recv(1024) - data += chunk - # write to file - filename = self._write_to_file(data) - - print("Received utterance[length=%d] from %s, saved to %s." % - (len(data), self.client_address[0], filename)) - start_time = time.time() - transcript = self.server.audio_process_handler(filename) - finish_time = time.time() - print("Response Time: %f, Transcript: %s" % - (finish_time - start_time, transcript)) - self.request.sendall(transcript.encode('utf-8')) - - def _write_to_file(self, data): - # prepare save dir and filename - if not os.path.exists(self.server.speech_save_dir): - os.mkdir(self.server.speech_save_dir) - timestamp = strftime("%Y%m%d%H%M%S", gmtime()) - out_filename = os.path.join( - self.server.speech_save_dir, - timestamp + "_" + self.client_address[0] + ".wav") - # write to wav file - file = wave.open(out_filename, 'wb') - file.setnchannels(1) - file.setsampwidth(4) - file.setframerate(16000) - file.writeframes(data) - file.close() - return out_filename - - -def warm_up_test(audio_process_handler, - manifest_path, - num_test_cases, - random_seed=0): - """Warming-up test.""" - manifest = read_manifest(manifest_path) - rng = random.Random(random_seed) - samples = rng.sample(manifest, num_test_cases) - for idx, sample in enumerate(samples): - print("Warm-up Test Case %d: %s", idx, sample['audio_filepath']) - start_time = time.time() - transcript = audio_process_handler(sample['audio_filepath']) - finish_time = time.time() - print("Response Time: %f, Transcript: %s" % - (finish_time - start_time, transcript)) - - -def start_server(): - """Start the ASR server""" - # prepare data generator - if args.use_gpu: - place = fluid.CUDAPlace(0) - else: - place = fluid.CPUPlace() - - data_generator = DataGenerator( - vocab_filepath=args.vocab_path, - mean_std_filepath=args.mean_std_path, - augmentation_config='{}', - specgram_type=args.specgram_type, - keep_transcription_text=True, - place = place, - is_training = False) - # prepare ASR model - ds2_model = DeepSpeech2Model( - vocab_size=data_generator.vocab_size, - num_conv_layers=args.num_conv_layers, - num_rnn_layers=args.num_rnn_layers, - rnn_layer_size=args.rnn_layer_size, - use_gru=args.use_gru, - init_from_pretrained_model=args.model_path, - place=place, - share_rnn_weights=args.share_rnn_weights) - - vocab_list = [chars for chars in data_generator.vocab_list] - - if args.decoding_method == "ctc_beam_search": - ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path, - vocab_list) - # prepare ASR inference handler - def file_to_transcript(filename): - feature = data_generator.process_utterance(filename, "") - audio_len = feature[0].shape[1] - mask_shape0 = (feature[0].shape[0] - 1) // 2 + 1 - mask_shape1 = (feature[0].shape[1] - 1) // 3 + 1 - mask_max_len = (audio_len - 1) // 3 + 1 - mask_ones = np.ones((mask_shape0, mask_shape1)) - mask_zeros = np.zeros((mask_shape0, mask_max_len - mask_shape1)) - mask = np.repeat( - np.reshape( - np.concatenate((mask_ones, mask_zeros), axis=1), - (1, mask_shape0, mask_max_len)), - 32, - axis=0) - feature = (np.array([feature[0]]).astype('float32'), - None, - np.array([audio_len]).astype('int64').reshape([-1,1]), - np.array([mask]).astype('float32')) - probs_split = ds2_model.infer_batch_probs( - infer_data=feature, - feeding_dict=data_generator.feeding) - - if args.decoding_method == "ctc_greedy": - result_transcript = ds2_model.decode_batch_greedy( - probs_split=probs_split, - vocab_list=vocab_list) - else: - result_transcript = ds2_model.decode_batch_beam_search( - probs_split=probs_split, - beam_alpha=args.alpha, - beam_beta=args.beta, - beam_size=args.beam_size, - cutoff_prob=args.cutoff_prob, - cutoff_top_n=args.cutoff_top_n, - vocab_list=vocab_list, - num_processes=1) - return result_transcript[0] - - # warming up with utterrances sampled from Librispeech - print('-----------------------------------------------------------') - print('Warming up ...') - warm_up_test( - audio_process_handler=file_to_transcript, - manifest_path=args.warmup_manifest, - num_test_cases=3) - print('-----------------------------------------------------------') - - # start the server - server = AsrTCPServer( - server_address=(args.host_ip, args.host_port), - RequestHandlerClass=AsrRequestHandler, - speech_save_dir=args.speech_save_dir, - audio_process_handler=file_to_transcript) - print("ASR Server Started.") - server.serve_forever() - - -def main(): - print_arguments(args) - start_server() - - -if __name__ == "__main__": - main() diff --git a/examples/aishell/.gitignore b/examples/aishell/.gitignore index 44038ca5b..389676a70 100644 --- a/examples/aishell/.gitignore +++ b/examples/aishell/.gitignore @@ -1,2 +1,4 @@ data ckpt* +demo_cache +*.log diff --git a/examples/deploy_demo/run_demo_client.sh b/examples/aishell/local/client.sh similarity index 60% rename from examples/deploy_demo/run_demo_client.sh rename to examples/aishell/local/client.sh index 60581c661..d626ecc75 100644 --- a/examples/deploy_demo/run_demo_client.sh +++ b/examples/aishell/local/client.sh @@ -2,9 +2,13 @@ source path.sh +# run on MacOS +# brew install portaudio +# pip install pyaudio +# pip install keyboard + # start demo client -CUDA_VISIBLE_DEVICES=0 \ -python3 -u ${MAIN_ROOT}/deploy/demo_client.py \ +python3 -u ${BIN_DIR}/deploy/client.py \ --host_ip="localhost" \ --host_port=8086 \ @@ -13,5 +17,4 @@ if [ $? -ne 0 ]; then exit 1 fi - exit 0 diff --git a/examples/aishell/local/infer_golden.sh b/examples/aishell/local/infer_golden.sh index 1727bcbad..3fdcd1b5e 100644 --- a/examples/aishell/local/infer_golden.sh +++ b/examples/aishell/local/infer_golden.sh @@ -14,28 +14,13 @@ fi # infer CUDA_VISIBLE_DEVICES=0 \ -python3 -u ${MAIN_ROOT}/infer.py \ ---num_samples=10 \ ---beam_size=300 \ ---num_proc_bsearch=8 \ ---num_conv_layers=2 \ ---num_rnn_layers=3 \ ---rnn_layer_size=1024 \ ---alpha=2.6 \ ---beta=5.0 \ ---cutoff_prob=0.99 \ ---cutoff_top_n=40 \ ---use_gru=True \ ---use_gpu=False \ ---share_rnn_weights=False \ ---infer_manifest="data/manifest.test" \ ---mean_std_path="data/pretrain/mean_std.npz" \ ---vocab_path="data/pretrain/vocab.txt" \ ---model_path="data/pretrain" \ ---lang_model_path="data/lm/zh_giga.no_cna_cmn.prune01244.klm" \ ---decoding_method="ctc_beam_search" \ ---error_rate_type="cer" \ ---specgram_type="linear" +python3 -u ${BIN_DIR}/infer.py \ +--device 'gpu' \ +--nproc 1 \ +--config conf/deepspeech2.yaml \ +--checkpoint_path data/pretrain/params.pdparams \ +--opts data.mean_std_filepath data/pretrain/mean_std.npz \ +--opts data.vocab_filepath data/pretrain/vocab.txt if [ $? -ne 0 ]; then echo "Failed in inference!" diff --git a/examples/aishell/local/server.sh b/examples/aishell/local/server.sh new file mode 100644 index 000000000..379684075 --- /dev/null +++ b/examples/aishell/local/server.sh @@ -0,0 +1,40 @@ +#! /usr/bin/env bash +# TODO: replace the model with a mandarin model + +if [[ $# != 1 ]];then + echo "usage: server.sh checkpoint_path" + exit -1 +fi + +source path.sh + +# download language model +bash local/download_lm_ch.sh +if [ $? -ne 0 ]; then + exit 1 +fi + +# download well-trained model +bash local/download_model.sh +if [ $? -ne 0 ]; then + exit 1 +fi + +# start demo server +CUDA_VISIBLE_DEVICES=0 \ +python3 -u ${BIN_DIR}/deploy/server.py \ +--device 'gpu' \ +--nproc 1 \ +--config conf/deepspeech2.yaml \ +--host_ip="localhost" \ +--host_port=8086 \ +--speech_save_dir="demo_cache" \ +--checkpoint_path ${1} + +if [ $? -ne 0 ]; then + echo "Failed in starting demo server!" + exit 1 +fi + + +exit 0 diff --git a/examples/deploy_demo/path.sh b/examples/deploy_demo/path.sh deleted file mode 100644 index fd1cebba8..000000000 --- a/examples/deploy_demo/path.sh +++ /dev/null @@ -1,8 +0,0 @@ -export MAIN_ROOT=${PWD}/../../ - -export PATH=${MAIN_ROOT}:${PWD}/tools:${PATH} -export LC_ALL=C - -# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C -export PYTHONIOENCODING=UTF-8 -export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} diff --git a/examples/deploy_demo/run_english_demo_server.sh b/examples/deploy_demo/run_english_demo_server.sh deleted file mode 100644 index ae092dbce..000000000 --- a/examples/deploy_demo/run_english_demo_server.sh +++ /dev/null @@ -1,54 +0,0 @@ -#! /usr/bin/env bash -# TODO: replace the model with a mandarin model - -source path.sh - -# download language model -cd ${MAIN_ROOT}/models/lm > /dev/null -bash download_lm_en.sh -if [ $? -ne 0 ]; then - exit 1 -fi -cd - > /dev/null - - -# download well-trained model -cd ${MAIN_ROOT}/models/baidu_en8k > /dev/null -bash download_model.sh -if [ $? -ne 0 ]; then - exit 1 -fi -cd - > /dev/null - - -# start demo server -CUDA_VISIBLE_DEVICES=0 \ -python3 -u ${MAIN_ROOT}/deploy/demo_server.py \ ---host_ip="localhost" \ ---host_port=8086 \ ---num_conv_layers=2 \ ---num_rnn_layers=3 \ ---rnn_layer_size=1024 \ ---alpha=1.15 \ ---beta=0.15 \ ---cutoff_prob=1.0 \ ---cutoff_top_n=40 \ ---use_gru=True \ ---use_gpu=True \ ---share_rnn_weights=False \ ---speech_save_dir="demo_cache" \ ---warmup_manifest="${MAIN_ROOT}/examples/tiny/data/manifest.test-clean" \ ---mean_std_path="${MAIN_ROOT}/models/baidu_en8k/mean_std.npz" \ ---vocab_path="${MAIN_ROOT}/models/baidu_en8k/vocab.txt" \ ---model_path="${MAIN_ROOT}/models/baidu_en8k" \ ---lang_model_path="${MAIN_ROOT}/models/lm/common_crawl_00.prune01111.trie.klm" \ ---decoding_method="ctc_beam_search" \ ---specgram_type="linear" - -if [ $? -ne 0 ]; then - echo "Failed in starting demo server!" - exit 1 -fi - - -exit 0