You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/utils/generate_infer_yaml.py

184 lines
5.5 KiB

#!/usr/bin/env python3
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
'''
Merge training configs into a single inference config.
The single inference config is for CLI, which only takes a single config to do inferencing.
The trainig configs includes: model config, preprocess config, decode config, vocab file and cmvn file.
'''
import argparse
import json
import math
import os
from contextlib import redirect_stdout
from yacs.config import CfgNode
from paddlespeech.s2t.frontend.utility import load_dict
def save(save_path, config):
with open(save_path, 'w') as fp:
with redirect_stdout(fp):
print(config.dump())
def load(save_path):
config = CfgNode(new_allowed=True)
config.merge_from_file(save_path)
return config
def load_json(json_path):
with open(json_path) as f:
json_content = json.load(f)
return json_content
def remove_config_part(config, key_list):
if len(key_list) == 0:
return
for i in range(len(key_list) - 1):
config = config[key_list[i]]
config.pop(key_list[-1])
def load_cmvn_from_json(cmvn_stats):
means = cmvn_stats['mean_stat']
variance = cmvn_stats['var_stat']
count = cmvn_stats['frame_num']
for i in range(len(means)):
means[i] /= count
variance[i] = variance[i] / count - means[i] * means[i]
if variance[i] < 1.0e-20:
variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i])
cmvn_stats = {"mean": means, "istd": variance}
return cmvn_stats
def merge_configs(
conf_path="conf/conformer.yaml",
preprocess_path="conf/preprocess.yaml",
decode_path="conf/tuning/decode.yaml",
vocab_path="data/vocab.txt",
cmvn_path="data/mean_std.json",
save_path="conf/conformer_infer.yaml", ):
# Load the configs
config = load(conf_path)
decode_config = load(decode_path)
vocab_list = load_dict(vocab_path)
# If use the kaldi feature, do not load the cmvn file
if cmvn_path.split(".")[-1] == 'json':
cmvn_stats = load_json(cmvn_path)
if os.path.exists(preprocess_path):
preprocess_config = load(preprocess_path)
for idx, process in enumerate(preprocess_config["process"]):
if process['type'] == "cmvn_json":
preprocess_config["process"][idx]["cmvn_path"] = cmvn_stats
break
config.preprocess_config = preprocess_config
else:
cmvn_stats = load_cmvn_from_json(cmvn_stats)
config.mean_std_filepath = [{"cmvn_stats": cmvn_stats}]
config.augmentation_config = ''
# the cmvn file is end with .ark
else:
config.cmvn_path = cmvn_path
# Updata the config
config.vocab_filepath = vocab_list
config.input_dim = config.feat_dim
config.output_dim = len(config.vocab_filepath)
config.decode = decode_config
# Remove some parts of the config
if os.path.exists(preprocess_path):
remove_train_list = [
"train_manifest",
"dev_manifest",
"test_manifest",
"n_epoch",
"accum_grad",
"global_grad_clip",
"optim",
"optim_conf",
"scheduler",
"scheduler_conf",
"log_interval",
"checkpoint",
"shuffle_method",
"weight_decay",
"ctc_grad_norm_type",
"minibatches",
"subsampling_factor",
"batch_bins",
"batch_count",
"batch_frames_in",
"batch_frames_inout",
"batch_frames_out",
"sortagrad",
"feat_dim",
"stride_ms",
"window_ms",
"batch_size",
"maxlen_in",
"maxlen_out",
]
else:
remove_train_list = [
"train_manifest",
"dev_manifest",
"test_manifest",
"n_epoch",
"accum_grad",
"global_grad_clip",
"log_interval",
"checkpoint",
"lr",
"lr_decay",
"batch_size",
"shuffle_method",
"weight_decay",
"sortagrad",
"num_workers",
]
for item in remove_train_list:
try:
remove_config_part(config, [item])
except Exception as e:
print(item + " " + "can not be removed")
# Save the config
save(save_path, config)
if __name__ == "__main__":
parser = argparse.ArgumentParser(prog='Config merge', add_help=True)
parser.add_argument(
'--cfg_pth',
type=str,
default='conf/transformer.yaml',
help='origin config file')
parser.add_argument(
'--pre_pth', type=str, default="conf/preprocess.yaml", help='')
parser.add_argument(
'--dcd_pth', type=str, default="conf/tuninig/decode.yaml", help='')
parser.add_argument(
'--vb_pth', type=str, default="data/lang_char/vocab.txt", help='')
parser.add_argument(
'--cmvn_pth', type=str, default="data/mean_std.json", help='')
parser.add_argument(
'--save_pth', type=str, default="conf/transformer_infer.yaml", help='')
parser_args = parser.parse_args()
merge_configs(
conf_path=parser_args.cfg_pth,
decode_path=parser_args.dcd_pth,
preprocess_path=parser_args.pre_pth,
vocab_path=parser_args.vb_pth,
cmvn_path=parser_args.cmvn_pth,
save_path=parser_args.save_pth, )