refactor socket server

new model from pretrain
pull/538/head
Hui Zhang 5 years ago
parent ac6a4da2e0
commit 4b6d4a1c06

@ -19,6 +19,8 @@ import sys
import argparse import argparse
import pyaudio import pyaudio
from deepspeech.utils.socket_server import socket_send
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument( parser.add_argument(
"--host_ip", "--host_ip",
@ -61,16 +63,7 @@ def callback(in_data, frame_count, time_info, status):
data_list.append(in_data) data_list.append(in_data)
enable_trigger_record = False enable_trigger_record = False
elif len(data_list) > 0: elif len(data_list) > 0:
# Connect to server and send data socket_send(args.host_ip, args.host_port, ''.join(data_list))
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((args.host_ip, args.host_port))
sent = ''.join(data_list)
sock.sendall(struct.pack('>i', len(sent)) + sent)
print('Speech[length=%d] Sent.' % len(sent))
# Receive data from the server and shut down
received = sock.recv(1024)
print("Recognition Results: {}".format(received))
sock.close()
data_list = [] data_list = []
enable_trigger_record = True enable_trigger_record = True
return (in_data, pyaudio.paContinue) return (in_data, pyaudio.paContinue)

@ -17,6 +17,8 @@ import socket
import argparse import argparse
import wave import wave
from deepspeech.utils.socket_server import socket_send
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument( parser.add_argument(
"--host_ip", "--host_ip",
@ -43,16 +45,7 @@ def main():
print(f"Wave sample rate: {wf.getframerate()}") print(f"Wave sample rate: {wf.getframerate()}")
print(f"Wave sample width: {wf.getsampwidth()}") print(f"Wave sample width: {wf.getsampwidth()}")
assert isinstance(data, bytes) assert isinstance(data, bytes)
socket_send(args.host_ip, args.host_port, data)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((args.host_ip, args.host_port))
sent = data
sock.sendall(struct.pack('>i', len(sent)) + sent)
print('Speech[length=%d] Sent.' % len(sent))
# Receive data from the server and shut down
received = sock.recv(1024)
print("Recognition Results: {}".format(received.decode('utf8')))
sock.close()
if __name__ == "__main__": if __name__ == "__main__":

@ -14,16 +14,15 @@
"""Server-end for the ASR demo.""" """Server-end for the ASR demo."""
import os import os
import time import time
import random
import argparse import argparse
import functools import functools
from time import gmtime, strftime
import socketserver
import struct
import wave
import paddle import paddle
import numpy as np import numpy as np
from deepspeech.utils.socket_server import warm_up_test
from deepspeech.utils.socket_server import AsrTCPServer
from deepspeech.utils.socket_server import AsrRequestHandler
from deepspeech.training.cli import default_argument_parser from deepspeech.training.cli import default_argument_parser
from deepspeech.exps.deepspeech2.config import get_cfg_defaults from deepspeech.exps.deepspeech2.config import get_cfg_defaults
@ -34,79 +33,6 @@ from deepspeech.models.deepspeech2 import DeepSpeech2Model
from deepspeech.io.dataset import ManifestDataset from deepspeech.io.dataset import ManifestDataset
class AsrTCPServer(socketserver.TCPServer):
"""The ASR TCP Server."""
def __init__(self,
server_address,
RequestHandlerClass,
speech_save_dir,
audio_process_handler,
bind_and_activate=True):
self.speech_save_dir = speech_save_dir
self.audio_process_handler = audio_process_handler
socketserver.TCPServer.__init__(
self, server_address, RequestHandlerClass, bind_and_activate=True)
class AsrRequestHandler(socketserver.BaseRequestHandler):
"""The ASR request handler."""
def handle(self):
# receive data through TCP socket
chunk = self.request.recv(1024)
target_len = struct.unpack('>i', chunk[:4])[0]
data = chunk[4:]
while len(data) < target_len:
chunk = self.request.recv(1024)
data += chunk
# write to file
filename = self._write_to_file(data)
print("Received utterance[length=%d] from %s, saved to %s." %
(len(data), self.client_address[0], filename))
start_time = time.time()
transcript = self.server.audio_process_handler(filename)
finish_time = time.time()
print("Response Time: %f, Transcript: %s" %
(finish_time - start_time, transcript))
self.request.sendall(transcript.encode('utf-8'))
def _write_to_file(self, data):
# prepare save dir and filename
if not os.path.exists(self.server.speech_save_dir):
os.mkdir(self.server.speech_save_dir)
timestamp = strftime("%Y%m%d%H%M%S", gmtime())
out_filename = os.path.join(
self.server.speech_save_dir,
timestamp + "_" + self.client_address[0] + ".wav")
# write to wav file
file = wave.open(out_filename, 'wb')
file.setnchannels(1)
file.setsampwidth(2)
file.setframerate(16000)
file.writeframes(data)
file.close()
return out_filename
def warm_up_test(audio_process_handler,
manifest_path,
num_test_cases,
random_seed=0):
"""Warming-up test."""
manifest = read_manifest(manifest_path)
rng = random.Random(random_seed)
samples = rng.sample(manifest, num_test_cases)
for idx, sample in enumerate(samples):
print("Warm-up Test Case %d: %s", idx, sample['audio_filepath'])
start_time = time.time()
transcript = audio_process_handler(sample['audio_filepath'])
finish_time = time.time()
print("Response Time: %f, Transcript: %s" %
(finish_time - start_time, transcript))
def start_server(config, args): def start_server(config, args):
"""Start the ASR server""" """Start the ASR server"""
dataset = ManifestDataset( dataset = ManifestDataset(
@ -127,15 +53,8 @@ def start_server(config, args):
random_seed=config.data.random_seed, random_seed=config.data.random_seed,
keep_transcription_text=True) keep_transcription_text=True)
model = DeepSpeech2Model( model = DeepSpeech2Model.from_pretrained(dataset, config,
feat_size=dataset.feature_size, args.checkpoint_path)
dict_size=dataset.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
model.from_pretrained(args.checkpoint_path)
model.eval() model.eval()
# prepare ASR inference handler # prepare ASR inference handler

@ -23,14 +23,14 @@ import logging
from paddle.io import DataLoader from paddle.io import DataLoader
from deepspeech.training.cli import default_argument_parser from deepspeech.utils import error_rate
from deepspeech.utils.error_rate import char_errors, word_errors
from deepspeech.utils.utility import add_arguments, print_arguments from deepspeech.utils.utility import add_arguments, print_arguments
from deepspeech.models.deepspeech2 import DeepSpeech2Model from deepspeech.models.deepspeech2 import DeepSpeech2Model
from deepspeech.io.collator import SpeechCollator from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset from deepspeech.io.dataset import ManifestDataset
from deepspeech.training.cli import default_argument_parser
from deepspeech.exps.deepspeech2.config import get_cfg_defaults from deepspeech.exps.deepspeech2.config import get_cfg_defaults
@ -66,20 +66,13 @@ def tune(config, args):
drop_last=False, drop_last=False,
collate_fn=SpeechCollator(is_training=False)) collate_fn=SpeechCollator(is_training=False))
model = DeepSpeech2Model( model = DeepSpeech2Model.from_pretrained(dev_dataset, config,
feat_size=valid_loader.dataset.feature_size, args.checkpoint_path)
dict_size=valid_loader.dataset.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
model.from_pretrained(args.checkpoint_path)
model.eval() model.eval()
# decoders only accept string encoded in utf-8 # decoders only accept string encoded in utf-8
vocab_list = valid_loader.dataset.vocab_list vocab_list = valid_loader.dataset.vocab_list
errors_func = char_errors if config.decoding.error_rate_type == 'cer' else word_errors errors_func = error_rate.char_errors if config.decoding.error_rate_type == 'cer' else error_rate.word_errors
# create grid for search # create grid for search
cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas) cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas)
@ -168,12 +161,8 @@ def tune(config, args):
print("finish tuning") print("finish tuning")
def main_sp(config, args):
tune(config, args)
def main(config, args): def main(config, args):
main_sp(config, args) tune(config, args)
if __name__ == "__main__": if __name__ == "__main__":

@ -282,7 +282,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
for i, batch in enumerate(self.test_loader): for i, batch in enumerate(self.test_loader):
metrics = self.compute_metrics(*batch) metrics = self.compute_metrics(*batch)
errors_sum += metrics['errors_sum'] errors_sum += metrics['errors_sum']
len_refs += metrics['len_refs'] len_refs += metrics['len_refs']
num_ins += metrics['num_ins'] num_ins += metrics['num_ins']
@ -306,7 +305,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
exit(-1) exit(-1)
def export(self): def export(self):
self.infer_model.from_pretrained(self.args.checkpoint_path)
self.infer_model.eval() self.infer_model.eval()
feat_dim = self.test_loader.dataset.feature_size feat_dim = self.test_loader.dataset.feature_size
# static_model = paddle.jit.to_static( # static_model = paddle.jit.to_static(
@ -358,14 +357,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
use_gru=config.model.use_gru, use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights) share_rnn_weights=config.model.share_rnn_weights)
infer_model = DeepSpeech2InferModel( infer_model = DeepSpeech2InferModel.from_pretrained(
feat_size=self.test_loader.dataset.feature_size, self.test_loader.dataset, config, self.args.checkpoint_path)
dict_size=self.test_loader.dataset.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
self.model = model self.model = model
self.infer_model = infer_model self.infer_model = infer_model

@ -59,9 +59,10 @@ class SpeechCollator():
# text # text
padded_text = np.zeros([max_text_length]) padded_text = np.zeros([max_text_length])
if self._is_training: if self._is_training:
padded_text[:len(text)] = text #ids padded_text[:len(text)] = text # token ids
else: else:
padded_text[:len(text)] = [ord(t) for t in text] # string padded_text[:len(text)] = [ord(t)
for t in text] # string, unicode ord
texts.append(padded_text) texts.append(padded_text)
text_lens.append(len(text)) text_lens.append(len(text))

