fix runtime.py and server.py

pull/684/head
Haoxin Ma 4 years ago
parent d55e6b5a0a
commit c753b9ddf2

@ -81,15 +81,15 @@ def inference(config, args):
def start_server(config, args): def start_server(config, args):
"""Start the ASR server""" """Start the ASR server"""
config.defrost() config.defrost()
config.data.manfiest = config.data.test_manifest config.data.manifest = config.data.test_manifest
config.data.augmentation_config = ""
config.data.keep_transcription_text = True
dataset = ManifestDataset.from_config(config) dataset = ManifestDataset.from_config(config)
config.collator.augmentation_config = ""
config.collator.keep_transcription_text = True
config.collator.batch_size=1 config.collator.batch_size=1
config.collator.num_workers=0 config.collator.num_workers=0
collate_fn = SpeechCollator.from_config(config) collate_fn = SpeechCollator.from_config(config)
test_loader = DataLoader(dataset_dataset, collate_fn=collate_fn, num_workers=0) test_loader = DataLoader(dataset, collate_fn=collate_fn, num_workers=0)
model = DeepSpeech2Model.from_pretrained(test_loader, config, model = DeepSpeech2Model.from_pretrained(test_loader, config,
args.checkpoint_path) args.checkpoint_path)
@ -97,15 +97,15 @@ def start_server(config, args):
# prepare ASR inference handler # prepare ASR inference handler
def file_to_transcript(filename): def file_to_transcript(filename):
feature = dataset.process_utterance(filename, "") feature = collate_fn.process_utterance(filename, "")
audio = np.array([feature[0]]).astype('float32') #[1, D, T] audio = np.array([feature[0]]).astype('float32') #[1, T, D]
audio_len = feature[0].shape[1] audio_len = feature[0].shape[0]
audio_len = np.array([audio_len]).astype('int64') # [1] audio_len = np.array([audio_len]).astype('int64') # [1]
result_transcript = model.decode( result_transcript = model.decode(
paddle.to_tensor(audio), paddle.to_tensor(audio),
paddle.to_tensor(audio_len), paddle.to_tensor(audio_len),
vocab_list=dataset.vocab_list, vocab_list=test_loader.collate_fn.vocab_list,
decoding_method=config.decoding.decoding_method, decoding_method=config.decoding.decoding_method,
lang_model_path=config.decoding.lang_model_path, lang_model_path=config.decoding.lang_model_path,
beam_alpha=config.decoding.alpha, beam_alpha=config.decoding.alpha,
@ -146,7 +146,7 @@ if __name__ == "__main__":
add_arg('host_ip', str, add_arg('host_ip', str,
'localhost', 'localhost',
"Server's IP address.") "Server's IP address.")
add_arg('host_port', int, 8086, "Server's IP port.") add_arg('host_port', int, 8089, "Server's IP port.")
add_arg('speech_save_dir', str, add_arg('speech_save_dir', str,
'demo_cache', 'demo_cache',
"Directory to save demo audios.") "Directory to save demo audios.")

@ -34,15 +34,15 @@ from deepspeech.io.collator import SpeechCollator
def start_server(config, args): def start_server(config, args):
"""Start the ASR server""" """Start the ASR server"""
config.defrost() config.defrost()
config.data.manfiest = config.data.test_manifest config.data.manifest = config.data.test_manifest
config.data.augmentation_config = ""
config.data.keep_transcription_text = True
dataset = ManifestDataset.from_config(config) dataset = ManifestDataset.from_config(config)
config.collator.augmentation_config = ""
config.collator.keep_transcription_text = True
config.collator.batch_size=1 config.collator.batch_size=1
config.collator.num_workers=0 config.collator.num_workers=0
collate_fn = SpeechCollator.from_config(config) collate_fn = SpeechCollator.from_config(config)
test_loader = DataLoader(dataset_dataset, collate_fn=collate_fn, num_workers=0) test_loader = DataLoader(dataset, collate_fn=collate_fn, num_workers=0)
model = DeepSpeech2Model.from_pretrained(test_loader, config, model = DeepSpeech2Model.from_pretrained(test_loader, config,
args.checkpoint_path) args.checkpoint_path)
@ -50,15 +50,19 @@ def start_server(config, args):
# prepare ASR inference handler # prepare ASR inference handler
def file_to_transcript(filename): def file_to_transcript(filename):
feature = dataset.process_utterance(filename, "") feature = test_loader.collate_fn.process_utterance(filename, "")
audio = np.array([feature[0]]).astype('float32') #[1, D, T] audio = np.array([feature[0]]).astype('float32') #[1, T, D]
audio_len = feature[0].shape[1] # audio = audio.swapaxes(1,2)
print('---file_to_transcript feature----')
print(audio.shape)
audio_len = feature[0].shape[0]
print(audio_len)
audio_len = np.array([audio_len]).astype('int64') # [1] audio_len = np.array([audio_len]).astype('int64') # [1]
result_transcript = model.decode( result_transcript = model.decode(
paddle.to_tensor(audio), paddle.to_tensor(audio),
paddle.to_tensor(audio_len), paddle.to_tensor(audio_len),
vocab_list=dataset.vocab_list, vocab_list=test_loader.collate_fn.vocab_list,
decoding_method=config.decoding.decoding_method, decoding_method=config.decoding.decoding_method,
lang_model_path=config.decoding.lang_model_path, lang_model_path=config.decoding.lang_model_path,
beam_alpha=config.decoding.alpha, beam_alpha=config.decoding.alpha,
@ -99,7 +103,7 @@ if __name__ == "__main__":
add_arg('host_ip', str, add_arg('host_ip', str,
'localhost', 'localhost',
"Server's IP address.") "Server's IP address.")
add_arg('host_port', int, 8086, "Server's IP port.") add_arg('host_port', int, 8088, "Server's IP port.")
add_arg('speech_save_dir', str, add_arg('speech_save_dir', str,
'demo_cache', 'demo_cache',
"Directory to save demo audios.") "Directory to save demo audios.")

@ -242,6 +242,7 @@ class SpeechCollator():
# specgram augment # specgram augment
specgram = self._augmentation_pipeline.transform_feature(specgram) specgram = self._augmentation_pipeline.transform_feature(specgram)
specgram=specgram.transpose([1,0])
return specgram, transcript_part return specgram, transcript_part
def __call__(self, batch): def __call__(self, batch):
@ -269,7 +270,7 @@ class SpeechCollator():
#utt #utt
utts.append(utt) utts.append(utt)
# audio # audio
audios.append(audio.T) # [T, D] audios.append(audio) # [T, D]
audio_lens.append(audio.shape[1]) audio_lens.append(audio.shape[1])
# text # text
# for training, text is token ids # for training, text is token ids

@ -48,9 +48,9 @@ def warm_up_test(audio_process_handler,
rng = random.Random(random_seed) rng = random.Random(random_seed)
samples = rng.sample(manifest, num_test_cases) samples = rng.sample(manifest, num_test_cases)
for idx, sample in enumerate(samples): for idx, sample in enumerate(samples):
print("Warm-up Test Case %d: %s", idx, sample['audio_filepath']) print("Warm-up Test Case %d: %s"%(idx, sample['feat']))
start_time = time.time() start_time = time.time()
transcript = audio_process_handler(sample['audio_filepath']) transcript = audio_process_handler(sample['feat'])
finish_time = time.time() finish_time = time.time()
print("Response Time: %f, Transcript: %s" % print("Response Time: %f, Transcript: %s" %
(finish_time - start_time, transcript)) (finish_time - start_time, transcript))

Loading…
Cancel
Save