|
|
|
@ -1,21 +1,19 @@
|
|
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
Merge training configs into a single inference config.
|
|
|
|
|
The single inference config is for CLI, which only takes a single config to do inferencing.
|
|
|
|
|
The trainig configs includes: model config, preprocess config, decode config, vocab file and cmvn file.
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
import yaml
|
|
|
|
|
import json
|
|
|
|
|
import os
|
|
|
|
|
import argparse
|
|
|
|
|
import json
|
|
|
|
|
import math
|
|
|
|
|
import os
|
|
|
|
|
from contextlib import redirect_stdout
|
|
|
|
|
|
|
|
|
|
from yacs.config import CfgNode
|
|
|
|
|
|
|
|
|
|
from paddlespeech.s2t.frontend.utility import load_dict
|
|
|
|
|
from contextlib import redirect_stdout
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save(save_path, config):
|
|
|
|
@ -29,18 +27,21 @@ def load(save_path):
|
|
|
|
|
config.merge_from_file(save_path)
|
|
|
|
|
return config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_json(json_path):
|
|
|
|
|
with open(json_path) as f:
|
|
|
|
|
json_content = json.load(f)
|
|
|
|
|
return json_content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def remove_config_part(config, key_list):
|
|
|
|
|
if len(key_list) == 0:
|
|
|
|
|
return
|
|
|
|
|
for i in range(len(key_list) -1):
|
|
|
|
|
for i in range(len(key_list) - 1):
|
|
|
|
|
config = config[key_list[i]]
|
|
|
|
|
config.pop(key_list[-1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_cmvn_from_json(cmvn_stats):
|
|
|
|
|
means = cmvn_stats['mean_stat']
|
|
|
|
|
variance = cmvn_stats['var_stat']
|
|
|
|
@ -51,17 +52,17 @@ def load_cmvn_from_json(cmvn_stats):
|
|
|
|
|
if variance[i] < 1.0e-20:
|
|
|
|
|
variance[i] = 1.0e-20
|
|
|
|
|
variance[i] = 1.0 / math.sqrt(variance[i])
|
|
|
|
|
cmvn_stats = {"mean":means, "istd":variance}
|
|
|
|
|
cmvn_stats = {"mean": means, "istd": variance}
|
|
|
|
|
return cmvn_stats
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def merge_configs(
|
|
|
|
|
conf_path = "conf/conformer.yaml",
|
|
|
|
|
preprocess_path = "conf/preprocess.yaml",
|
|
|
|
|
decode_path = "conf/tuning/decode.yaml",
|
|
|
|
|
vocab_path = "data/vocab.txt",
|
|
|
|
|
cmvn_path = "data/mean_std.json",
|
|
|
|
|
save_path = "conf/conformer_infer.yaml",
|
|
|
|
|
):
|
|
|
|
|
conf_path="conf/conformer.yaml",
|
|
|
|
|
preprocess_path="conf/preprocess.yaml",
|
|
|
|
|
decode_path="conf/tuning/decode.yaml",
|
|
|
|
|
vocab_path="data/vocab.txt",
|
|
|
|
|
cmvn_path="data/mean_std.json",
|
|
|
|
|
save_path="conf/conformer_infer.yaml", ):
|
|
|
|
|
|
|
|
|
|
# Load the configs
|
|
|
|
|
config = load(conf_path)
|
|
|
|
@ -75,14 +76,13 @@ def merge_configs(
|
|
|
|
|
preprocess_config = load(preprocess_path)
|
|
|
|
|
for idx, process in enumerate(preprocess_config["process"]):
|
|
|
|
|
if process['type'] == "cmvn_json":
|
|
|
|
|
preprocess_config["process"][idx][
|
|
|
|
|
"cmvn_path"] = cmvn_stats
|
|
|
|
|
preprocess_config["process"][idx]["cmvn_path"] = cmvn_stats
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
config.preprocess_config = preprocess_config
|
|
|
|
|
else:
|
|
|
|
|
cmvn_stats = load_cmvn_from_json(cmvn_stats)
|
|
|
|
|
config.mean_std_filepath = [{"cmvn_stats":cmvn_stats}]
|
|
|
|
|
config.mean_std_filepath = [{"cmvn_stats": cmvn_stats}]
|
|
|
|
|
config.augmentation_config = ''
|
|
|
|
|
# the cmvn file is end with .ark
|
|
|
|
|
else:
|
|
|
|
@ -95,7 +95,8 @@ def merge_configs(
|
|
|
|
|
# Remove some parts of the config
|
|
|
|
|
|
|
|
|
|
if os.path.exists(preprocess_path):
|
|
|
|
|
remove_train_list = ["train_manifest",
|
|
|
|
|
remove_train_list = [
|
|
|
|
|
"train_manifest",
|
|
|
|
|
"dev_manifest",
|
|
|
|
|
"test_manifest",
|
|
|
|
|
"n_epoch",
|
|
|
|
@ -126,7 +127,8 @@ def merge_configs(
|
|
|
|
|
"maxlen_out",
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
remove_train_list = ["train_manifest",
|
|
|
|
|
remove_train_list = [
|
|
|
|
|
"train_manifest",
|
|
|
|
|
"dev_manifest",
|
|
|
|
|
"test_manifest",
|
|
|
|
|
"n_epoch",
|
|
|
|
@ -147,37 +149,35 @@ def merge_configs(
|
|
|
|
|
try:
|
|
|
|
|
remove_config_part(config, [item])
|
|
|
|
|
except:
|
|
|
|
|
print ( item + " " +"can not be removed")
|
|
|
|
|
print(item + " " + "can not be removed")
|
|
|
|
|
|
|
|
|
|
# Save the config
|
|
|
|
|
save(save_path, config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
|
prog='Config merge', add_help=True)
|
|
|
|
|
parser = argparse.ArgumentParser(prog='Config merge', add_help=True)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
'--cfg_pth', type=str, default = 'conf/transformer.yaml', help='origin config file')
|
|
|
|
|
'--cfg_pth',
|
|
|
|
|
type=str,
|
|
|
|
|
default='conf/transformer.yaml',
|
|
|
|
|
help='origin config file')
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
'--pre_pth', type=str, default= "conf/preprocess.yaml", help='')
|
|
|
|
|
'--pre_pth', type=str, default="conf/preprocess.yaml", help='')
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
'--dcd_pth', type=str, default= "conf/tuninig/decode.yaml", help='')
|
|
|
|
|
'--dcd_pth', type=str, default="conf/tuninig/decode.yaml", help='')
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
'--vb_pth', type=str, default= "data/lang_char/vocab.txt", help='')
|
|
|
|
|
'--vb_pth', type=str, default="data/lang_char/vocab.txt", help='')
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
'--cmvn_pth', type=str, default= "data/mean_std.json", help='')
|
|
|
|
|
'--cmvn_pth', type=str, default="data/mean_std.json", help='')
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
'--save_pth', type=str, default= "conf/transformer_infer.yaml", help='')
|
|
|
|
|
'--save_pth', type=str, default="conf/transformer_infer.yaml", help='')
|
|
|
|
|
parser_args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
merge_configs(
|
|
|
|
|
conf_path = parser_args.cfg_pth,
|
|
|
|
|
decode_path = parser_args.dcd_pth,
|
|
|
|
|
preprocess_path = parser_args.pre_pth,
|
|
|
|
|
vocab_path = parser_args.vb_pth,
|
|
|
|
|
cmvn_path = parser_args.cmvn_pth,
|
|
|
|
|
save_path = parser_args.save_pth,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf_path=parser_args.cfg_pth,
|
|
|
|
|
decode_path=parser_args.dcd_pth,
|
|
|
|
|
preprocess_path=parser_args.pre_pth,
|
|
|
|
|
vocab_path=parser_args.vb_pth,
|
|
|
|
|
cmvn_path=parser_args.cmvn_pth,
|
|
|
|
|
save_path=parser_args.save_pth, )
|
|
|
|
|