@ -53,6 +53,26 @@ class ManifestDataset(Dataset):
target_dB=-20, target_dB=-20,
random_seed=0, random_seed=0,
keep_transcription_text=False): keep_transcription_text=False):
"""Manifest Dataset
Args:
manifest_path (str): manifest josn file path
vocab_filepath (str): vocab file path
mean_std_filepath (str): mean and std file path, which suffix is *.npy
augmentation_config (str, optional): augmentation json str. Defaults to '{}'.
max_duration (float, optional): audio length in seconds must less than this. Defaults to float('inf').
min_duration (float, optional): audio length is seconds must greater than this. Defaults to 0.0.
stride_ms (float, optional): stride size in ms. Defaults to 10.0.
window_ms (float, optional): window size in ms. Defaults to 20.0.
n_fft (int, optional): fft points for rfft. Defaults to None.
max_freq (int, optional): max cut freq. Defaults to None.
target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000.
specgram_type (str, optional): 'linear' or 'mfcc'. Defaults to 'linear'.
use_dB_normalization (bool, optional): do dB normalization. Defaults to True.
target_dB (int, optional): target dB. Defaults to -20.
random_seed (int, optional): for random generator. Defaults to 0.
keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False.
"""
super().__init__() super().__init__()
self._max_duration = max_duration self._max_duration = max_duration

