Add warming-up to demo_server.py for DS2 and clean codes.

pull/2/head
Xinghai Sun 8 years ago
parent 0ebf36b98f
commit cb9370f308

@ -64,8 +64,6 @@ class AudioSegment(object):
:rtype: AudioSegment :rtype: AudioSegment
""" """
samples, sample_rate = soundfile.read(file, dtype='float32') samples, sample_rate = soundfile.read(file, dtype='float32')
print(samples)
print(sample_rate)
return cls(samples, sample_rate) return cls(samples, sample_rate)
@classmethod @classmethod

@ -1,5 +1,6 @@
import os import os
import time import time
import random
import argparse import argparse
import distutils.util import distutils.util
from time import gmtime, strftime from time import gmtime, strftime
@ -8,9 +9,10 @@ import struct
import wave import wave
import pyaudio import pyaudio
import paddle.v2 as paddle import paddle.v2 as paddle
from utils import print_arguments
from data_utils.data import DataGenerator from data_utils.data import DataGenerator
from model import DeepSpeech2Model from model import DeepSpeech2Model
import utils from data_utils.utils import read_manifest
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument( parser.add_argument(
@ -38,6 +40,11 @@ parser.add_argument(
default='mean_std.npz', default='mean_std.npz',
type=str, type=str,
help="Manifest path for normalizer. (default: %(default)s)") help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument(
"--warmup_manifest_path",
default='datasets/manifest.test',
type=str,
help="Manifest path for warmup test. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--specgram_type", "--specgram_type",
default='linear', default='linear',
@ -77,7 +84,7 @@ parser.add_argument(
"(default: %(default)s)") "(default: %(default)s)")
parser.add_argument( parser.add_argument(
"--beam_size", "--beam_size",
default=500, default=100,
type=int, type=int,
help="Width for beam search decoding. (default: %(default)d)") help="Width for beam search decoding. (default: %(default)d)")
parser.add_argument( parser.add_argument(
@ -134,7 +141,6 @@ class AsrRequestHandler(SocketServer.BaseRequestHandler):
print("Received utterance[length=%d] from %s, saved to %s." % print("Received utterance[length=%d] from %s, saved to %s." %
(len(data), self.client_address[0], filename)) (len(data), self.client_address[0], filename))
#filename = "/home/work/.cache/paddle/dataset/speech/Libri/train-other-500/LibriSpeech/train-other-500/811/130143/811-130143-0025.flac"
start_time = time.time() start_time = time.time()
transcript = self.server.audio_process_handler(filename) transcript = self.server.audio_process_handler(filename)
finish_time = time.time() finish_time = time.time()
@ -149,7 +155,7 @@ class AsrRequestHandler(SocketServer.BaseRequestHandler):
timestamp = strftime("%Y%m%d%H%M%S", gmtime()) timestamp = strftime("%Y%m%d%H%M%S", gmtime())
out_filename = os.path.join( out_filename = os.path.join(
self.server.speech_save_dir, self.server.speech_save_dir,
timestamp + "_" + self.client_address[0] + "_" + ".wav") timestamp + "_" + self.client_address[0] + ".wav")
# write to wav file # write to wav file
file = wave.open(out_filename, 'wb') file = wave.open(out_filename, 'wb')
file.setnchannels(1) file.setnchannels(1)
@ -160,6 +166,22 @@ class AsrRequestHandler(SocketServer.BaseRequestHandler):
return out_filename return out_filename
def warm_up_test(audio_process_handler,
manifest_path,
num_test_cases,
random_seed=0):
manifest = read_manifest(manifest_path)
rng = random.Random(random_seed)
samples = rng.sample(manifest, num_test_cases)
for idx, sample in enumerate(samples):
print("Warm-up Test Case %d: %s", idx, sample['audio_filepath'])
start_time = time.time()
transcript = audio_process_handler(sample['audio_filepath'])
finish_time = time.time()
print("Response Time: %f, Transcript: %s" %
(finish_time - start_time, transcript))
def start_server(): def start_server():
data_generator = DataGenerator( data_generator = DataGenerator(
vocab_filepath=args.vocab_filepath, vocab_filepath=args.vocab_filepath,
@ -188,6 +210,14 @@ def start_server():
num_processes=1) num_processes=1)
return result_transcript[0] return result_transcript[0]
print('-----------------------------------------------------------')
print('Warming up ...')
warm_up_test(
audio_process_handler=file_to_transcript,
manifest_path=args.warmup_manifest_path,
num_test_cases=3)
print('-----------------------------------------------------------')
server = AsrTCPServer( server = AsrTCPServer(
server_address=(args.host_ip, args.host_port), server_address=(args.host_ip, args.host_port),
RequestHandlerClass=AsrRequestHandler, RequestHandlerClass=AsrRequestHandler,
@ -199,7 +229,7 @@ def start_server():
def main(): def main():
utils.print_arguments(args) print_arguments(args)
paddle.init(use_gpu=args.use_gpu, trainer_count=1) paddle.init(use_gpu=args.use_gpu, trainer_count=1)
start_server() start_server()

Loading…
Cancel
Save