fix decode json file

pull/852/head
Hui Zhang 3 years ago
parent 9abe33b4bd
commit 9d5eb74066

@ -18,8 +18,8 @@ from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path
from typing import Optional
import jsonlines
import jsonlines
import numpy as np
import paddle
from paddle import distributed as dist
@ -306,7 +306,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
len_refs += len_ref
num_ins += 1
if fout:
fout.write({"utt": utt, "ref", target, "hyp": result})
fout.write({"utt": utt, "ref": target, "hyp": result})
logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")

@ -21,8 +21,8 @@ from collections import OrderedDict
from contextlib import nullcontext
from pathlib import Path
from typing import Optional
import jsonlines
import jsonlines
import numpy as np
import paddle
from paddle import distributed as dist
@ -467,7 +467,7 @@ class U2Tester(U2Trainer):
len_refs += len_ref
num_ins += 1
if fout:
fout.write({"utt": utt, "ref", target, "hyp": result})
fout.write({"utt": utt, "ref": target, "hyp": result})
logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")

@ -20,8 +20,8 @@ from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path
from typing import Optional
import jsonlines
import jsonlines
import numpy as np
import paddle
from paddle import distributed as dist
@ -446,7 +446,7 @@ class U2Tester(U2Trainer):
len_refs += len_ref
num_ins += 1
if fout:
fout.write({"utt": utt, "ref", target, "hyp": result})
fout.write({"utt": utt, "ref": target, "hyp": result})
logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")

@ -20,8 +20,8 @@ from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path
from typing import Optional
import jsonlines
import jsonlines
import numpy as np
import paddle
from paddle import distributed as dist
@ -480,7 +480,7 @@ class U2STTester(U2STTrainer):
len_refs += len(target.split())
num_ins += 1
if fout:
fout.write({"utt": utt, "ref", target, "hyp": result})
fout.write({"utt": utt, "ref": target, "hyp": result})
logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")

@ -1,3 +1,4 @@
# Utils
* [kaldi utils](https://github.com/kaldi-asr/kaldi/blob/cbed4ff688/egs/wsj/s5/utils)
* [kaldi utils](https://github.com/kaldi-asr/kaldi/blob/cbed4ff688/egs/wsj/s5/utils)
* [espnet utils)(https://github.com/espnet/espnet/tree/master/utils)

@ -1,7 +1,5 @@
#!/usr/bin/env python3
# Apache 2.0
import argparse
import codecs
import sys
@ -12,15 +10,13 @@ is_python2 = sys.version_info[0] == 2
def get_parser():
parser = argparse.ArgumentParser(
description="filter words in a text file",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
parser.add_argument(
"--exclude",
"-v",
dest="exclude",
action="store_true",
help="exclude filter words",
)
help="exclude filter words", )
parser.add_argument("filt", type=str, help="filter list")
parser.add_argument("infile", type=str, help="input file")
return parser
@ -37,29 +33,20 @@ def filter_file(infile, filt, exclude):
for line in vocabfile:
vocab.add(line.strip())
sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer
)
sys.stdout = codecs.getwriter("utf-8")(sys.stdout
if is_python2 else sys.stdout.buffer)
with codecs.open(infile, "r", encoding="utf-8") as textfile:
for line in textfile:
if exclude:
print(
" ".join(
map(
lambda word: word if word not in vocab else "",
line.strip().split(),
)
)
)
print(" ".join(
map(
lambda word: word if word not in vocab else "",
line.strip().split(), )))
else:
print(
" ".join(
map(
lambda word: word if word in vocab else "<UNK>",
line.strip().split(),
)
)
)
print(" ".join(
map(
lambda word: word if word in vocab else "<UNK>",
line.strip().split(), )))
if __name__ == "__main__":

Loading…
Cancel
Save