|
|
@ -35,6 +35,9 @@ from deepspeech.utils import error_rate
|
|
|
|
from deepspeech.utils import layer_tools
|
|
|
|
from deepspeech.utils import layer_tools
|
|
|
|
from deepspeech.utils import mp_tools
|
|
|
|
from deepspeech.utils import mp_tools
|
|
|
|
from deepspeech.utils.log import Log
|
|
|
|
from deepspeech.utils.log import Log
|
|
|
|
|
|
|
|
from deepspeech.utils.log import Autolog
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -223,7 +226,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, config, args):
|
|
|
|
def __init__(self, config, args):
|
|
|
|
super().__init__(config, args)
|
|
|
|
super().__init__(config, args)
|
|
|
|
|
|
|
|
self.autolog = Autolog(batch_size = config.decoding.batch_size, model_name = "deepspeech2", model_precision = "fp32").getlog()
|
|
|
|
|
|
|
|
|
|
|
|
def ordid2token(self, texts, texts_len):
|
|
|
|
def ordid2token(self, texts, texts_len):
|
|
|
|
""" ord() id to chr() chr """
|
|
|
|
""" ord() id to chr() chr """
|
|
|
|
trans = []
|
|
|
|
trans = []
|
|
|
@ -248,6 +252,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
vocab_list = self.test_loader.collate_fn.vocab_list
|
|
|
|
vocab_list = self.test_loader.collate_fn.vocab_list
|
|
|
|
|
|
|
|
|
|
|
|
target_transcripts = self.ordid2token(texts, texts_len)
|
|
|
|
target_transcripts = self.ordid2token(texts, texts_len)
|
|
|
|
|
|
|
|
self.autolog.times.start()
|
|
|
|
|
|
|
|
self.autolog.times.stamp()
|
|
|
|
result_transcripts = self.model.decode(
|
|
|
|
result_transcripts = self.model.decode(
|
|
|
|
audio,
|
|
|
|
audio,
|
|
|
|
audio_len,
|
|
|
|
audio_len,
|
|
|
@ -260,6 +266,9 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
cutoff_prob=cfg.cutoff_prob,
|
|
|
|
cutoff_prob=cfg.cutoff_prob,
|
|
|
|
cutoff_top_n=cfg.cutoff_top_n,
|
|
|
|
cutoff_top_n=cfg.cutoff_top_n,
|
|
|
|
num_processes=cfg.num_proc_bsearch)
|
|
|
|
num_processes=cfg.num_proc_bsearch)
|
|
|
|
|
|
|
|
self.autolog.times.stamp()
|
|
|
|
|
|
|
|
self.autolog.times.stamp()
|
|
|
|
|
|
|
|
self.autolog.times.end()
|
|
|
|
|
|
|
|
|
|
|
|
for utt, target, result in zip(utts, target_transcripts,
|
|
|
|
for utt, target, result in zip(utts, target_transcripts,
|
|
|
|
result_transcripts):
|
|
|
|
result_transcripts):
|
|
|
@ -308,6 +317,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
msg += "Final error rate [%s] (%d/%d) = %f" % (
|
|
|
|
msg += "Final error rate [%s] (%d/%d) = %f" % (
|
|
|
|
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
|
|
|
|
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
|
|
|
|
logger.info(msg)
|
|
|
|
logger.info(msg)
|
|
|
|
|
|
|
|
self.autolog.report()
|
|
|
|
|
|
|
|
|
|
|
|
def run_test(self):
|
|
|
|
def run_test(self):
|
|
|
|
self.resume_or_scratch()
|
|
|
|
self.resume_or_scratch()
|
|
|
|