|
|
@ -26,6 +26,7 @@ from src_deepspeech2x.models.ds2 import DeepSpeech2InferModel
|
|
|
|
from src_deepspeech2x.models.ds2 import DeepSpeech2Model
|
|
|
|
from src_deepspeech2x.models.ds2 import DeepSpeech2Model
|
|
|
|
from yacs.config import CfgNode
|
|
|
|
from yacs.config import CfgNode
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
|
|
|
|
from deepspeech.io.collator import SpeechCollator
|
|
|
|
from deepspeech.io.collator import SpeechCollator
|
|
|
|
from deepspeech.io.dataset import ManifestDataset
|
|
|
|
from deepspeech.io.dataset import ManifestDataset
|
|
|
|
from deepspeech.io.sampler import SortagradBatchSampler
|
|
|
|
from deepspeech.io.sampler import SortagradBatchSampler
|
|
|
@ -38,7 +39,6 @@ from deepspeech.utils import error_rate
|
|
|
|
from deepspeech.utils import layer_tools
|
|
|
|
from deepspeech.utils import layer_tools
|
|
|
|
from deepspeech.utils import mp_tools
|
|
|
|
from deepspeech.utils import mp_tools
|
|
|
|
from deepspeech.utils.log import Log
|
|
|
|
from deepspeech.utils.log import Log
|
|
|
|
#from deepspeech.utils.log import Autolog
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
|
|
|
|
|
|
|
@ -272,6 +272,9 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
return default
|
|
|
|
return default
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, config, args):
|
|
|
|
def __init__(self, config, args):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._text_featurizer = TextFeaturizer(
|
|
|
|
|
|
|
|
unit_type=config.collator.unit_type, vocab_filepath=None)
|
|
|
|
super().__init__(config, args)
|
|
|
|
super().__init__(config, args)
|
|
|
|
|
|
|
|
|
|
|
|
def ordid2token(self, texts, texts_len):
|
|
|
|
def ordid2token(self, texts, texts_len):
|
|
|
@ -296,9 +299,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
|
|
|
|
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
|
|
|
|
|
|
|
|
|
|
|
|
vocab_list = self.test_loader.collate_fn.vocab_list
|
|
|
|
vocab_list = self.test_loader.collate_fn.vocab_list
|
|
|
|
if "" in vocab_list:
|
|
|
|
|
|
|
|
space_id = vocab_list.index("")
|
|
|
|
|
|
|
|
vocab_list[space_id] = " "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
target_transcripts = self.ordid2token(texts, texts_len)
|
|
|
|
target_transcripts = self.ordid2token(texts, texts_len)
|
|
|
|
|
|
|
|
|
|
|
@ -337,6 +337,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
cutoff_prob=cfg.cutoff_prob,
|
|
|
|
cutoff_prob=cfg.cutoff_prob,
|
|
|
|
cutoff_top_n=cfg.cutoff_top_n,
|
|
|
|
cutoff_top_n=cfg.cutoff_top_n,
|
|
|
|
num_processes=cfg.num_proc_bsearch)
|
|
|
|
num_processes=cfg.num_proc_bsearch)
|
|
|
|
|
|
|
|
result_transcripts = [
|
|
|
|
|
|
|
|
self._text_featurizer.detokenize(item)
|
|
|
|
|
|
|
|
for item in result_transcripts
|
|
|
|
|
|
|
|
]
|
|
|
|
return result_transcripts
|
|
|
|
return result_transcripts
|
|
|
|
|
|
|
|
|
|
|
|
@mp_tools.rank_zero_only
|
|
|
|
@mp_tools.rank_zero_only
|
|
|
@ -367,8 +371,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
|
|
|
|
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
|
|
|
|
logger.info(msg)
|
|
|
|
logger.info(msg)
|
|
|
|
|
|
|
|
|
|
|
|
# self.autolog.report()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_test(self):
|
|
|
|
def run_test(self):
|
|
|
|
self.resume_or_scratch()
|
|
|
|
self.resume_or_scratch()
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|