From 4b6d4a1c06269c9535cda5ae817e22b99d7ff63c Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 3 Mar 2021 12:05:08 +0000 Subject: [PATCH] refactor socket server new model from pretrain --- .../exps/deepspeech2/bin/deploy/client.py | 13 +- .../exps/deepspeech2/bin/deploy/send.py | 13 +- .../exps/deepspeech2/bin/deploy/server.py | 93 +-------------- deepspeech/exps/deepspeech2/bin/tune.py | 23 +--- deepspeech/exps/deepspeech2/model.py | 13 +- deepspeech/io/collator.py | 5 +- deepspeech/io/dataset.py | 20 ++++ deepspeech/models/deepspeech2.py | 18 +-- deepspeech/utils/socket_server.py | 111 ++++++++++++++++++ examples/aishell/local/export.sh | 20 ++++ 10 files changed, 177 insertions(+), 152 deletions(-) create mode 100644 deepspeech/utils/socket_server.py create mode 100644 examples/aishell/local/export.sh diff --git a/deepspeech/exps/deepspeech2/bin/deploy/client.py b/deepspeech/exps/deepspeech2/bin/deploy/client.py index ae86e97f4..766fdc5a9 100644 --- a/deepspeech/exps/deepspeech2/bin/deploy/client.py +++ b/deepspeech/exps/deepspeech2/bin/deploy/client.py @@ -19,6 +19,8 @@ import sys import argparse import pyaudio +from deepspeech.utils.socket_server import socket_send + parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--host_ip", @@ -61,16 +63,7 @@ def callback(in_data, frame_count, time_info, status): data_list.append(in_data) enable_trigger_record = False elif len(data_list) > 0: - # Connect to server and send data - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.connect((args.host_ip, args.host_port)) - sent = ''.join(data_list) - sock.sendall(struct.pack('>i', len(sent)) + sent) - print('Speech[length=%d] Sent.' % len(sent)) - # Receive data from the server and shut down - received = sock.recv(1024) - print("Recognition Results: {}".format(received)) - sock.close() + socket_send(args.host_ip, args.host_port, ''.join(data_list)) data_list = [] enable_trigger_record = True return (in_data, pyaudio.paContinue) diff --git a/deepspeech/exps/deepspeech2/bin/deploy/send.py b/deepspeech/exps/deepspeech2/bin/deploy/send.py index 09a269ceb..84411f91f 100644 --- a/deepspeech/exps/deepspeech2/bin/deploy/send.py +++ b/deepspeech/exps/deepspeech2/bin/deploy/send.py @@ -17,6 +17,8 @@ import socket import argparse import wave +from deepspeech.utils.socket_server import socket_send + parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--host_ip", @@ -43,16 +45,7 @@ def main(): print(f"Wave sample rate: {wf.getframerate()}") print(f"Wave sample width: {wf.getsampwidth()}") assert isinstance(data, bytes) - - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.connect((args.host_ip, args.host_port)) - sent = data - sock.sendall(struct.pack('>i', len(sent)) + sent) - print('Speech[length=%d] Sent.' % len(sent)) - # Receive data from the server and shut down - received = sock.recv(1024) - print("Recognition Results: {}".format(received.decode('utf8'))) - sock.close() + socket_send(args.host_ip, args.host_port, data) if __name__ == "__main__": diff --git a/deepspeech/exps/deepspeech2/bin/deploy/server.py b/deepspeech/exps/deepspeech2/bin/deploy/server.py index 6e0e9a603..109beece6 100644 --- a/deepspeech/exps/deepspeech2/bin/deploy/server.py +++ b/deepspeech/exps/deepspeech2/bin/deploy/server.py @@ -14,16 +14,15 @@ """Server-end for the ASR demo.""" import os import time -import random import argparse import functools -from time import gmtime, strftime -import socketserver -import struct -import wave import paddle import numpy as np +from deepspeech.utils.socket_server import warm_up_test +from deepspeech.utils.socket_server import AsrTCPServer +from deepspeech.utils.socket_server import AsrRequestHandler + from deepspeech.training.cli import default_argument_parser from deepspeech.exps.deepspeech2.config import get_cfg_defaults @@ -34,79 +33,6 @@ from deepspeech.models.deepspeech2 import DeepSpeech2Model from deepspeech.io.dataset import ManifestDataset -class AsrTCPServer(socketserver.TCPServer): - """The ASR TCP Server.""" - - def __init__(self, - server_address, - RequestHandlerClass, - speech_save_dir, - audio_process_handler, - bind_and_activate=True): - self.speech_save_dir = speech_save_dir - self.audio_process_handler = audio_process_handler - socketserver.TCPServer.__init__( - self, server_address, RequestHandlerClass, bind_and_activate=True) - - -class AsrRequestHandler(socketserver.BaseRequestHandler): - """The ASR request handler.""" - - def handle(self): - # receive data through TCP socket - chunk = self.request.recv(1024) - target_len = struct.unpack('>i', chunk[:4])[0] - data = chunk[4:] - while len(data) < target_len: - chunk = self.request.recv(1024) - data += chunk - # write to file - filename = self._write_to_file(data) - - print("Received utterance[length=%d] from %s, saved to %s." % - (len(data), self.client_address[0], filename)) - start_time = time.time() - transcript = self.server.audio_process_handler(filename) - finish_time = time.time() - print("Response Time: %f, Transcript: %s" % - (finish_time - start_time, transcript)) - self.request.sendall(transcript.encode('utf-8')) - - def _write_to_file(self, data): - # prepare save dir and filename - if not os.path.exists(self.server.speech_save_dir): - os.mkdir(self.server.speech_save_dir) - timestamp = strftime("%Y%m%d%H%M%S", gmtime()) - out_filename = os.path.join( - self.server.speech_save_dir, - timestamp + "_" + self.client_address[0] + ".wav") - # write to wav file - file = wave.open(out_filename, 'wb') - file.setnchannels(1) - file.setsampwidth(2) - file.setframerate(16000) - file.writeframes(data) - file.close() - return out_filename - - -def warm_up_test(audio_process_handler, - manifest_path, - num_test_cases, - random_seed=0): - """Warming-up test.""" - manifest = read_manifest(manifest_path) - rng = random.Random(random_seed) - samples = rng.sample(manifest, num_test_cases) - for idx, sample in enumerate(samples): - print("Warm-up Test Case %d: %s", idx, sample['audio_filepath']) - start_time = time.time() - transcript = audio_process_handler(sample['audio_filepath']) - finish_time = time.time() - print("Response Time: %f, Transcript: %s" % - (finish_time - start_time, transcript)) - - def start_server(config, args): """Start the ASR server""" dataset = ManifestDataset( @@ -127,15 +53,8 @@ def start_server(config, args): random_seed=config.data.random_seed, keep_transcription_text=True) - model = DeepSpeech2Model( - feat_size=dataset.feature_size, - dict_size=dataset.vocab_size, - num_conv_layers=config.model.num_conv_layers, - num_rnn_layers=config.model.num_rnn_layers, - rnn_size=config.model.rnn_layer_size, - use_gru=config.model.use_gru, - share_rnn_weights=config.model.share_rnn_weights) - model.from_pretrained(args.checkpoint_path) + model = DeepSpeech2Model.from_pretrained(dataset, config, + args.checkpoint_path) model.eval() # prepare ASR inference handler diff --git a/deepspeech/exps/deepspeech2/bin/tune.py b/deepspeech/exps/deepspeech2/bin/tune.py index a8bcc1855..afd8d646a 100644 --- a/deepspeech/exps/deepspeech2/bin/tune.py +++ b/deepspeech/exps/deepspeech2/bin/tune.py @@ -23,14 +23,14 @@ import logging from paddle.io import DataLoader -from deepspeech.training.cli import default_argument_parser -from deepspeech.utils.error_rate import char_errors, word_errors +from deepspeech.utils import error_rate from deepspeech.utils.utility import add_arguments, print_arguments from deepspeech.models.deepspeech2 import DeepSpeech2Model from deepspeech.io.collator import SpeechCollator from deepspeech.io.dataset import ManifestDataset +from deepspeech.training.cli import default_argument_parser from deepspeech.exps.deepspeech2.config import get_cfg_defaults @@ -66,20 +66,13 @@ def tune(config, args): drop_last=False, collate_fn=SpeechCollator(is_training=False)) - model = DeepSpeech2Model( - feat_size=valid_loader.dataset.feature_size, - dict_size=valid_loader.dataset.vocab_size, - num_conv_layers=config.model.num_conv_layers, - num_rnn_layers=config.model.num_rnn_layers, - rnn_size=config.model.rnn_layer_size, - use_gru=config.model.use_gru, - share_rnn_weights=config.model.share_rnn_weights) - model.from_pretrained(args.checkpoint_path) + model = DeepSpeech2Model.from_pretrained(dev_dataset, config, + args.checkpoint_path) model.eval() # decoders only accept string encoded in utf-8 vocab_list = valid_loader.dataset.vocab_list - errors_func = char_errors if config.decoding.error_rate_type == 'cer' else word_errors + errors_func = error_rate.char_errors if config.decoding.error_rate_type == 'cer' else error_rate.word_errors # create grid for search cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas) @@ -168,12 +161,8 @@ def tune(config, args): print("finish tuning") -def main_sp(config, args): - tune(config, args) - - def main(config, args): - main_sp(config, args) + tune(config, args) if __name__ == "__main__": diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 3db144341..020ed1b58 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -282,7 +282,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): for i, batch in enumerate(self.test_loader): metrics = self.compute_metrics(*batch) - errors_sum += metrics['errors_sum'] len_refs += metrics['len_refs'] num_ins += metrics['num_ins'] @@ -306,7 +305,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): exit(-1) def export(self): - self.infer_model.from_pretrained(self.args.checkpoint_path) + self.infer_model.eval() feat_dim = self.test_loader.dataset.feature_size # static_model = paddle.jit.to_static( @@ -358,14 +357,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): use_gru=config.model.use_gru, share_rnn_weights=config.model.share_rnn_weights) - infer_model = DeepSpeech2InferModel( - feat_size=self.test_loader.dataset.feature_size, - dict_size=self.test_loader.dataset.vocab_size, - num_conv_layers=config.model.num_conv_layers, - num_rnn_layers=config.model.num_rnn_layers, - rnn_size=config.model.rnn_layer_size, - use_gru=config.model.use_gru, - share_rnn_weights=config.model.share_rnn_weights) + infer_model = DeepSpeech2InferModel.from_pretrained( + self.test_loader.dataset, config, self.args.checkpoint_path) self.model = model self.infer_model = infer_model diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index 9e50e170e..10f838fb2 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -59,9 +59,10 @@ class SpeechCollator(): # text padded_text = np.zeros([max_text_length]) if self._is_training: - padded_text[:len(text)] = text #ids + padded_text[:len(text)] = text # token ids else: - padded_text[:len(text)] = [ord(t) for t in text] # string + padded_text[:len(text)] = [ord(t) + for t in text] # string, unicode ord texts.append(padded_text) text_lens.append(len(text)) diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index 3762f0d93..b4c1c7afd 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -53,6 +53,26 @@ class ManifestDataset(Dataset): target_dB=-20, random_seed=0, keep_transcription_text=False): + """Manifest Dataset + + Args: + manifest_path (str): manifest josn file path + vocab_filepath (str): vocab file path + mean_std_filepath (str): mean and std file path, which suffix is *.npy + augmentation_config (str, optional): augmentation json str. Defaults to '{}'. + max_duration (float, optional): audio length in seconds must less than this. Defaults to float('inf'). + min_duration (float, optional): audio length is seconds must greater than this. Defaults to 0.0. + stride_ms (float, optional): stride size in ms. Defaults to 10.0. + window_ms (float, optional): window size in ms. Defaults to 20.0. + n_fft (int, optional): fft points for rfft. Defaults to None. + max_freq (int, optional): max cut freq. Defaults to None. + target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000. + specgram_type (str, optional): 'linear' or 'mfcc'. Defaults to 'linear'. + use_dB_normalization (bool, optional): do dB normalization. Defaults to True. + target_dB (int, optional): target dB. Defaults to -20. + random_seed (int, optional): for random generator. Defaults to 0. + keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False. + """ super().__init__() self._max_duration = max_duration diff --git a/deepspeech/models/deepspeech2.py b/deepspeech/models/deepspeech2.py index 1173c852d..b58260749 100644 --- a/deepspeech/models/deepspeech2.py +++ b/deepspeech/models/deepspeech2.py @@ -29,6 +29,7 @@ from deepspeech.modules.rnn import RNNStack from deepspeech.modules.mask import sequence_mask from deepspeech.modules.activation import brelu from deepspeech.utils import checkpoint +from deepspeech.utils import layer_tools from deepspeech.decoders.swig_wrapper import Scorer from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch @@ -378,21 +379,6 @@ class DeepSpeech2Model(nn.Layer): lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, cutoff_top_n, num_processes) - def from_pretrained(self, checkpoint_path): - """Build a model from a pretrained model. - Parameters - ---------- - checkpoint_path: Path or str - The path of pretrained model checkpoint, without extension name. - - Returns - ------- - DeepSpeech2Model - The model build from pretrined result. - """ - checkpoint.load_parameters(self, checkpoint_path=checkpoint_path) - return self - @classmethod def from_pretrained(cls, dataset, config, checkpoint_path): """Build a DeepSpeech2Model model from a pretrained model. @@ -418,7 +404,7 @@ class DeepSpeech2Model(nn.Layer): rnn_size=config.model.rnn_layer_size, use_gru=config.model.use_gru, share_rnn_weights=config.model.share_rnn_weights) - model.from_pretrained(checkpoint_path) + checkpoint.load_parameters(model, checkpoint_path=checkpoint_path) layer_tools.summary(model) return model diff --git a/deepspeech/utils/socket_server.py b/deepspeech/utils/socket_server.py new file mode 100644 index 000000000..2a0a62d01 --- /dev/null +++ b/deepspeech/utils/socket_server.py @@ -0,0 +1,111 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +import time +from time import gmtime, strftime +import socketserver +import struct +import wave + +from deepspeech.frontend.utility import read_manifest + +__all__ = ["socket_send", "warm_up_test", "AsrTCPServer", "AsrRequestHandler"] + + +def socket_send(server_ip: str, server_port: str, data: bytes): + # Connect to server and send data + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect((server_ip, server_port)) + sent = data + sock.sendall(struct.pack('>i', len(sent)) + sent) + print('Speech[length=%d] Sent.' % len(sent)) + # Receive data from the server and shut down + received = sock.recv(1024) + print("Recognition Results: {}".format(received.decode('utf8'))) + sock.close() + + +def warm_up_test(audio_process_handler, + manifest_path, + num_test_cases, + random_seed=0): + """Warming-up test.""" + manifest = read_manifest(manifest_path) + rng = random.Random(random_seed) + samples = rng.sample(manifest, num_test_cases) + for idx, sample in enumerate(samples): + print("Warm-up Test Case %d: %s", idx, sample['audio_filepath']) + start_time = time.time() + transcript = audio_process_handler(sample['audio_filepath']) + finish_time = time.time() + print("Response Time: %f, Transcript: %s" % + (finish_time - start_time, transcript)) + + +class AsrTCPServer(socketserver.TCPServer): + """The ASR TCP Server.""" + + def __init__(self, + server_address, + RequestHandlerClass, + speech_save_dir, + audio_process_handler, + bind_and_activate=True): + self.speech_save_dir = speech_save_dir + self.audio_process_handler = audio_process_handler + socketserver.TCPServer.__init__( + self, server_address, RequestHandlerClass, bind_and_activate=True) + + +class AsrRequestHandler(socketserver.BaseRequestHandler): + """The ASR request handler.""" + + def handle(self): + # receive data through TCP socket + chunk = self.request.recv(1024) + target_len = struct.unpack('>i', chunk[:4])[0] + data = chunk[4:] + while len(data) < target_len: + chunk = self.request.recv(1024) + data += chunk + # write to file + filename = self._write_to_file(data) + + print("Received utterance[length=%d] from %s, saved to %s." % + (len(data), self.client_address[0], filename)) + start_time = time.time() + transcript = self.server.audio_process_handler(filename) + finish_time = time.time() + print("Response Time: %f, Transcript: %s" % + (finish_time - start_time, transcript)) + self.request.sendall(transcript.encode('utf-8')) + + def _write_to_file(self, data): + # prepare save dir and filename + if not os.path.exists(self.server.speech_save_dir): + os.mkdir(self.server.speech_save_dir) + timestamp = strftime("%Y%m%d%H%M%S", gmtime()) + out_filename = os.path.join( + self.server.speech_save_dir, + timestamp + "_" + self.client_address[0] + ".wav") + # write to wav file + file = wave.open(out_filename, 'wb') + file.setnchannels(1) + file.setsampwidth(2) + file.setframerate(16000) + file.writeframes(data) + file.close() + return out_filename diff --git a/examples/aishell/local/export.sh b/examples/aishell/local/export.sh new file mode 100644 index 000000000..1b5533916 --- /dev/null +++ b/examples/aishell/local/export.sh @@ -0,0 +1,20 @@ +#! /usr/bin/env bash + +if [ $# != 2 ];then + echo "usage: export ckpt_path jit_model_path" + exit -1 +fi + +python3 -u ${BIN_DIR}/export.py \ +--config conf/deepspeech2.yaml \ +--checkpoint_path ${1} \ +--export_path ${2} + + +if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 +fi + + +exit 0