You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
153 lines
4.9 KiB
153 lines
4.9 KiB
#!/usr/bin/env python3
|
|
# encoding: utf-8
|
|
# Copyright 2018 Nagoya University (Tomoki Hayashi)
|
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
import argparse
|
|
import codecs
|
|
import json
|
|
import logging
|
|
import sys
|
|
from distutils.util import strtobool
|
|
|
|
from espnet.utils.cli_utils import get_commandline_args
|
|
|
|
is_python2 = sys.version_info[0] == 2
|
|
|
|
|
|
def get_parser():
|
|
parser = argparse.ArgumentParser(
|
|
description="add multiple json values to an input or output value",
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
|
|
parser.add_argument("jsons", type=str, nargs="+", help="json files")
|
|
parser.add_argument(
|
|
"-i",
|
|
"--is-input",
|
|
default=True,
|
|
type=strtobool,
|
|
help="If true, add to input. If false, add to output", )
|
|
parser.add_argument(
|
|
"--verbose", "-V", default=0, type=int, help="Verbose option")
|
|
return parser
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = get_parser()
|
|
args = parser.parse_args()
|
|
|
|
# logging info
|
|
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
|
if args.verbose > 0:
|
|
logging.basicConfig(level=logging.INFO, format=logfmt)
|
|
else:
|
|
logging.basicConfig(level=logging.WARN, format=logfmt)
|
|
logging.info(get_commandline_args())
|
|
|
|
# make intersection set for utterance keys
|
|
js = []
|
|
intersec_ks = []
|
|
for x in args.jsons:
|
|
with codecs.open(x, "r", encoding="utf-8") as f:
|
|
j = json.load(f)
|
|
ks = j["utts"].keys()
|
|
logging.info(x + ": has " + str(len(ks)) + " utterances")
|
|
if len(intersec_ks) > 0:
|
|
intersec_ks = intersec_ks.intersection(set(ks))
|
|
if len(intersec_ks) == 0:
|
|
logging.warning("Empty intersection")
|
|
break
|
|
else:
|
|
intersec_ks = set(ks)
|
|
js.append(j)
|
|
logging.info("new json has " + str(len(intersec_ks)) + " utterances")
|
|
|
|
# updated original dict to keep intersection
|
|
intersec_org_dic = dict()
|
|
for k in intersec_ks:
|
|
v = js[0]["utts"][k]
|
|
intersec_org_dic[k] = v
|
|
|
|
intersec_add_dic = dict()
|
|
for k in intersec_ks:
|
|
v = js[1]["utts"][k]
|
|
for j in js[2:]:
|
|
v.update(j["utts"][k])
|
|
intersec_add_dic[k] = v
|
|
|
|
new_dic = dict()
|
|
for key_id in intersec_org_dic:
|
|
orgdic = intersec_org_dic[key_id]
|
|
adddic = intersec_add_dic[key_id]
|
|
|
|
if "utt2spk" not in orgdic:
|
|
orgdic["utt2spk"] = ""
|
|
# NOTE: for machine translation
|
|
|
|
# add as input
|
|
if args.is_input:
|
|
# original input
|
|
input_list = orgdic["input"]
|
|
# additional input
|
|
in_add_dic = {}
|
|
if "idim" in adddic and "ilen" in adddic:
|
|
in_add_dic["shape"] = [int(adddic["ilen"]), int(adddic["idim"])]
|
|
elif "idim" in adddic:
|
|
in_add_dic["shape"] = [int(adddic["idim"])]
|
|
# add all other key value
|
|
for key, value in adddic.items():
|
|
if key in ["idim", "ilen"]:
|
|
continue
|
|
in_add_dic[key] = value
|
|
# add name
|
|
in_add_dic["name"] = "input%d" % (len(input_list) + 1)
|
|
|
|
input_list.append(in_add_dic)
|
|
new_dic[key_id] = {
|
|
"input": input_list,
|
|
"output": orgdic["output"],
|
|
"utt2spk": orgdic["utt2spk"],
|
|
}
|
|
# add as output
|
|
else:
|
|
# original output
|
|
output_list = orgdic["output"]
|
|
# additional output
|
|
out_add_dic = {}
|
|
# add shape
|
|
if "odim" in adddic and "olen" in adddic:
|
|
out_add_dic[
|
|
"shape"] = [int(adddic["olen"]), int(adddic["odim"])]
|
|
elif "odim" in adddic:
|
|
out_add_dic["shape"] = [int(adddic["odim"])]
|
|
# add all other key value
|
|
for key, value in adddic.items():
|
|
if key in ["odim", "olen"]:
|
|
continue
|
|
out_add_dic[key] = value
|
|
# add name
|
|
out_add_dic["name"] = "target%d" % (len(output_list) + 1)
|
|
|
|
output_list.append(out_add_dic)
|
|
new_dic[key_id] = {
|
|
"input": orgdic["input"],
|
|
"output": output_list,
|
|
"utt2spk": orgdic["utt2spk"],
|
|
}
|
|
if "lang" in orgdic.keys():
|
|
new_dic[key_id]["lang"] = orgdic["lang"]
|
|
|
|
# ensure "ensure_ascii=False", which is a bug
|
|
jsonstring = json.dumps(
|
|
{
|
|
"utts": new_dic
|
|
},
|
|
indent=4,
|
|
ensure_ascii=False,
|
|
sort_keys=True,
|
|
separators=(",", ": "), )
|
|
sys.stdout = codecs.getwriter("utf-8")(sys.stdout
|
|
if is_python2 else sys.stdout.buffer)
|
|
print(jsonstring)
|