|
|
|
@ -18,6 +18,7 @@ from collections import defaultdict
|
|
|
|
|
from contextlib import nullcontext
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Optional
|
|
|
|
|
import jsonlines
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
import paddle
|
|
|
|
@ -305,9 +306,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
|
len_refs += len_ref
|
|
|
|
|
num_ins += 1
|
|
|
|
|
if fout:
|
|
|
|
|
fout.write(utt + " " + result + "\n")
|
|
|
|
|
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
|
|
|
|
|
(target, result))
|
|
|
|
|
fout.write({"utt": utt, "ref", target, "hyp": result})
|
|
|
|
|
logger.info(f"Utt: {utt}")
|
|
|
|
|
logger.info(f"Ref: {target}")
|
|
|
|
|
logger.info(f"Hyp: {result}")
|
|
|
|
|
logger.info("Current error rate [%s] = %f" %
|
|
|
|
|
(cfg.error_rate_type, error_rate_func(target, result)))
|
|
|
|
|
|
|
|
|
@ -350,7 +352,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
|
cfg = self.config
|
|
|
|
|
error_rate_type = None
|
|
|
|
|
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
|
|
|
|
with open(self.args.result_file, 'w') as fout:
|
|
|
|
|
with jsonlines.open(self.args.result_file, 'w') as fout:
|
|
|
|
|
for i, batch in enumerate(self.test_loader):
|
|
|
|
|
utts, audio, audio_len, texts, texts_len = batch
|
|
|
|
|
metrics = self.compute_metrics(utts, audio, audio_len, texts,
|
|
|
|
|