@ -29,6 +29,7 @@ from deepspeech.modules.rnn import RNNStack
from deepspeech.modules.mask import sequence_mask from deepspeech.modules.mask import sequence_mask
from deepspeech.modules.activation import brelu from deepspeech.modules.activation import brelu
from deepspeech.utils import checkpoint from deepspeech.utils import checkpoint
from deepspeech.utils import layer_tools
from deepspeech.decoders.swig_wrapper import Scorer from deepspeech.decoders.swig_wrapper import Scorer
from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder
from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch
@ -378,21 +379,6 @@ class DeepSpeech2Model(nn.Layer):
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes) cutoff_top_n, num_processes)
def from_pretrained(self, checkpoint_path):
"""Build a model from a pretrained model.
Parameters
----------
checkpoint_path: Path or str
The path of pretrained model checkpoint, without extension name.
Returns
-------
DeepSpeech2Model
The model build from pretrined result.
"""
checkpoint.load_parameters(self, checkpoint_path=checkpoint_path)
return self
@classmethod @classmethod
def from_pretrained(cls, dataset, config, checkpoint_path): def from_pretrained(cls, dataset, config, checkpoint_path):
"""Build a DeepSpeech2Model model from a pretrained model. """Build a DeepSpeech2Model model from a pretrained model.
@ -418,7 +404,7 @@ class DeepSpeech2Model(nn.Layer):
rnn_size=config.model.rnn_layer_size, rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru, use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights) share_rnn_weights=config.model.share_rnn_weights)
model.from_pretrained(checkpoint_path) checkpoint.load_parameters(model, checkpoint_path=checkpoint_path)
layer_tools.summary(model) layer_tools.summary(model)
return model return model

