diff --git a/demo_client.py b/demo_client.py index 97649fd4..a789d816 100644 --- a/demo_client.py +++ b/demo_client.py @@ -11,6 +11,7 @@ enable_trigger_record = True def on_press(key): + """On-press keyboard callback function.""" global is_recording, enable_trigger_record if key == keyboard.Key.space: if (not is_recording) and enable_trigger_record: @@ -20,6 +21,7 @@ def on_press(key): def on_release(key): + """On-release keyboard callback function.""" global is_recording, enable_trigger_record if key == keyboard.Key.esc: return False @@ -32,6 +34,7 @@ data_list = [] def callback(in_data, frame_count, time_info, status): + """Audio recorder's stream callback function.""" global data_list, is_recording, enable_trigger_record if is_recording: data_list.append(in_data) @@ -53,6 +56,7 @@ def callback(in_data, frame_count, time_info, status): def main(): + # prepare audio recorder p = pyaudio.PyAudio() stream = p.open( format=pyaudio.paInt32, @@ -62,10 +66,12 @@ def main(): stream_callback=callback) stream.start_stream() + # prepare keyboard listener with keyboard.Listener( on_press=on_press, on_release=on_release) as listener: listener.join() + # close up stream.stop_stream() stream.close() p.terminate() diff --git a/demo_server.py b/demo_server.py index 85f69483..d6c0de40 100644 --- a/demo_server.py +++ b/demo_server.py @@ -112,6 +112,8 @@ args = parser.parse_args() class AsrTCPServer(SocketServer.TCPServer): + """The ASR TCP Server.""" + def __init__(self, server_address, RequestHandlerClass, @@ -125,8 +127,7 @@ class AsrTCPServer(SocketServer.TCPServer): class AsrRequestHandler(SocketServer.BaseRequestHandler): - """The ASR request handler. - """ + """The ASR request handler.""" def handle(self): # receive data through TCP socket @@ -170,6 +171,7 @@ def warm_up_test(audio_process_handler, manifest_path, num_test_cases, random_seed=0): + """Warming-up test.""" manifest = read_manifest(manifest_path) rng = random.Random(random_seed) samples = rng.sample(manifest, num_test_cases) @@ -183,12 +185,15 @@ def warm_up_test(audio_process_handler, def start_server(): + """Start the ASR server""" + # prepare data generator data_generator = DataGenerator( vocab_filepath=args.vocab_filepath, mean_std_filepath=args.mean_std_filepath, augmentation_config='{}', specgram_type=args.specgram_type, num_threads=1) + # prepare ASR model ds2_model = DeepSpeech2Model( vocab_size=data_generator.vocab_size, num_conv_layers=args.num_conv_layers, @@ -196,6 +201,7 @@ def start_server(): rnn_layer_size=args.rnn_layer_size, pretrained_model_path=args.model_filepath) + # prepare ASR inference handler def file_to_transcript(filename): feature = data_generator.process_utterance(filename, "") result_transcript = ds2_model.infer_batch( @@ -210,6 +216,7 @@ def start_server(): num_processes=1) return result_transcript[0] + # warming up with utterrances sampled from Librispeech print('-----------------------------------------------------------') print('Warming up ...') warm_up_test( @@ -218,12 +225,12 @@ def start_server(): num_test_cases=3) print('-----------------------------------------------------------') + # start the server server = AsrTCPServer( server_address=(args.host_ip, args.host_port), RequestHandlerClass=AsrRequestHandler, speech_save_dir=args.speech_save_dir, audio_process_handler=file_to_transcript) - print("ASR Server Started.") server.serve_forever()