You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
290 lines
9.7 KiB
290 lines
9.7 KiB
#!/usr/bin/env python3
|
|
# encoding: utf-8
|
|
import argparse
|
|
import codecs
|
|
import json
|
|
import logging
|
|
import sys
|
|
from distutils.util import strtobool
|
|
from io import open
|
|
|
|
from paddlespeech.s2t.utils.cli_utils import get_commandline_args
|
|
|
|
PY2 = sys.version_info[0] == 2
|
|
sys.stdin = codecs.getreader("utf-8")(sys.stdin if PY2 else sys.stdin.buffer)
|
|
sys.stdout = codecs.getwriter("utf-8")(sys.stdout if PY2 else sys.stdout.buffer)
|
|
|
|
|
|
# Special types:
|
|
def shape(x):
|
|
"""Change str to List[int]
|
|
|
|
>>> shape('3,5')
|
|
[3, 5]
|
|
>>> shape(' [3, 5] ')
|
|
[3, 5]
|
|
|
|
"""
|
|
|
|
# x: ' [3, 5] ' -> '3, 5'
|
|
x = x.strip()
|
|
if x[0] == "[":
|
|
x = x[1:]
|
|
if x[-1] == "]":
|
|
x = x[:-1]
|
|
|
|
return list(map(int, x.split(",")))
|
|
|
|
|
|
def get_parser():
|
|
parser = argparse.ArgumentParser(
|
|
description="Given each file paths with such format as "
|
|
"<key>:<file>:<type>. type> can be omitted and the default "
|
|
'is "str". e.g. {} '
|
|
"--input-scps feat:data/feats.scp shape:data/utt2feat_shape:shape "
|
|
"--input-scps feat:data/feats2.scp shape:data/utt2feat2_shape:shape "
|
|
"--output-scps text:data/text shape:data/utt2text_shape:shape "
|
|
"--scps utt2spk:data/utt2spk".format(sys.argv[0]),
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
|
|
parser.add_argument(
|
|
"--input-scps",
|
|
type=str,
|
|
nargs="*",
|
|
action="append",
|
|
default=[],
|
|
help="Json files for the inputs", )
|
|
parser.add_argument(
|
|
"--output-scps",
|
|
type=str,
|
|
nargs="*",
|
|
action="append",
|
|
default=[],
|
|
help="Json files for the outputs", )
|
|
parser.add_argument(
|
|
"--scps",
|
|
type=str,
|
|
nargs="+",
|
|
default=[],
|
|
help="The json files except for the input and outputs", )
|
|
parser.add_argument(
|
|
"--verbose", "-V", default=1, type=int, help="Verbose option")
|
|
parser.add_argument(
|
|
"--allow-one-column",
|
|
type=strtobool,
|
|
default=False,
|
|
help="Allow one column in input scp files. "
|
|
"In this case, the value will be empty string.", )
|
|
parser.add_argument(
|
|
"--out",
|
|
"-O",
|
|
type=str,
|
|
help="The output filename. "
|
|
"If omitted, then output to sys.stdout", )
|
|
return parser
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = get_parser()
|
|
args = parser.parse_args()
|
|
args.scps = [args.scps]
|
|
|
|
# logging info
|
|
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
|
if args.verbose > 0:
|
|
logging.basicConfig(level=logging.INFO, format=logfmt)
|
|
else:
|
|
logging.basicConfig(level=logging.WARN, format=logfmt)
|
|
logging.info(get_commandline_args())
|
|
|
|
# List[List[Tuple[str, str, Callable[[str], Any], str, str]]]
|
|
input_infos = []
|
|
output_infos = []
|
|
infos = []
|
|
for lis_list, key_scps_list in [
|
|
(input_infos, args.input_scps),
|
|
(output_infos, args.output_scps),
|
|
(infos, args.scps),
|
|
]:
|
|
for key_scps in key_scps_list:
|
|
lis = []
|
|
for key_scp in key_scps:
|
|
sps = key_scp.split(":")
|
|
if len(sps) == 2:
|
|
key, scp = sps
|
|
type_func = None
|
|
type_func_str = "none"
|
|
elif len(sps) == 3:
|
|
key, scp, type_func_str = sps
|
|
fail = False
|
|
|
|
try:
|
|
# type_func: Callable[[str], Any]
|
|
# e.g. type_func_str = "int" -> type_func = int
|
|
type_func = eval(type_func_str)
|
|
except Exception:
|
|
raise RuntimeError(
|
|
"Unknown type: {}".format(type_func_str))
|
|
|
|
if not callable(type_func):
|
|
raise RuntimeError(
|
|
"Unknown type: {}".format(type_func_str))
|
|
|
|
else:
|
|
raise RuntimeError(
|
|
"Format <key>:<filepath> "
|
|
"or <key>:<filepath>:<type> "
|
|
"e.g. feat:data/feat.scp "
|
|
"or shape:data/feat.scp:shape: {}".format(key_scp))
|
|
|
|
for item in lis:
|
|
if key == item[0]:
|
|
raise RuntimeError('The key "{}" is duplicated: {} {}'.
|
|
format(key, item[3], key_scp))
|
|
|
|
lis.append((key, scp, type_func, key_scp, type_func_str))
|
|
lis_list.append(lis)
|
|
|
|
# Open scp files
|
|
input_fscps = [[open(i[1], "r", encoding="utf-8") for i in il]
|
|
for il in input_infos]
|
|
output_fscps = [[open(i[1], "r", encoding="utf-8") for i in il]
|
|
for il in output_infos]
|
|
fscps = [[open(i[1], "r", encoding="utf-8") for i in il] for il in infos]
|
|
|
|
# Note(kamo): What is done here?
|
|
# The final goal is creating a JSON file such as.
|
|
# {
|
|
# "utts": {
|
|
# "sample_id1": {(omitted)},
|
|
# "sample_id2": {(omitted)},
|
|
# ....
|
|
# }
|
|
# }
|
|
#
|
|
# To reduce memory usage, reading the input text files for each lines
|
|
# and writing JSON elements per samples.
|
|
if args.out is None:
|
|
out = sys.stdout
|
|
else:
|
|
out = open(args.out, "w", encoding="utf-8")
|
|
out.write('{\n "utts": {\n')
|
|
nutt = 0
|
|
while True:
|
|
nutt += 1
|
|
# List[List[str]]
|
|
input_lines = [[f.readline() for f in fl] for fl in input_fscps]
|
|
output_lines = [[f.readline() for f in fl] for fl in output_fscps]
|
|
lines = [[f.readline() for f in fl] for fl in fscps]
|
|
|
|
# Get the first line
|
|
concat = sum(input_lines + output_lines + lines, [])
|
|
if len(concat) == 0:
|
|
break
|
|
first = concat[0]
|
|
|
|
# Sanity check: Must be sorted by the first column and have same keys
|
|
count = 0
|
|
for ls_list in (input_lines, output_lines, lines):
|
|
for ls in ls_list:
|
|
for line in ls:
|
|
if line == "" or first == "":
|
|
if line != first:
|
|
concat = sum(input_infos + output_infos + infos, [])
|
|
raise RuntimeError("The number of lines mismatch "
|
|
'between: "{}" and "{}"'.format(
|
|
concat[0][1],
|
|
concat[count][1]))
|
|
|
|
elif line.split()[0] != first.split()[0]:
|
|
concat = sum(input_infos + output_infos + infos, [])
|
|
raise RuntimeError(
|
|
"The keys are mismatch at {}th line "
|
|
'between "{}" and "{}":\n>>> {}\n>>> {}'.format(
|
|
nutt,
|
|
concat[0][1],
|
|
concat[count][1],
|
|
first.rstrip(),
|
|
line.rstrip(), ))
|
|
count += 1
|
|
|
|
# The end of file
|
|
if first == "":
|
|
if nutt != 1:
|
|
out.write("\n")
|
|
break
|
|
if nutt != 1:
|
|
out.write(",\n")
|
|
|
|
entry = {}
|
|
for inout, _lines, _infos in [
|
|
("input", input_lines, input_infos),
|
|
("output", output_lines, output_infos),
|
|
("other", lines, infos),
|
|
]:
|
|
|
|
lis = []
|
|
for idx, (line_list, info_list) in enumerate(
|
|
zip(_lines, _infos), 1):
|
|
if inout == "input":
|
|
d = {"name": "input{}".format(idx)}
|
|
elif inout == "output":
|
|
d = {"name": "target{}".format(idx)}
|
|
else:
|
|
d = {}
|
|
|
|
# info_list: List[Tuple[str, str, Callable]]
|
|
# line_list: List[str]
|
|
for line, info in zip(line_list, info_list):
|
|
sps = line.split(None, 1)
|
|
if len(sps) < 2:
|
|
if not args.allow_one_column:
|
|
raise RuntimeError(
|
|
"Format error {}th line in {}: "
|
|
' Expecting "<key> <value>":\n>>> {}'.format(
|
|
nutt, info[1], line))
|
|
uttid = sps[0]
|
|
value = ""
|
|
else:
|
|
uttid, value = sps
|
|
|
|
key = info[0]
|
|
type_func = info[2]
|
|
value = value.rstrip()
|
|
|
|
if type_func is not None:
|
|
try:
|
|
# type_func: Callable[[str], Any]
|
|
value = type_func(value)
|
|
except Exception:
|
|
logging.error(
|
|
'"{}" is an invalid function '
|
|
"for the {} th line in {}: \n>>> {}".format(
|
|
info[4], nutt, info[1], line))
|
|
raise
|
|
|
|
d[key] = value
|
|
lis.append(d)
|
|
|
|
if inout != "other":
|
|
entry[inout] = lis
|
|
else:
|
|
# If key == 'other'. only has the first item
|
|
entry.update(lis[0])
|
|
|
|
entry = json.dumps(
|
|
entry,
|
|
indent=4,
|
|
ensure_ascii=False,
|
|
sort_keys=True,
|
|
separators=(",", ": "))
|
|
# Add indent
|
|
indent = " " * 2
|
|
entry = ("\n" + indent).join(entry.split("\n"))
|
|
|
|
uttid = first.split()[0]
|
|
out.write(' "{}": {}'.format(uttid, entry))
|
|
|
|
out.write(" }\n}\n")
|
|
|
|
logging.info("{} entries in {}".format(nutt, out.name))
|