refactor the cli/st, test=st

pull/1312/head
huangyuxin 3 years ago
parent 62ac429044
commit fe1dc9d211

@ -40,11 +40,11 @@ __all__ = ["STExecutor"]
pretrained_models = {
"fat_st_ted-en-zh": {
"url":
"https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/fat_st_ted-en-zh.tar.gz",
"https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/st1_transformer_mtl_noam_ted-en-zh_ckpt_0.1.1.model.tar.gz",
"md5":
"fa0a7425b91b4f8d259c70b2aca5ae67",
"d62063f35a16d91210a71081bd2dd557",
"cfg_path":
"conf/transformer_mtl_noam.yaml",
"model.yaml",
"ckpt_path":
"exp/transformer_mtl_noam/checkpoints/fat_st_ted-en-zh.pdparams",
}
@ -170,24 +170,19 @@ class STExecutor(BaseExecutor):
#Init body.
self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path)
self.config.decoding.decoding_method = "fullsentence"
self.config.decode.decoding_method = "fullsentence"
with UpdateConfig(self.config):
self.config.collator.vocab_filepath = os.path.join(
res_path, self.config.collator.vocab_filepath)
self.config.collator.cmvn_path = os.path.join(
res_path, self.config.collator.cmvn_path)
self.config.collator.spm_model_prefix = os.path.join(
res_path, self.config.collator.spm_model_prefix)
self.config.cmvn_path = os.path.join(
res_path, self.config.cmvn_path)
self.config.spm_model_prefix = os.path.join(
res_path, self.config.spm_model_prefix)
self.text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type,
vocab=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
self.config.model.input_dim = self.config.collator.feat_dim
self.config.model.output_dim = self.text_feature.vocab_size
model_conf = self.config.model
logger.info(model_conf)
unit_type=self.config.unit_type,
vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix)
model_conf = self.config
model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset}
model_class = dynamic_import(model_name, model_alias)
@ -218,7 +213,7 @@ class STExecutor(BaseExecutor):
logger.info("Preprocess audio_file:" + audio_file)
if "fat_st" in model_type:
cmvn = self.config.collator.cmvn_path
cmvn = self.config.cmvn_path
utt_name = "_tmp"
# Get the object for feature extraction
@ -284,7 +279,7 @@ class STExecutor(BaseExecutor):
"""
Model inference and result stored in self.output.
"""
cfg = self.config.decoding
cfg = self.config.decode
audio = self._inputs["audio"]
audio_len = self._inputs["audio_len"]
if model_type == "fat_st_ted":

@ -67,21 +67,26 @@ def merge_configs(
config = load(conf_path)
decode_config = load(decode_path)
vocab_list = load_dict(vocab_path)
cmvn_stats = load_json(cmvn_path)
if os.path.exists(preprocess_path):
preprocess_config = load(preprocess_path)
for idx, process in enumerate(preprocess_config["process"]):
if process['type'] == "cmvn_json":
preprocess_config["process"][idx][
"cmvn_path"] = cmvn_stats
break
config.preprocess_config = preprocess_config
else:
cmvn_stats = load_cmvn_from_json(cmvn_stats)
config.mean_std_filepath = [{"cmvn_stats":cmvn_stats}]
config.augmentation_config = ''
# If use the kaldi feature, do not load the cmvn file
if cmvn_path.split(".")[-1] == 'json':
cmvn_stats = load_json(cmvn_path)
if os.path.exists(preprocess_path):
preprocess_config = load(preprocess_path)
for idx, process in enumerate(preprocess_config["process"]):
if process['type'] == "cmvn_json":
preprocess_config["process"][idx][
"cmvn_path"] = cmvn_stats
break
config.preprocess_config = preprocess_config
else:
cmvn_stats = load_cmvn_from_json(cmvn_stats)
config.mean_std_filepath = [{"cmvn_stats":cmvn_stats}]
config.augmentation_config = ''
# the cmvn file is end with .ark
else:
config.cmvn_path = cmvn_path
# Updata the config
config.vocab_filepath = vocab_list
config.input_dim = config.feat_dim

Loading…
Cancel
Save