diff --git a/data_utils/audio.py b/data_utils/audio.py index 29fdd0bd..3891f5b9 100644 --- a/data_utils/audio.py +++ b/data_utils/audio.py @@ -64,8 +64,6 @@ class AudioSegment(object): :rtype: AudioSegment """ samples, sample_rate = soundfile.read(file, dtype='float32') - print(samples) - print(sample_rate) return cls(samples, sample_rate) @classmethod diff --git a/demo_server.py b/demo_server.py index 4a3feb13..85f69483 100644 --- a/demo_server.py +++ b/demo_server.py @@ -1,5 +1,6 @@ import os import time +import random import argparse import distutils.util from time import gmtime, strftime @@ -8,9 +9,10 @@ import struct import wave import pyaudio import paddle.v2 as paddle +from utils import print_arguments from data_utils.data import DataGenerator from model import DeepSpeech2Model -import utils +from data_utils.utils import read_manifest parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( @@ -38,6 +40,11 @@ parser.add_argument( default='mean_std.npz', type=str, help="Manifest path for normalizer. (default: %(default)s)") +parser.add_argument( + "--warmup_manifest_path", + default='datasets/manifest.test', + type=str, + help="Manifest path for warmup test. (default: %(default)s)") parser.add_argument( "--specgram_type", default='linear', @@ -77,7 +84,7 @@ parser.add_argument( "(default: %(default)s)") parser.add_argument( "--beam_size", - default=500, + default=100, type=int, help="Width for beam search decoding. (default: %(default)d)") parser.add_argument( @@ -134,7 +141,6 @@ class AsrRequestHandler(SocketServer.BaseRequestHandler): print("Received utterance[length=%d] from %s, saved to %s." % (len(data), self.client_address[0], filename)) - #filename = "/home/work/.cache/paddle/dataset/speech/Libri/train-other-500/LibriSpeech/train-other-500/811/130143/811-130143-0025.flac" start_time = time.time() transcript = self.server.audio_process_handler(filename) finish_time = time.time() @@ -149,7 +155,7 @@ class AsrRequestHandler(SocketServer.BaseRequestHandler): timestamp = strftime("%Y%m%d%H%M%S", gmtime()) out_filename = os.path.join( self.server.speech_save_dir, - timestamp + "_" + self.client_address[0] + "_" + ".wav") + timestamp + "_" + self.client_address[0] + ".wav") # write to wav file file = wave.open(out_filename, 'wb') file.setnchannels(1) @@ -160,6 +166,22 @@ class AsrRequestHandler(SocketServer.BaseRequestHandler): return out_filename +def warm_up_test(audio_process_handler, + manifest_path, + num_test_cases, + random_seed=0): + manifest = read_manifest(manifest_path) + rng = random.Random(random_seed) + samples = rng.sample(manifest, num_test_cases) + for idx, sample in enumerate(samples): + print("Warm-up Test Case %d: %s", idx, sample['audio_filepath']) + start_time = time.time() + transcript = audio_process_handler(sample['audio_filepath']) + finish_time = time.time() + print("Response Time: %f, Transcript: %s" % + (finish_time - start_time, transcript)) + + def start_server(): data_generator = DataGenerator( vocab_filepath=args.vocab_filepath, @@ -188,6 +210,14 @@ def start_server(): num_processes=1) return result_transcript[0] + print('-----------------------------------------------------------') + print('Warming up ...') + warm_up_test( + audio_process_handler=file_to_transcript, + manifest_path=args.warmup_manifest_path, + num_test_cases=3) + print('-----------------------------------------------------------') + server = AsrTCPServer( server_address=(args.host_ip, args.host_port), RequestHandlerClass=AsrRequestHandler, @@ -199,7 +229,7 @@ def start_server(): def main(): - utils.print_arguments(args) + print_arguments(args) paddle.init(use_gpu=args.use_gpu, trainer_count=1) start_server()