#!/usr/bin/env python3 # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) ''' Merge training configs into a single inference config. The single inference config is for CLI, which only takes a single config to do inferencing. The trainig configs includes: model config, preprocess config, decode config, vocab file and cmvn file. ''' import yaml import json import os import argparse import math from yacs.config import CfgNode from paddlespeech.s2t.frontend.utility import load_dict from contextlib import redirect_stdout def save(save_path, config): with open(save_path, 'w') as fp: with redirect_stdout(fp): print(config.dump()) def load(save_path): config = CfgNode(new_allowed=True) config.merge_from_file(save_path) return config def load_json(json_path): with open(json_path) as f: json_content = json.load(f) return json_content def remove_config_part(config, key_list): if len(key_list) == 0: return for i in range(len(key_list) -1): config = config[key_list[i]] config.pop(key_list[-1]) def load_cmvn_from_json(cmvn_stats): means = cmvn_stats['mean_stat'] variance = cmvn_stats['var_stat'] count = cmvn_stats['frame_num'] for i in range(len(means)): means[i] /= count variance[i] = variance[i] / count - means[i] * means[i] if variance[i] < 1.0e-20: variance[i] = 1.0e-20 variance[i] = 1.0 / math.sqrt(variance[i]) cmvn_stats = {"mean":means, "istd":variance} return cmvn_stats def merge_configs( conf_path = "conf/conformer.yaml", preprocess_path = "conf/preprocess.yaml", decode_path = "conf/tuning/decode.yaml", vocab_path = "data/vocab.txt", cmvn_path = "data/mean_std.json", save_path = "conf/conformer_infer.yaml", ): # Load the configs config = load(conf_path) decode_config = load(decode_path) vocab_list = load_dict(vocab_path) cmvn_stats = load_json(cmvn_path) if os.path.exists(preprocess_path): preprocess_config = load(preprocess_path) for idx, process in enumerate(preprocess_config["process"]): if process['type'] == "cmvn_json": preprocess_config["process"][idx][ "cmvn_path"] = cmvn_stats break config.preprocess_config = preprocess_config else: cmvn_stats = load_cmvn_from_json(cmvn_stats) config.mean_std_filepath = [{"cmvn_stats":cmvn_stats}] config.augmentation_config = '' # Updata the config config.vocab_filepath = vocab_list config.input_dim = config.feat_dim config.output_dim = len(config.vocab_filepath) config.decode = decode_config # Remove some parts of the config if os.path.exists(preprocess_path): remove_train_list = ["train_manifest", "dev_manifest", "test_manifest", "n_epoch", "accum_grad", "global_grad_clip", "optim", "optim_conf", "scheduler", "scheduler_conf", "log_interval", "checkpoint", "shuffle_method", "weight_decay", "ctc_grad_norm_type", "minibatches", "subsampling_factor", "batch_bins", "batch_count", "batch_frames_in", "batch_frames_inout", "batch_frames_out", "sortagrad", "feat_dim", "stride_ms", "window_ms", "batch_size", "maxlen_in", "maxlen_out", ] else: remove_train_list = ["train_manifest", "dev_manifest", "test_manifest", "n_epoch", "accum_grad", "global_grad_clip", "log_interval", "checkpoint", "lr", "lr_decay", "batch_size", "shuffle_method", "weight_decay", "sortagrad", "num_workers", ] for item in remove_train_list: try: remove_config_part(config, [item]) except: print ( item + " " +"can not be removed") # Save the config save(save_path, config) if __name__ == "__main__": parser = argparse.ArgumentParser( prog='Config merge', add_help=True) parser.add_argument( '--cfg_pth', type=str, default = 'conf/transformer.yaml', help='origin config file') parser.add_argument( '--pre_pth', type=str, default= "conf/preprocess.yaml", help='') parser.add_argument( '--dcd_pth', type=str, default= "conf/tuninig/decode.yaml", help='') parser.add_argument( '--vb_pth', type=str, default= "data/lang_char/vocab.txt", help='') parser.add_argument( '--cmvn_pth', type=str, default= "data/mean_std.json", help='') parser.add_argument( '--save_pth', type=str, default= "conf/transformer_infer.yaml", help='') parser_args = parser.parse_args() merge_configs( conf_path = parser_args.cfg_pth, decode_path = parser_args.dcd_pth, preprocess_path = parser_args.pre_pth, vocab_path = parser_args.vb_pth, cmvn_path = parser_args.cmvn_pth, save_path = parser_args.save_pth, )