dump decode result as jsonlines

pull/852/head
Hui Zhang 3 years ago
parent c6e8a33b73
commit ae87bc8c7a

@ -18,6 +18,7 @@ from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import jsonlines
import numpy as np import numpy as np
import paddle import paddle
@ -305,9 +306,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
len_refs += len_ref len_refs += len_ref
num_ins += 1 num_ins += 1
if fout: if fout:
fout.write(utt + " " + result + "\n") fout.write({"utt": utt, "ref", target, "hyp": result})
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % logger.info(f"Utt: {utt}")
(target, result)) logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
logger.info("Current error rate [%s] = %f" % logger.info("Current error rate [%s] = %f" %
(cfg.error_rate_type, error_rate_func(target, result))) (cfg.error_rate_type, error_rate_func(target, result)))
@ -350,7 +352,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
cfg = self.config cfg = self.config
error_rate_type = None error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
with open(self.args.result_file, 'w') as fout: with jsonlines.open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader): for i, batch in enumerate(self.test_loader):
utts, audio, audio_len, texts, texts_len = batch utts, audio, audio_len, texts, texts_len = batch
metrics = self.compute_metrics(utts, audio, audio_len, texts, metrics = self.compute_metrics(utts, audio, audio_len, texts,

@ -21,6 +21,7 @@ from collections import OrderedDict
from contextlib import nullcontext from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import jsonlines
import numpy as np import numpy as np
import paddle import paddle
@ -466,9 +467,10 @@ class U2Tester(U2Trainer):
len_refs += len_ref len_refs += len_ref
num_ins += 1 num_ins += 1
if fout: if fout:
fout.write(utt + " " + result + "\n") fout.write({"utt": utt, "ref", target, "hyp": result})
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % logger.info(f"Utt: {utt}")
(target, result)) logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
logger.info("One example error rate [%s] = %f" % logger.info("One example error rate [%s] = %f" %
(cfg.error_rate_type, error_rate_func(target, result))) (cfg.error_rate_type, error_rate_func(target, result)))
@ -493,7 +495,7 @@ class U2Tester(U2Trainer):
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
num_frames = 0.0 num_frames = 0.0
num_time = 0.0 num_time = 0.0
with open(self.args.result_file, 'w') as fout: with jsonlines.open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader): for i, batch in enumerate(self.test_loader):
metrics = self.compute_metrics(*batch, fout=fout) metrics = self.compute_metrics(*batch, fout=fout)
num_frames += metrics['num_frames'] num_frames += metrics['num_frames']

@ -20,6 +20,7 @@ from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import jsonlines
import numpy as np import numpy as np
import paddle import paddle
@ -445,9 +446,10 @@ class U2Tester(U2Trainer):
len_refs += len_ref len_refs += len_ref
num_ins += 1 num_ins += 1
if fout: if fout:
fout.write(utt + " " + result + "\n") fout.write({"utt": utt, "ref", target, "hyp": result})
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % logger.info(f"Utt: {utt}")
(target, result)) logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
logger.info("One example error rate [%s] = %f" % logger.info("One example error rate [%s] = %f" %
(cfg.error_rate_type, error_rate_func(target, result))) (cfg.error_rate_type, error_rate_func(target, result)))
@ -472,7 +474,7 @@ class U2Tester(U2Trainer):
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
num_frames = 0.0 num_frames = 0.0
num_time = 0.0 num_time = 0.0
with open(self.args.result_file, 'w') as fout: with jsonlines.open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader): for i, batch in enumerate(self.test_loader):
metrics = self.compute_metrics(*batch, fout=fout) metrics = self.compute_metrics(*batch, fout=fout)
num_frames += metrics['num_frames'] num_frames += metrics['num_frames']

@ -20,6 +20,7 @@ from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import jsonlines
import numpy as np import numpy as np
import paddle import paddle
@ -479,8 +480,10 @@ class U2STTester(U2STTrainer):
len_refs += len(target.split()) len_refs += len(target.split())
num_ins += 1 num_ins += 1
if fout: if fout:
fout.write(utt + " " + result + "\n") fout.write({"utt": utt, "ref", target, "hyp": result})
logger.info("\nReference: %s\nHypothesis: %s" % (target, result)) logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
logger.info("One example BLEU = %s" % logger.info("One example BLEU = %s" %
(bleu_func([result], [[target]]).prec_str)) (bleu_func([result], [[target]]).prec_str))
@ -508,7 +511,7 @@ class U2STTester(U2STTrainer):
len_refs, num_ins = 0, 0 len_refs, num_ins = 0, 0
num_frames = 0.0 num_frames = 0.0
num_time = 0.0 num_time = 0.0
with open(self.args.result_file, 'w') as fout: with jsonlines.open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader): for i, batch in enumerate(self.test_loader):
metrics = self.compute_translation_metrics( metrics = self.compute_translation_metrics(
*batch, bleu_func=bleu_func, fout=fout) *batch, bleu_func=bleu_func, fout=fout)

Loading…
Cancel
Save