@ -0,0 +1,111 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import time
from time import gmtime, strftime
import socketserver
import struct
import wave
from deepspeech.frontend.utility import read_manifest
__all__ = ["socket_send", "warm_up_test", "AsrTCPServer", "AsrRequestHandler"]
def socket_send(server_ip: str, server_port: str, data: bytes):
# Connect to server and send data
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((server_ip, server_port))
sent = data
sock.sendall(struct.pack('>i', len(sent)) + sent)
print('Speech[length=%d] Sent.' % len(sent))
# Receive data from the server and shut down
received = sock.recv(1024)
print("Recognition Results: {}".format(received.decode('utf8')))
sock.close()
def warm_up_test(audio_process_handler,
manifest_path,
num_test_cases,
random_seed=0):
"""Warming-up test."""
manifest = read_manifest(manifest_path)
rng = random.Random(random_seed)
samples = rng.sample(manifest, num_test_cases)
for idx, sample in enumerate(samples):
print("Warm-up Test Case %d: %s", idx, sample['audio_filepath'])
start_time = time.time()
transcript = audio_process_handler(sample['audio_filepath'])
finish_time = time.time()
print("Response Time: %f, Transcript: %s" %
(finish_time - start_time, transcript))
class AsrTCPServer(socketserver.TCPServer):
"""The ASR TCP Server."""
def __init__(self,
server_address,
RequestHandlerClass,
speech_save_dir,
audio_process_handler,
bind_and_activate=True):
self.speech_save_dir = speech_save_dir
self.audio_process_handler = audio_process_handler
socketserver.TCPServer.__init__(
self, server_address, RequestHandlerClass, bind_and_activate=True)
class AsrRequestHandler(socketserver.BaseRequestHandler):
"""The ASR request handler."""
def handle(self):
# receive data through TCP socket
chunk = self.request.recv(1024)
target_len = struct.unpack('>i', chunk[:4])[0]
data = chunk[4:]
while len(data) < target_len:
chunk = self.request.recv(1024)
data += chunk
# write to file
filename = self._write_to_file(data)
print("Received utterance[length=%d] from %s, saved to %s." %
(len(data), self.client_address[0], filename))
start_time = time.time()
transcript = self.server.audio_process_handler(filename)
finish_time = time.time()
print("Response Time: %f, Transcript: %s" %
(finish_time - start_time, transcript))
self.request.sendall(transcript.encode('utf-8'))
def _write_to_file(self, data):
# prepare save dir and filename
if not os.path.exists(self.server.speech_save_dir):
os.mkdir(self.server.speech_save_dir)
timestamp = strftime("%Y%m%d%H%M%S", gmtime())
out_filename = os.path.join(
self.server.speech_save_dir,
timestamp + "_" + self.client_address[0] + ".wav")
# write to wav file
file = wave.open(out_filename, 'wb')
file.setnchannels(1)
file.setsampwidth(2)
file.setframerate(16000)
file.writeframes(data)
file.close()
return out_filename

@ -0,0 +1,20 @@
#! /usr/bin/env bash
if [ $# != 2 ];then
echo "usage: export ckpt_path jit_model_path"
exit -1
fi
python3 -u ${BIN_DIR}/export.py \
--config conf/deepspeech2.yaml \
--checkpoint_path ${1} \
--export_path ${2}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
exit 0
Loading…
Cancel
Save