Add function docs and comments to demo_server.py and demo_client.py.

pull/2/head
Xinghai Sun 7 years ago
parent cb9370f308
commit a4c2dd7de2

@ -11,6 +11,7 @@ enable_trigger_record = True
def on_press(key): def on_press(key):
"""On-press keyboard callback function."""
global is_recording, enable_trigger_record global is_recording, enable_trigger_record
if key == keyboard.Key.space: if key == keyboard.Key.space:
if (not is_recording) and enable_trigger_record: if (not is_recording) and enable_trigger_record:
@ -20,6 +21,7 @@ def on_press(key):
def on_release(key): def on_release(key):
"""On-release keyboard callback function."""
global is_recording, enable_trigger_record global is_recording, enable_trigger_record
if key == keyboard.Key.esc: if key == keyboard.Key.esc:
return False return False
@ -32,6 +34,7 @@ data_list = []
def callback(in_data, frame_count, time_info, status): def callback(in_data, frame_count, time_info, status):
"""Audio recorder's stream callback function."""
global data_list, is_recording, enable_trigger_record global data_list, is_recording, enable_trigger_record
if is_recording: if is_recording:
data_list.append(in_data) data_list.append(in_data)
@ -53,6 +56,7 @@ def callback(in_data, frame_count, time_info, status):
def main(): def main():
# prepare audio recorder
p = pyaudio.PyAudio() p = pyaudio.PyAudio()
stream = p.open( stream = p.open(
format=pyaudio.paInt32, format=pyaudio.paInt32,
@ -62,10 +66,12 @@ def main():
stream_callback=callback) stream_callback=callback)
stream.start_stream() stream.start_stream()
# prepare keyboard listener
with keyboard.Listener( with keyboard.Listener(
on_press=on_press, on_release=on_release) as listener: on_press=on_press, on_release=on_release) as listener:
listener.join() listener.join()
# close up
stream.stop_stream() stream.stop_stream()
stream.close() stream.close()
p.terminate() p.terminate()

@ -112,6 +112,8 @@ args = parser.parse_args()
class AsrTCPServer(SocketServer.TCPServer): class AsrTCPServer(SocketServer.TCPServer):
"""The ASR TCP Server."""
def __init__(self, def __init__(self,
server_address, server_address,
RequestHandlerClass, RequestHandlerClass,
@ -125,8 +127,7 @@ class AsrTCPServer(SocketServer.TCPServer):
class AsrRequestHandler(SocketServer.BaseRequestHandler): class AsrRequestHandler(SocketServer.BaseRequestHandler):
"""The ASR request handler. """The ASR request handler."""
"""
def handle(self): def handle(self):
# receive data through TCP socket # receive data through TCP socket
@ -170,6 +171,7 @@ def warm_up_test(audio_process_handler,
manifest_path, manifest_path,
num_test_cases, num_test_cases,
random_seed=0): random_seed=0):
"""Warming-up test."""
manifest = read_manifest(manifest_path) manifest = read_manifest(manifest_path)
rng = random.Random(random_seed) rng = random.Random(random_seed)
samples = rng.sample(manifest, num_test_cases) samples = rng.sample(manifest, num_test_cases)
@ -183,12 +185,15 @@ def warm_up_test(audio_process_handler,
def start_server(): def start_server():
"""Start the ASR server"""
# prepare data generator
data_generator = DataGenerator( data_generator = DataGenerator(
vocab_filepath=args.vocab_filepath, vocab_filepath=args.vocab_filepath,
mean_std_filepath=args.mean_std_filepath, mean_std_filepath=args.mean_std_filepath,
augmentation_config='{}', augmentation_config='{}',
specgram_type=args.specgram_type, specgram_type=args.specgram_type,
num_threads=1) num_threads=1)
# prepare ASR model
ds2_model = DeepSpeech2Model( ds2_model = DeepSpeech2Model(
vocab_size=data_generator.vocab_size, vocab_size=data_generator.vocab_size,
num_conv_layers=args.num_conv_layers, num_conv_layers=args.num_conv_layers,
@ -196,6 +201,7 @@ def start_server():
rnn_layer_size=args.rnn_layer_size, rnn_layer_size=args.rnn_layer_size,
pretrained_model_path=args.model_filepath) pretrained_model_path=args.model_filepath)
# prepare ASR inference handler
def file_to_transcript(filename): def file_to_transcript(filename):
feature = data_generator.process_utterance(filename, "") feature = data_generator.process_utterance(filename, "")
result_transcript = ds2_model.infer_batch( result_transcript = ds2_model.infer_batch(
@ -210,6 +216,7 @@ def start_server():
num_processes=1) num_processes=1)
return result_transcript[0] return result_transcript[0]
# warming up with utterrances sampled from Librispeech
print('-----------------------------------------------------------') print('-----------------------------------------------------------')
print('Warming up ...') print('Warming up ...')
warm_up_test( warm_up_test(
@ -218,12 +225,12 @@ def start_server():
num_test_cases=3) num_test_cases=3)
print('-----------------------------------------------------------') print('-----------------------------------------------------------')
# start the server
server = AsrTCPServer( server = AsrTCPServer(
server_address=(args.host_ip, args.host_port), server_address=(args.host_ip, args.host_port),
RequestHandlerClass=AsrRequestHandler, RequestHandlerClass=AsrRequestHandler,
speech_save_dir=args.speech_save_dir, speech_save_dir=args.speech_save_dir,
audio_process_handler=file_to_transcript) audio_process_handler=file_to_transcript)
print("ASR Server Started.") print("ASR Server Started.")
server.serve_forever() server.serve_forever()

Loading…
Cancel
Save