|
|
@ -37,6 +37,11 @@ from deepspeech.utils import mp_tools
|
|
|
|
from deepspeech.utils.log import Log
|
|
|
|
from deepspeech.utils.log import Log
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import auto_log
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
from paddle import inference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DeepSpeech2Trainer(Trainer):
|
|
|
|
class DeepSpeech2Trainer(Trainer):
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
@ -223,6 +228,27 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, config, args):
|
|
|
|
def __init__(self, config, args):
|
|
|
|
super().__init__(config, args)
|
|
|
|
super().__init__(config, args)
|
|
|
|
|
|
|
|
# added by hyx
|
|
|
|
|
|
|
|
pid = os.getpid()
|
|
|
|
|
|
|
|
infer_config = inference.Config()
|
|
|
|
|
|
|
|
infer_config.enable_use_gpu(10000, 2)
|
|
|
|
|
|
|
|
logger = None
|
|
|
|
|
|
|
|
autolog = auto_log.AutoLogger(
|
|
|
|
|
|
|
|
model_name="tiny_s0",
|
|
|
|
|
|
|
|
model_precision="fp32",
|
|
|
|
|
|
|
|
batch_size=config.decoding.batch_size,
|
|
|
|
|
|
|
|
data_shape="dynamic",
|
|
|
|
|
|
|
|
save_path="./output/auto_log.lpg",
|
|
|
|
|
|
|
|
inference_config=infer_config,
|
|
|
|
|
|
|
|
pids=pid,
|
|
|
|
|
|
|
|
process_name=None,
|
|
|
|
|
|
|
|
gpu_ids=2,
|
|
|
|
|
|
|
|
time_keys=[
|
|
|
|
|
|
|
|
'preprocess_time', 'inference_time', 'postprocess_time'
|
|
|
|
|
|
|
|
],
|
|
|
|
|
|
|
|
warmup=0)
|
|
|
|
|
|
|
|
self.autolog = autolog
|
|
|
|
|
|
|
|
logger = autolog.logger
|
|
|
|
|
|
|
|
|
|
|
|
def ordid2token(self, texts, texts_len):
|
|
|
|
def ordid2token(self, texts, texts_len):
|
|
|
|
""" ord() id to chr() chr """
|
|
|
|
""" ord() id to chr() chr """
|
|
|
@ -248,6 +274,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
vocab_list = self.test_loader.collate_fn.vocab_list
|
|
|
|
vocab_list = self.test_loader.collate_fn.vocab_list
|
|
|
|
|
|
|
|
|
|
|
|
target_transcripts = self.ordid2token(texts, texts_len)
|
|
|
|
target_transcripts = self.ordid2token(texts, texts_len)
|
|
|
|
|
|
|
|
self.autolog.times.start()
|
|
|
|
|
|
|
|
self.autolog.times.stamp()
|
|
|
|
result_transcripts = self.model.decode(
|
|
|
|
result_transcripts = self.model.decode(
|
|
|
|
audio,
|
|
|
|
audio,
|
|
|
|
audio_len,
|
|
|
|
audio_len,
|
|
|
@ -260,7 +288,9 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
cutoff_prob=cfg.cutoff_prob,
|
|
|
|
cutoff_prob=cfg.cutoff_prob,
|
|
|
|
cutoff_top_n=cfg.cutoff_top_n,
|
|
|
|
cutoff_top_n=cfg.cutoff_top_n,
|
|
|
|
num_processes=cfg.num_proc_bsearch)
|
|
|
|
num_processes=cfg.num_proc_bsearch)
|
|
|
|
|
|
|
|
self.autolog.times.stamp()
|
|
|
|
|
|
|
|
self.autolog.times.stamp()
|
|
|
|
|
|
|
|
self.autolog.times.end()
|
|
|
|
for utt, target, result in zip(utts, target_transcripts,
|
|
|
|
for utt, target, result in zip(utts, target_transcripts,
|
|
|
|
result_transcripts):
|
|
|
|
result_transcripts):
|
|
|
|
errors, len_ref = errors_func(target, result)
|
|
|
|
errors, len_ref = errors_func(target, result)
|
|
|
@ -291,6 +321,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
|
|
|
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
|
|
|
with open(self.args.result_file, 'w') as fout:
|
|
|
|
with open(self.args.result_file, 'w') as fout:
|
|
|
|
for i, batch in enumerate(self.test_loader):
|
|
|
|
for i, batch in enumerate(self.test_loader):
|
|
|
|
|
|
|
|
logger.info("batch: {}".format(i))
|
|
|
|
utts, audio, audio_len, texts, texts_len = batch
|
|
|
|
utts, audio, audio_len, texts, texts_len = batch
|
|
|
|
metrics = self.compute_metrics(utts, audio, audio_len, texts,
|
|
|
|
metrics = self.compute_metrics(utts, audio, audio_len, texts,
|
|
|
|
texts_len, fout)
|
|
|
|
texts_len, fout)
|
|
|
@ -308,6 +339,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
msg += "Final error rate [%s] (%d/%d) = %f" % (
|
|
|
|
msg += "Final error rate [%s] (%d/%d) = %f" % (
|
|
|
|
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
|
|
|
|
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
|
|
|
|
logger.info(msg)
|
|
|
|
logger.info(msg)
|
|
|
|
|
|
|
|
self.autolog.report()
|
|
|
|
|
|
|
|
|
|
|
|
def run_test(self):
|
|
|
|
def run_test(self):
|
|
|
|
self.resume_or_scratch()
|
|
|
|
self.resume_or_scratch()
|
|
|
|