parent
4a3768ad18
commit
a1b04176ac
@ -0,0 +1,54 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Record wav from Microphone"""
|
||||
# http://people.csail.mit.edu/hubert/pyaudio/
|
||||
import pyaudio
|
||||
import wave
|
||||
|
||||
CHUNK = 1024
|
||||
FORMAT = pyaudio.paInt16
|
||||
CHANNELS = 1
|
||||
RATE = 16000
|
||||
RECORD_SECONDS = 5
|
||||
WAVE_OUTPUT_FILENAME = "output.wav"
|
||||
|
||||
p = pyaudio.PyAudio()
|
||||
|
||||
stream = p.open(
|
||||
format=FORMAT,
|
||||
channels=CHANNELS,
|
||||
rate=RATE,
|
||||
input=True,
|
||||
frames_per_buffer=CHUNK)
|
||||
|
||||
print("* recording")
|
||||
|
||||
frames = []
|
||||
|
||||
for i in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
|
||||
data = stream.read(CHUNK)
|
||||
frames.append(data)
|
||||
|
||||
print("* done recording")
|
||||
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
p.terminate()
|
||||
|
||||
wf = wave.open(WAVE_OUTPUT_FILENAME, 'wb')
|
||||
wf.setnchannels(CHANNELS)
|
||||
wf.setsampwidth(p.get_sample_size(FORMAT))
|
||||
wf.setframerate(RATE)
|
||||
wf.writeframes(b''.join(frames))
|
||||
wf.close()
|
@ -0,0 +1,59 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Socket client to send wav to ASR server."""
|
||||
import struct
|
||||
import socket
|
||||
import argparse
|
||||
import wave
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--host_ip",
|
||||
default="localhost",
|
||||
type=str,
|
||||
help="Server IP address. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--host_port",
|
||||
default=8086,
|
||||
type=int,
|
||||
help="Server Port. (default: %(default)s)")
|
||||
args = parser.parse_args()
|
||||
|
||||
WAVE_OUTPUT_FILENAME = "output.wav"
|
||||
|
||||
|
||||
def main():
|
||||
wf = wave.open(WAVE_OUTPUT_FILENAME, 'rb')
|
||||
nframe = wf.getnframes()
|
||||
data = wf.readframes(nframe)
|
||||
print(f"Wave: {WAVE_OUTPUT_FILENAME}")
|
||||
print(f"Wave samples: {nframe}")
|
||||
print(f"Wave channels: {wf.getnchannels()}")
|
||||
print(f"Wave sample rate: {wf.getframerate()}")
|
||||
print(f"Wave sample width: {wf.getsampwidth()}")
|
||||
assert isinstance(data, bytes)
|
||||
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.connect((args.host_ip, args.host_port))
|
||||
sent = data
|
||||
sock.sendall(struct.pack('>i', len(sent)) + sent)
|
||||
print('Speech[length=%d] Sent.' % len(sent))
|
||||
# Receive data from the server and shut down
|
||||
received = sock.recv(1024)
|
||||
print("Recognition Results: {}".format(received.decode('utf8')))
|
||||
sock.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,216 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Server-end for the ASR demo."""
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
import argparse
|
||||
import functools
|
||||
from time import gmtime, strftime
|
||||
import socketserver
|
||||
import struct
|
||||
import wave
|
||||
import paddle
|
||||
import numpy as np
|
||||
|
||||
from deepspeech.training.cli import default_argument_parser
|
||||
from deepspeech.exps.deepspeech2.config import get_cfg_defaults
|
||||
|
||||
from deepspeech.frontend.utility import read_manifest
|
||||
from deepspeech.utils.utility import add_arguments, print_arguments
|
||||
|
||||
from deepspeech.models.deepspeech2 import DeepSpeech2Model
|
||||
from deepspeech.io.dataset import ManifestDataset
|
||||
|
||||
|
||||
class AsrTCPServer(socketserver.TCPServer):
|
||||
"""The ASR TCP Server."""
|
||||
|
||||
def __init__(self,
|
||||
server_address,
|
||||
RequestHandlerClass,
|
||||
speech_save_dir,
|
||||
audio_process_handler,
|
||||
bind_and_activate=True):
|
||||
self.speech_save_dir = speech_save_dir
|
||||
self.audio_process_handler = audio_process_handler
|
||||
socketserver.TCPServer.__init__(
|
||||
self, server_address, RequestHandlerClass, bind_and_activate=True)
|
||||
|
||||
|
||||
class AsrRequestHandler(socketserver.BaseRequestHandler):
|
||||
"""The ASR request handler."""
|
||||
|
||||
def handle(self):
|
||||
# receive data through TCP socket
|
||||
chunk = self.request.recv(1024)
|
||||
target_len = struct.unpack('>i', chunk[:4])[0]
|
||||
data = chunk[4:]
|
||||
while len(data) < target_len:
|
||||
chunk = self.request.recv(1024)
|
||||
data += chunk
|
||||
# write to file
|
||||
filename = self._write_to_file(data)
|
||||
|
||||
print("Received utterance[length=%d] from %s, saved to %s." %
|
||||
(len(data), self.client_address[0], filename))
|
||||
start_time = time.time()
|
||||
transcript = self.server.audio_process_handler(filename)
|
||||
finish_time = time.time()
|
||||
print("Response Time: %f, Transcript: %s" %
|
||||
(finish_time - start_time, transcript))
|
||||
self.request.sendall(transcript.encode('utf-8'))
|
||||
|
||||
def _write_to_file(self, data):
|
||||
# prepare save dir and filename
|
||||
if not os.path.exists(self.server.speech_save_dir):
|
||||
os.mkdir(self.server.speech_save_dir)
|
||||
timestamp = strftime("%Y%m%d%H%M%S", gmtime())
|
||||
out_filename = os.path.join(
|
||||
self.server.speech_save_dir,
|
||||
timestamp + "_" + self.client_address[0] + ".wav")
|
||||
# write to wav file
|
||||
file = wave.open(out_filename, 'wb')
|
||||
file.setnchannels(1)
|
||||
file.setsampwidth(2)
|
||||
file.setframerate(16000)
|
||||
file.writeframes(data)
|
||||
file.close()
|
||||
return out_filename
|
||||
|
||||
|
||||
def warm_up_test(audio_process_handler,
|
||||
manifest_path,
|
||||
num_test_cases,
|
||||
random_seed=0):
|
||||
"""Warming-up test."""
|
||||
manifest = read_manifest(manifest_path)
|
||||
rng = random.Random(random_seed)
|
||||
samples = rng.sample(manifest, num_test_cases)
|
||||
for idx, sample in enumerate(samples):
|
||||
print("Warm-up Test Case %d: %s", idx, sample['audio_filepath'])
|
||||
start_time = time.time()
|
||||
transcript = audio_process_handler(sample['audio_filepath'])
|
||||
finish_time = time.time()
|
||||
print("Response Time: %f, Transcript: %s" %
|
||||
(finish_time - start_time, transcript))
|
||||
|
||||
|
||||
def start_server(config, args):
|
||||
"""Start the ASR server"""
|
||||
dataset = ManifestDataset(
|
||||
config.data.test_manifest,
|
||||
config.data.vocab_filepath,
|
||||
config.data.mean_std_filepath,
|
||||
augmentation_config="{}",
|
||||
max_duration=config.data.max_duration,
|
||||
min_duration=config.data.min_duration,
|
||||
stride_ms=config.data.stride_ms,
|
||||
window_ms=config.data.window_ms,
|
||||
n_fft=config.data.n_fft,
|
||||
max_freq=config.data.max_freq,
|
||||
target_sample_rate=config.data.target_sample_rate,
|
||||
specgram_type=config.data.specgram_type,
|
||||
use_dB_normalization=config.data.use_dB_normalization,
|
||||
target_dB=config.data.target_dB,
|
||||
random_seed=config.data.random_seed,
|
||||
keep_transcription_text=True)
|
||||
|
||||
model = DeepSpeech2Model(
|
||||
feat_size=dataset.feature_size,
|
||||
dict_size=dataset.vocab_size,
|
||||
num_conv_layers=config.model.num_conv_layers,
|
||||
num_rnn_layers=config.model.num_rnn_layers,
|
||||
rnn_size=config.model.rnn_layer_size,
|
||||
use_gru=config.model.use_gru,
|
||||
share_rnn_weights=config.model.share_rnn_weights)
|
||||
model.from_pretrained(args.checkpoint_path)
|
||||
model.eval()
|
||||
|
||||
# prepare ASR inference handler
|
||||
def file_to_transcript(filename):
|
||||
feature = dataset.process_utterance(filename, "")
|
||||
audio = np.array([feature[0]]).astype('float32') #[1, D, T]
|
||||
audio_len = feature[0].shape[1]
|
||||
audio_len = np.array([audio_len]).astype('int64') # [1]
|
||||
|
||||
result_transcript = model.decode(
|
||||
paddle.to_tensor(audio),
|
||||
paddle.to_tensor(audio_len),
|
||||
vocab_list=dataset.vocab_list,
|
||||
decoding_method=config.decoding.decoding_method,
|
||||
lang_model_path=config.decoding.lang_model_path,
|
||||
beam_alpha=config.decoding.alpha,
|
||||
beam_beta=config.decoding.beta,
|
||||
beam_size=config.decoding.beam_size,
|
||||
cutoff_prob=config.decoding.cutoff_prob,
|
||||
cutoff_top_n=config.decoding.cutoff_top_n,
|
||||
num_processes=config.decoding.num_proc_bsearch)
|
||||
return result_transcript[0]
|
||||
|
||||
# warming up with utterrances sampled from Librispeech
|
||||
print('-----------------------------------------------------------')
|
||||
print('Warming up ...')
|
||||
warm_up_test(
|
||||
audio_process_handler=file_to_transcript,
|
||||
manifest_path=args.warmup_manifest,
|
||||
num_test_cases=3)
|
||||
print('-----------------------------------------------------------')
|
||||
|
||||
# start the server
|
||||
server = AsrTCPServer(
|
||||
server_address=(args.host_ip, args.host_port),
|
||||
RequestHandlerClass=AsrRequestHandler,
|
||||
speech_save_dir=args.speech_save_dir,
|
||||
audio_process_handler=file_to_transcript)
|
||||
print("ASR Server Started.")
|
||||
server.serve_forever()
|
||||
|
||||
|
||||
def main(config, args):
|
||||
start_server(config, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = default_argument_parser()
|
||||
add_arg = functools.partial(add_arguments, argparser=parser)
|
||||
# yapf: disable
|
||||
add_arg('host_ip', str,
|
||||
'localhost',
|
||||
"Server's IP address.")
|
||||
add_arg('host_port', int, 8086, "Server's IP port.")
|
||||
add_arg('speech_save_dir', str,
|
||||
'demo_cache',
|
||||
"Directory to save demo audios.")
|
||||
add_arg('warmup_manifest', str, None, "Filepath of manifest to warm up.")
|
||||
args = parser.parse_args()
|
||||
print_arguments(args)
|
||||
|
||||
# https://yaml.org/type/float.html
|
||||
config = get_cfg_defaults()
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
if args.opts:
|
||||
config.merge_from_list(args.opts)
|
||||
config.freeze()
|
||||
print(config)
|
||||
|
||||
args.warmup_manifest = config.data.test_manifest
|
||||
print_arguments(args)
|
||||
|
||||
if args.dump_config:
|
||||
with open(args.dump_config, 'w') as f:
|
||||
print(config, file=f)
|
||||
|
||||
main(config, args)
|
@ -1,252 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Server-end for the ASR demo."""
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
import argparse
|
||||
import functools
|
||||
from time import gmtime, strftime
|
||||
import SocketServer
|
||||
import struct
|
||||
import wave
|
||||
import paddle.fluid as fluid
|
||||
import numpy as np
|
||||
|
||||
from deepspeech.frontend.utility import read_manifest
|
||||
from deepspeech.utils.utility import add_arguments, print_arguments
|
||||
|
||||
from deepspeech.exps.deepspeech2.model import DeepSpeech2Model
|
||||
from deepspeech.exps.deepspeech2.dataset import DataGenerator
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
add_arg = functools.partial(add_arguments, argparser=parser)
|
||||
# yapf: disable
|
||||
add_arg('host_port', int, 8086, "Server's IP port.")
|
||||
add_arg('beam_size', int, 500, "Beam search width.")
|
||||
add_arg('num_conv_layers', int, 2, "# of convolution layers.")
|
||||
add_arg('num_rnn_layers', int, 3, "# of recurrent layers.")
|
||||
add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.")
|
||||
add_arg('alpha', float, 2.5, "Coef of LM for beam search.")
|
||||
add_arg('beta', float, 0.3, "Coef of WC for beam search.")
|
||||
add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.")
|
||||
add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.")
|
||||
add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.")
|
||||
add_arg('use_gpu', bool, True, "Use GPU or not.")
|
||||
add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across "
|
||||
"bi-directional RNNs. Not for GRU.")
|
||||
add_arg('host_ip', str,
|
||||
'localhost',
|
||||
"Server's IP address.")
|
||||
add_arg('speech_save_dir', str,
|
||||
'demo_cache',
|
||||
"Directory to save demo audios.")
|
||||
add_arg('warmup_manifest', str,
|
||||
'data/librispeech/manifest.test-clean',
|
||||
"Filepath of manifest to warm up.")
|
||||
add_arg('mean_std_path', str,
|
||||
'data/librispeech/mean_std.npz',
|
||||
"Filepath of normalizer's mean & std.")
|
||||
add_arg('vocab_path', str,
|
||||
'data/librispeech/eng_vocab.txt',
|
||||
"Filepath of vocabulary.")
|
||||
add_arg('model_path', str,
|
||||
'./checkpoints/libri/step_final',
|
||||
"If None, the training starts from scratch, "
|
||||
"otherwise, it resumes from the pre-trained model.")
|
||||
add_arg('lang_model_path', str,
|
||||
'lm/data/common_crawl_00.prune01111.trie.klm',
|
||||
"Filepath for language model.")
|
||||
add_arg('decoding_method', str,
|
||||
'ctc_beam_search',
|
||||
"Decoding method. Options: ctc_beam_search, ctc_greedy",
|
||||
choices = ['ctc_beam_search', 'ctc_greedy'])
|
||||
add_arg('specgram_type', str,
|
||||
'linear',
|
||||
"Audio feature type. Options: linear, mfcc.",
|
||||
choices=['linear', 'mfcc'])
|
||||
# yapf: disable
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
class AsrTCPServer(SocketServer.TCPServer):
|
||||
"""The ASR TCP Server."""
|
||||
|
||||
def __init__(self,
|
||||
server_address,
|
||||
RequestHandlerClass,
|
||||
speech_save_dir,
|
||||
audio_process_handler,
|
||||
bind_and_activate=True):
|
||||
self.speech_save_dir = speech_save_dir
|
||||
self.audio_process_handler = audio_process_handler
|
||||
SocketServer.TCPServer.__init__(
|
||||
self, server_address, RequestHandlerClass, bind_and_activate=True)
|
||||
|
||||
|
||||
class AsrRequestHandler(SocketServer.BaseRequestHandler):
|
||||
"""The ASR request handler."""
|
||||
|
||||
def handle(self):
|
||||
# receive data through TCP socket
|
||||
chunk = self.request.recv(1024)
|
||||
target_len = struct.unpack('>i', chunk[:4])[0]
|
||||
data = chunk[4:]
|
||||
while len(data) < target_len:
|
||||
chunk = self.request.recv(1024)
|
||||
data += chunk
|
||||
# write to file
|
||||
filename = self._write_to_file(data)
|
||||
|
||||
print("Received utterance[length=%d] from %s, saved to %s." %
|
||||
(len(data), self.client_address[0], filename))
|
||||
start_time = time.time()
|
||||
transcript = self.server.audio_process_handler(filename)
|
||||
finish_time = time.time()
|
||||
print("Response Time: %f, Transcript: %s" %
|
||||
(finish_time - start_time, transcript))
|
||||
self.request.sendall(transcript.encode('utf-8'))
|
||||
|
||||
def _write_to_file(self, data):
|
||||
# prepare save dir and filename
|
||||
if not os.path.exists(self.server.speech_save_dir):
|
||||
os.mkdir(self.server.speech_save_dir)
|
||||
timestamp = strftime("%Y%m%d%H%M%S", gmtime())
|
||||
out_filename = os.path.join(
|
||||
self.server.speech_save_dir,
|
||||
timestamp + "_" + self.client_address[0] + ".wav")
|
||||
# write to wav file
|
||||
file = wave.open(out_filename, 'wb')
|
||||
file.setnchannels(1)
|
||||
file.setsampwidth(4)
|
||||
file.setframerate(16000)
|
||||
file.writeframes(data)
|
||||
file.close()
|
||||
return out_filename
|
||||
|
||||
|
||||
def warm_up_test(audio_process_handler,
|
||||
manifest_path,
|
||||
num_test_cases,
|
||||
random_seed=0):
|
||||
"""Warming-up test."""
|
||||
manifest = read_manifest(manifest_path)
|
||||
rng = random.Random(random_seed)
|
||||
samples = rng.sample(manifest, num_test_cases)
|
||||
for idx, sample in enumerate(samples):
|
||||
print("Warm-up Test Case %d: %s", idx, sample['audio_filepath'])
|
||||
start_time = time.time()
|
||||
transcript = audio_process_handler(sample['audio_filepath'])
|
||||
finish_time = time.time()
|
||||
print("Response Time: %f, Transcript: %s" %
|
||||
(finish_time - start_time, transcript))
|
||||
|
||||
|
||||
def start_server():
|
||||
"""Start the ASR server"""
|
||||
# prepare data generator
|
||||
if args.use_gpu:
|
||||
place = fluid.CUDAPlace(0)
|
||||
else:
|
||||
place = fluid.CPUPlace()
|
||||
|
||||
data_generator = DataGenerator(
|
||||
vocab_filepath=args.vocab_path,
|
||||
mean_std_filepath=args.mean_std_path,
|
||||
augmentation_config='{}',
|
||||
specgram_type=args.specgram_type,
|
||||
keep_transcription_text=True,
|
||||
place = place,
|
||||
is_training = False)
|
||||
# prepare ASR model
|
||||
ds2_model = DeepSpeech2Model(
|
||||
vocab_size=data_generator.vocab_size,
|
||||
num_conv_layers=args.num_conv_layers,
|
||||
num_rnn_layers=args.num_rnn_layers,
|
||||
rnn_layer_size=args.rnn_layer_size,
|
||||
use_gru=args.use_gru,
|
||||
init_from_pretrained_model=args.model_path,
|
||||
place=place,
|
||||
share_rnn_weights=args.share_rnn_weights)
|
||||
|
||||
vocab_list = [chars for chars in data_generator.vocab_list]
|
||||
|
||||
if args.decoding_method == "ctc_beam_search":
|
||||
ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path,
|
||||
vocab_list)
|
||||
# prepare ASR inference handler
|
||||
def file_to_transcript(filename):
|
||||
feature = data_generator.process_utterance(filename, "")
|
||||
audio_len = feature[0].shape[1]
|
||||
mask_shape0 = (feature[0].shape[0] - 1) // 2 + 1
|
||||
mask_shape1 = (feature[0].shape[1] - 1) // 3 + 1
|
||||
mask_max_len = (audio_len - 1) // 3 + 1
|
||||
mask_ones = np.ones((mask_shape0, mask_shape1))
|
||||
mask_zeros = np.zeros((mask_shape0, mask_max_len - mask_shape1))
|
||||
mask = np.repeat(
|
||||
np.reshape(
|
||||
np.concatenate((mask_ones, mask_zeros), axis=1),
|
||||
(1, mask_shape0, mask_max_len)),
|
||||
32,
|
||||
axis=0)
|
||||
feature = (np.array([feature[0]]).astype('float32'),
|
||||
None,
|
||||
np.array([audio_len]).astype('int64').reshape([-1,1]),
|
||||
np.array([mask]).astype('float32'))
|
||||
probs_split = ds2_model.infer_batch_probs(
|
||||
infer_data=feature,
|
||||
feeding_dict=data_generator.feeding)
|
||||
|
||||
if args.decoding_method == "ctc_greedy":
|
||||
result_transcript = ds2_model.decode_batch_greedy(
|
||||
probs_split=probs_split,
|
||||
vocab_list=vocab_list)
|
||||
else:
|
||||
result_transcript = ds2_model.decode_batch_beam_search(
|
||||
probs_split=probs_split,
|
||||
beam_alpha=args.alpha,
|
||||
beam_beta=args.beta,
|
||||
beam_size=args.beam_size,
|
||||
cutoff_prob=args.cutoff_prob,
|
||||
cutoff_top_n=args.cutoff_top_n,
|
||||
vocab_list=vocab_list,
|
||||
num_processes=1)
|
||||
return result_transcript[0]
|
||||
|
||||
# warming up with utterrances sampled from Librispeech
|
||||
print('-----------------------------------------------------------')
|
||||
print('Warming up ...')
|
||||
warm_up_test(
|
||||
audio_process_handler=file_to_transcript,
|
||||
manifest_path=args.warmup_manifest,
|
||||
num_test_cases=3)
|
||||
print('-----------------------------------------------------------')
|
||||
|
||||
# start the server
|
||||
server = AsrTCPServer(
|
||||
server_address=(args.host_ip, args.host_port),
|
||||
RequestHandlerClass=AsrRequestHandler,
|
||||
speech_save_dir=args.speech_save_dir,
|
||||
audio_process_handler=file_to_transcript)
|
||||
print("ASR Server Started.")
|
||||
server.serve_forever()
|
||||
|
||||
|
||||
def main():
|
||||
print_arguments(args)
|
||||
start_server()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,2 +1,4 @@
|
||||
data
|
||||
ckpt*
|
||||
demo_cache
|
||||
*.log
|
||||
|
@ -0,0 +1,40 @@
|
||||
#! /usr/bin/env bash
|
||||
# TODO: replace the model with a mandarin model
|
||||
|
||||
if [[ $# != 1 ]];then
|
||||
echo "usage: server.sh checkpoint_path"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
source path.sh
|
||||
|
||||
# download language model
|
||||
bash local/download_lm_ch.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# download well-trained model
|
||||
bash local/download_model.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# start demo server
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python3 -u ${BIN_DIR}/deploy/server.py \
|
||||
--device 'gpu' \
|
||||
--nproc 1 \
|
||||
--config conf/deepspeech2.yaml \
|
||||
--host_ip="localhost" \
|
||||
--host_port=8086 \
|
||||
--speech_save_dir="demo_cache" \
|
||||
--checkpoint_path ${1}
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in starting demo server!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -1,8 +0,0 @@
|
||||
export MAIN_ROOT=${PWD}/../../
|
||||
|
||||
export PATH=${MAIN_ROOT}:${PWD}/tools:${PATH}
|
||||
export LC_ALL=C
|
||||
|
||||
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
@ -1,54 +0,0 @@
|
||||
#! /usr/bin/env bash
|
||||
# TODO: replace the model with a mandarin model
|
||||
|
||||
source path.sh
|
||||
|
||||
# download language model
|
||||
cd ${MAIN_ROOT}/models/lm > /dev/null
|
||||
bash download_lm_en.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
cd - > /dev/null
|
||||
|
||||
|
||||
# download well-trained model
|
||||
cd ${MAIN_ROOT}/models/baidu_en8k > /dev/null
|
||||
bash download_model.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
cd - > /dev/null
|
||||
|
||||
|
||||
# start demo server
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python3 -u ${MAIN_ROOT}/deploy/demo_server.py \
|
||||
--host_ip="localhost" \
|
||||
--host_port=8086 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=1024 \
|
||||
--alpha=1.15 \
|
||||
--beta=0.15 \
|
||||
--cutoff_prob=1.0 \
|
||||
--cutoff_top_n=40 \
|
||||
--use_gru=True \
|
||||
--use_gpu=True \
|
||||
--share_rnn_weights=False \
|
||||
--speech_save_dir="demo_cache" \
|
||||
--warmup_manifest="${MAIN_ROOT}/examples/tiny/data/manifest.test-clean" \
|
||||
--mean_std_path="${MAIN_ROOT}/models/baidu_en8k/mean_std.npz" \
|
||||
--vocab_path="${MAIN_ROOT}/models/baidu_en8k/vocab.txt" \
|
||||
--model_path="${MAIN_ROOT}/models/baidu_en8k" \
|
||||
--lang_model_path="${MAIN_ROOT}/models/lm/common_crawl_00.prune01111.trie.klm" \
|
||||
--decoding_method="ctc_beam_search" \
|
||||
--specgram_type="linear"
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in starting demo server!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
Loading…
Reference in new issue