#!/usr/bin/env python3 # encoding: utf-8 import argparse import codecs import json import logging import sys from io import open from distutils.util import strtobool from paddlespeech.s2t.utils.cli_utils import get_commandline_args PY2 = sys.version_info[0] == 2 sys.stdin = codecs.getreader("utf-8")(sys.stdin if PY2 else sys.stdin.buffer) sys.stdout = codecs.getwriter("utf-8")(sys.stdout if PY2 else sys.stdout.buffer) # Special types: def shape(x): """Change str to List[int] >>> shape('3,5') [3, 5] >>> shape(' [3, 5] ') [3, 5] """ # x: ' [3, 5] ' -> '3, 5' x = x.strip() if x[0] == "[": x = x[1:] if x[-1] == "]": x = x[:-1] return list(map(int, x.split(","))) def get_parser(): parser = argparse.ArgumentParser( description="Given each file paths with such format as " "::. type> can be omitted and the default " 'is "str". e.g. {} ' "--input-scps feat:data/feats.scp shape:data/utt2feat_shape:shape " "--input-scps feat:data/feats2.scp shape:data/utt2feat2_shape:shape " "--output-scps text:data/text shape:data/utt2text_shape:shape " "--scps utt2spk:data/utt2spk".format(sys.argv[0]), formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "--input-scps", type=str, nargs="*", action="append", default=[], help="Json files for the inputs", ) parser.add_argument( "--output-scps", type=str, nargs="*", action="append", default=[], help="Json files for the outputs", ) parser.add_argument( "--scps", type=str, nargs="+", default=[], help="The json files except for the input and outputs", ) parser.add_argument( "--verbose", "-V", default=1, type=int, help="Verbose option") parser.add_argument( "--allow-one-column", type=strtobool, default=False, help="Allow one column in input scp files. " "In this case, the value will be empty string.", ) parser.add_argument( "--out", "-O", type=str, help="The output filename. " "If omitted, then output to sys.stdout", ) return parser if __name__ == "__main__": parser = get_parser() args = parser.parse_args() args.scps = [args.scps] # logging info logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" if args.verbose > 0: logging.basicConfig(level=logging.INFO, format=logfmt) else: logging.basicConfig(level=logging.WARN, format=logfmt) logging.info(get_commandline_args()) # List[List[Tuple[str, str, Callable[[str], Any], str, str]]] input_infos = [] output_infos = [] infos = [] for lis_list, key_scps_list in [ (input_infos, args.input_scps), (output_infos, args.output_scps), (infos, args.scps), ]: for key_scps in key_scps_list: lis = [] for key_scp in key_scps: sps = key_scp.split(":") if len(sps) == 2: key, scp = sps type_func = None type_func_str = "none" elif len(sps) == 3: key, scp, type_func_str = sps fail = False try: # type_func: Callable[[str], Any] # e.g. type_func_str = "int" -> type_func = int type_func = eval(type_func_str) except Exception: raise RuntimeError( "Unknown type: {}".format(type_func_str)) if not callable(type_func): raise RuntimeError( "Unknown type: {}".format(type_func_str)) else: raise RuntimeError( "Format : " "or :: " "e.g. feat:data/feat.scp " "or shape:data/feat.scp:shape: {}".format(key_scp)) for item in lis: if key == item[0]: raise RuntimeError('The key "{}" is duplicated: {} {}'. format(key, item[3], key_scp)) lis.append((key, scp, type_func, key_scp, type_func_str)) lis_list.append(lis) # Open scp files input_fscps = [[open(i[1], "r", encoding="utf-8") for i in il] for il in input_infos] output_fscps = [[open(i[1], "r", encoding="utf-8") for i in il] for il in output_infos] fscps = [[open(i[1], "r", encoding="utf-8") for i in il] for il in infos] # Note(kamo): What is done here? # The final goal is creating a JSON file such as. # { # "utts": { # "sample_id1": {(omitted)}, # "sample_id2": {(omitted)}, # .... # } # } # # To reduce memory usage, reading the input text files for each lines # and writing JSON elements per samples. if args.out is None: out = sys.stdout else: out = open(args.out, "w", encoding="utf-8") out.write('{\n "utts": {\n') nutt = 0 while True: nutt += 1 # List[List[str]] input_lines = [[f.readline() for f in fl] for fl in input_fscps] output_lines = [[f.readline() for f in fl] for fl in output_fscps] lines = [[f.readline() for f in fl] for fl in fscps] # Get the first line concat = sum(input_lines + output_lines + lines, []) if len(concat) == 0: break first = concat[0] # Sanity check: Must be sorted by the first column and have same keys count = 0 for ls_list in (input_lines, output_lines, lines): for ls in ls_list: for line in ls: if line == "" or first == "": if line != first: concat = sum(input_infos + output_infos + infos, []) raise RuntimeError("The number of lines mismatch " 'between: "{}" and "{}"'.format( concat[0][1], concat[count][1])) elif line.split()[0] != first.split()[0]: concat = sum(input_infos + output_infos + infos, []) raise RuntimeError( "The keys are mismatch at {}th line " 'between "{}" and "{}":\n>>> {}\n>>> {}'.format( nutt, concat[0][1], concat[count][1], first.rstrip(), line.rstrip(), )) count += 1 # The end of file if first == "": if nutt != 1: out.write("\n") break if nutt != 1: out.write(",\n") entry = {} for inout, _lines, _infos in [ ("input", input_lines, input_infos), ("output", output_lines, output_infos), ("other", lines, infos), ]: lis = [] for idx, (line_list, info_list) in enumerate( zip(_lines, _infos), 1): if inout == "input": d = {"name": "input{}".format(idx)} elif inout == "output": d = {"name": "target{}".format(idx)} else: d = {} # info_list: List[Tuple[str, str, Callable]] # line_list: List[str] for line, info in zip(line_list, info_list): sps = line.split(None, 1) if len(sps) < 2: if not args.allow_one_column: raise RuntimeError( "Format error {}th line in {}: " ' Expecting " ":\n>>> {}'.format( nutt, info[1], line)) uttid = sps[0] value = "" else: uttid, value = sps key = info[0] type_func = info[2] value = value.rstrip() if type_func is not None: try: # type_func: Callable[[str], Any] value = type_func(value) except Exception: logging.error( '"{}" is an invalid function ' "for the {} th line in {}: \n>>> {}".format( info[4], nutt, info[1], line)) raise d[key] = value lis.append(d) if inout != "other": entry[inout] = lis else: # If key == 'other'. only has the first item entry.update(lis[0]) entry = json.dumps( entry, indent=4, ensure_ascii=False, sort_keys=True, separators=(",", ": ")) # Add indent indent = " " * 2 entry = ("\n" + indent).join(entry.split("\n")) uttid = first.split()[0] out.write(' "{}": {}'.format(uttid, entry)) out.write(" }\n}\n") logging.info("{} entries in {}".format(nutt, out.name))