fix decode json file

pull/852/head
Hui Zhang 3 years ago
parent 9abe33b4bd
commit 9d5eb74066

@ -18,8 +18,8 @@ from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import jsonlines
import jsonlines
import numpy as np import numpy as np
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
@ -306,7 +306,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
len_refs += len_ref len_refs += len_ref
num_ins += 1 num_ins += 1
if fout: if fout:
fout.write({"utt": utt, "ref", target, "hyp": result}) fout.write({"utt": utt, "ref": target, "hyp": result})
logger.info(f"Utt: {utt}") logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}") logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}") logger.info(f"Hyp: {result}")

@ -21,8 +21,8 @@ from collections import OrderedDict
from contextlib import nullcontext from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import jsonlines
import jsonlines
import numpy as np import numpy as np
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
@ -467,7 +467,7 @@ class U2Tester(U2Trainer):
len_refs += len_ref len_refs += len_ref
num_ins += 1 num_ins += 1
if fout: if fout:
fout.write({"utt": utt, "ref", target, "hyp": result}) fout.write({"utt": utt, "ref": target, "hyp": result})
logger.info(f"Utt: {utt}") logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}") logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}") logger.info(f"Hyp: {result}")

@ -20,8 +20,8 @@ from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import jsonlines
import jsonlines
import numpy as np import numpy as np
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
@ -446,7 +446,7 @@ class U2Tester(U2Trainer):
len_refs += len_ref len_refs += len_ref
num_ins += 1 num_ins += 1
if fout: if fout:
fout.write({"utt": utt, "ref", target, "hyp": result}) fout.write({"utt": utt, "ref": target, "hyp": result})
logger.info(f"Utt: {utt}") logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}") logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}") logger.info(f"Hyp: {result}")

@ -20,8 +20,8 @@ from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import jsonlines
import jsonlines
import numpy as np import numpy as np
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
@ -480,7 +480,7 @@ class U2STTester(U2STTrainer):
len_refs += len(target.split()) len_refs += len(target.split())
num_ins += 1 num_ins += 1
if fout: if fout:
fout.write({"utt": utt, "ref", target, "hyp": result}) fout.write({"utt": utt, "ref": target, "hyp": result})
logger.info(f"Utt: {utt}") logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}") logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}") logger.info(f"Hyp: {result}")

@ -1,3 +1,4 @@
# Utils # Utils
* [kaldi utils](https://github.com/kaldi-asr/kaldi/blob/cbed4ff688/egs/wsj/s5/utils) * [kaldi utils](https://github.com/kaldi-asr/kaldi/blob/cbed4ff688/egs/wsj/s5/utils)
* [espnet utils)(https://github.com/espnet/espnet/tree/master/utils)

@ -1,7 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Apache 2.0 # Apache 2.0
import argparse import argparse
import codecs import codecs
import sys import sys
@ -12,15 +10,13 @@ is_python2 = sys.version_info[0] == 2
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="filter words in a text file", description="filter words in a text file",
formatter_class=argparse.ArgumentDefaultsHelpFormatter, formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
)
parser.add_argument( parser.add_argument(
"--exclude", "--exclude",
"-v", "-v",
dest="exclude", dest="exclude",
action="store_true", action="store_true",
help="exclude filter words", help="exclude filter words", )
)
parser.add_argument("filt", type=str, help="filter list") parser.add_argument("filt", type=str, help="filter list")
parser.add_argument("infile", type=str, help="input file") parser.add_argument("infile", type=str, help="input file")
return parser return parser
@ -37,29 +33,20 @@ def filter_file(infile, filt, exclude):
for line in vocabfile: for line in vocabfile:
vocab.add(line.strip()) vocab.add(line.strip())
sys.stdout = codecs.getwriter("utf-8")( sys.stdout = codecs.getwriter("utf-8")(sys.stdout
sys.stdout if is_python2 else sys.stdout.buffer if is_python2 else sys.stdout.buffer)
)
with codecs.open(infile, "r", encoding="utf-8") as textfile: with codecs.open(infile, "r", encoding="utf-8") as textfile:
for line in textfile: for line in textfile:
if exclude: if exclude:
print( print(" ".join(
" ".join( map(
map( lambda word: word if word not in vocab else "",
lambda word: word if word not in vocab else "", line.strip().split(), )))
line.strip().split(),
)
)
)
else: else:
print( print(" ".join(
" ".join( map(
map( lambda word: word if word in vocab else "<UNK>",
lambda word: word if word in vocab else "<UNK>", line.strip().split(), )))
line.strip().split(),
)
)
)
if __name__ == "__main__": if __name__ == "__main__":

Loading…
Cancel
Save