PaddleSpeech/utils/merge_scp2json.py

291 lines
9.7 KiB

#!/usr/bin/env python3
# encoding: utf-8
import argparse
import codecs
import json
import logging
import sys
from io import open
from distutils.util import strtobool
from paddlespeech.s2t.utils.cli_utils import get_commandline_args
PY2 = sys.version_info[0] == 2
sys.stdin = codecs.getreader("utf-8")(sys.stdin if PY2 else sys.stdin.buffer)
sys.stdout = codecs.getwriter("utf-8")(sys.stdout if PY2 else sys.stdout.buffer)
# Special types:
def shape(x):
"""Change str to List[int]
>>> shape('3,5')
[3, 5]
>>> shape(' [3, 5] ')
[3, 5]
"""
# x: ' [3, 5] ' -> '3, 5'
x = x.strip()
if x[0] == "[":
x = x[1:]
if x[-1] == "]":
x = x[:-1]
return list(map(int, x.split(",")))
def get_parser():
parser = argparse.ArgumentParser(
description="Given each file paths with such format as "
"<key>:<file>:<type>. type> can be omitted and the default "
'is "str". e.g. {} '
"--input-scps feat:data/feats.scp shape:data/utt2feat_shape:shape "
"--input-scps feat:data/feats2.scp shape:data/utt2feat2_shape:shape "
"--output-scps text:data/text shape:data/utt2text_shape:shape "
"--scps utt2spk:data/utt2spk".format(sys.argv[0]),
formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
parser.add_argument(
"--input-scps",
type=str,
nargs="*",
action="append",
default=[],
help="Json files for the inputs", )
parser.add_argument(
"--output-scps",
type=str,
nargs="*",
action="append",
default=[],
help="Json files for the outputs", )
parser.add_argument(
"--scps",
type=str,
nargs="+",
default=[],
help="The json files except for the input and outputs", )
parser.add_argument(
"--verbose", "-V", default=1, type=int, help="Verbose option")
parser.add_argument(
"--allow-one-column",
type=strtobool,
default=False,
help="Allow one column in input scp files. "
"In this case, the value will be empty string.", )
parser.add_argument(
"--out",
"-O",
type=str,
help="The output filename. "
"If omitted, then output to sys.stdout", )
return parser
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
args.scps = [args.scps]
# logging info
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
if args.verbose > 0:
logging.basicConfig(level=logging.INFO, format=logfmt)
else:
logging.basicConfig(level=logging.WARN, format=logfmt)
logging.info(get_commandline_args())
# List[List[Tuple[str, str, Callable[[str], Any], str, str]]]
input_infos = []
output_infos = []
infos = []
for lis_list, key_scps_list in [
(input_infos, args.input_scps),
(output_infos, args.output_scps),
(infos, args.scps),
]:
for key_scps in key_scps_list:
lis = []
for key_scp in key_scps:
sps = key_scp.split(":")
if len(sps) == 2:
key, scp = sps
type_func = None
type_func_str = "none"
elif len(sps) == 3:
key, scp, type_func_str = sps
fail = False
try:
# type_func: Callable[[str], Any]
# e.g. type_func_str = "int" -> type_func = int
type_func = eval(type_func_str)
except Exception:
raise RuntimeError(
"Unknown type: {}".format(type_func_str))
if not callable(type_func):
raise RuntimeError(
"Unknown type: {}".format(type_func_str))
else:
raise RuntimeError(
"Format <key>:<filepath> "
"or <key>:<filepath>:<type> "
"e.g. feat:data/feat.scp "
"or shape:data/feat.scp:shape: {}".format(key_scp))
for item in lis:
if key == item[0]:
raise RuntimeError('The key "{}" is duplicated: {} {}'.
format(key, item[3], key_scp))
lis.append((key, scp, type_func, key_scp, type_func_str))
lis_list.append(lis)
# Open scp files
input_fscps = [[open(i[1], "r", encoding="utf-8") for i in il]
for il in input_infos]
output_fscps = [[open(i[1], "r", encoding="utf-8") for i in il]
for il in output_infos]
fscps = [[open(i[1], "r", encoding="utf-8") for i in il] for il in infos]
# Note(kamo): What is done here?
# The final goal is creating a JSON file such as.
# {
# "utts": {
# "sample_id1": {(omitted)},
# "sample_id2": {(omitted)},
# ....
# }
# }
#
# To reduce memory usage, reading the input text files for each lines
# and writing JSON elements per samples.
if args.out is None:
out = sys.stdout
else:
out = open(args.out, "w", encoding="utf-8")
out.write('{\n "utts": {\n')
nutt = 0
while True:
nutt += 1
# List[List[str]]
input_lines = [[f.readline() for f in fl] for fl in input_fscps]
output_lines = [[f.readline() for f in fl] for fl in output_fscps]
lines = [[f.readline() for f in fl] for fl in fscps]
# Get the first line
concat = sum(input_lines + output_lines + lines, [])
if len(concat) == 0:
break
first = concat[0]
# Sanity check: Must be sorted by the first column and have same keys
count = 0
for ls_list in (input_lines, output_lines, lines):
for ls in ls_list:
for line in ls:
if line == "" or first == "":
if line != first:
concat = sum(input_infos + output_infos + infos, [])
raise RuntimeError("The number of lines mismatch "
'between: "{}" and "{}"'.format(
concat[0][1],
concat[count][1]))
elif line.split()[0] != first.split()[0]:
concat = sum(input_infos + output_infos + infos, [])
raise RuntimeError(
"The keys are mismatch at {}th line "
'between "{}" and "{}":\n>>> {}\n>>> {}'.format(
nutt,
concat[0][1],
concat[count][1],
first.rstrip(),
line.rstrip(), ))
count += 1
# The end of file
if first == "":
if nutt != 1:
out.write("\n")
break
if nutt != 1:
out.write(",\n")
entry = {}
for inout, _lines, _infos in [
("input", input_lines, input_infos),
("output", output_lines, output_infos),
("other", lines, infos),
]:
lis = []
for idx, (line_list, info_list) in enumerate(
zip(_lines, _infos), 1):
if inout == "input":
d = {"name": "input{}".format(idx)}
elif inout == "output":
d = {"name": "target{}".format(idx)}
else:
d = {}
# info_list: List[Tuple[str, str, Callable]]
# line_list: List[str]
for line, info in zip(line_list, info_list):
sps = line.split(None, 1)
if len(sps) < 2:
if not args.allow_one_column:
raise RuntimeError(
"Format error {}th line in {}: "
' Expecting "<key> <value>":\n>>> {}'.format(
nutt, info[1], line))
uttid = sps[0]
value = ""
else:
uttid, value = sps
key = info[0]
type_func = info[2]
value = value.rstrip()
if type_func is not None:
try:
# type_func: Callable[[str], Any]
value = type_func(value)
except Exception:
logging.error(
'"{}" is an invalid function '
"for the {} th line in {}: \n>>> {}".format(
info[4], nutt, info[1], line))
raise
d[key] = value
lis.append(d)
if inout != "other":
entry[inout] = lis
else:
# If key == 'other'. only has the first item
entry.update(lis[0])
entry = json.dumps(
entry,
indent=4,
ensure_ascii=False,
sort_keys=True,
separators=(",", ": "))
# Add indent
indent = " " * 2
entry = ("\n" + indent).join(entry.split("\n"))
uttid = first.split()[0]
out.write(' "{}": {}'.format(uttid, entry))
out.write(" }\n}\n")
logging.info("{} entries in {}".format(nutt, out.name))