dump decode result as jsonlines

pull/852/head
Hui Zhang 3 years ago
parent c6e8a33b73
commit ae87bc8c7a

@ -18,6 +18,7 @@ from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path
from typing import Optional
import jsonlines
import numpy as np
import paddle
@ -305,9 +306,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
len_refs += len_ref
num_ins += 1
if fout:
fout.write(utt + " " + result + "\n")
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target, result))
fout.write({"utt": utt, "ref", target, "hyp": result})
logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
logger.info("Current error rate [%s] = %f" %
(cfg.error_rate_type, error_rate_func(target, result)))
@ -350,7 +352,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
cfg = self.config
error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0
with open(self.args.result_file, 'w') as fout:
with jsonlines.open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
utts, audio, audio_len, texts, texts_len = batch
metrics = self.compute_metrics(utts, audio, audio_len, texts,

@ -21,6 +21,7 @@ from collections import OrderedDict
from contextlib import nullcontext
from pathlib import Path
from typing import Optional
import jsonlines
import numpy as np
import paddle
@ -466,9 +467,10 @@ class U2Tester(U2Trainer):
len_refs += len_ref
num_ins += 1
if fout:
fout.write(utt + " " + result + "\n")
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target, result))
fout.write({"utt": utt, "ref", target, "hyp": result})
logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
logger.info("One example error rate [%s] = %f" %
(cfg.error_rate_type, error_rate_func(target, result)))
@ -493,7 +495,7 @@ class U2Tester(U2Trainer):
errors_sum, len_refs, num_ins = 0.0, 0, 0
num_frames = 0.0
num_time = 0.0
with open(self.args.result_file, 'w') as fout:
with jsonlines.open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
metrics = self.compute_metrics(*batch, fout=fout)
num_frames += metrics['num_frames']

@ -20,6 +20,7 @@ from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path
from typing import Optional
import jsonlines
import numpy as np
import paddle
@ -445,9 +446,10 @@ class U2Tester(U2Trainer):
len_refs += len_ref
num_ins += 1
if fout:
fout.write(utt + " " + result + "\n")
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target, result))
fout.write({"utt": utt, "ref", target, "hyp": result})
logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
logger.info("One example error rate [%s] = %f" %
(cfg.error_rate_type, error_rate_func(target, result)))
@ -472,7 +474,7 @@ class U2Tester(U2Trainer):
errors_sum, len_refs, num_ins = 0.0, 0, 0
num_frames = 0.0
num_time = 0.0
with open(self.args.result_file, 'w') as fout:
with jsonlines.open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
metrics = self.compute_metrics(*batch, fout=fout)
num_frames += metrics['num_frames']

@ -20,6 +20,7 @@ from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path
from typing import Optional
import jsonlines
import numpy as np
import paddle
@ -479,8 +480,10 @@ class U2STTester(U2STTrainer):
len_refs += len(target.split())
num_ins += 1
if fout:
fout.write(utt + " " + result + "\n")
logger.info("\nReference: %s\nHypothesis: %s" % (target, result))
fout.write({"utt": utt, "ref", target, "hyp": result})
logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
logger.info("One example BLEU = %s" %
(bleu_func([result], [[target]]).prec_str))
@ -508,7 +511,7 @@ class U2STTester(U2STTrainer):
len_refs, num_ins = 0, 0
num_frames = 0.0
num_time = 0.0
with open(self.args.result_file, 'w') as fout:
with jsonlines.open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
metrics = self.compute_translation_metrics(
*batch, bleu_func=bleu_func, fout=fout)

Loading…
Cancel
Save