refactor the cli/st, test=st

pull/1312/head
huangyuxin 3 years ago
parent 62ac429044
commit fe1dc9d211

@ -40,11 +40,11 @@ __all__ = ["STExecutor"]
pretrained_models = { pretrained_models = {
"fat_st_ted-en-zh": { "fat_st_ted-en-zh": {
"url": "url":
"https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/fat_st_ted-en-zh.tar.gz", "https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/st1_transformer_mtl_noam_ted-en-zh_ckpt_0.1.1.model.tar.gz",
"md5": "md5":
"fa0a7425b91b4f8d259c70b2aca5ae67", "d62063f35a16d91210a71081bd2dd557",
"cfg_path": "cfg_path":
"conf/transformer_mtl_noam.yaml", "model.yaml",
"ckpt_path": "ckpt_path":
"exp/transformer_mtl_noam/checkpoints/fat_st_ted-en-zh.pdparams", "exp/transformer_mtl_noam/checkpoints/fat_st_ted-en-zh.pdparams",
} }
@ -170,24 +170,19 @@ class STExecutor(BaseExecutor):
#Init body. #Init body.
self.config = CfgNode(new_allowed=True) self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path) self.config.merge_from_file(self.cfg_path)
self.config.decoding.decoding_method = "fullsentence" self.config.decode.decoding_method = "fullsentence"
with UpdateConfig(self.config): with UpdateConfig(self.config):
self.config.collator.vocab_filepath = os.path.join( self.config.cmvn_path = os.path.join(
res_path, self.config.collator.vocab_filepath) res_path, self.config.cmvn_path)
self.config.collator.cmvn_path = os.path.join( self.config.spm_model_prefix = os.path.join(
res_path, self.config.collator.cmvn_path) res_path, self.config.spm_model_prefix)
self.config.collator.spm_model_prefix = os.path.join(
res_path, self.config.collator.spm_model_prefix)
self.text_feature = TextFeaturizer( self.text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type, unit_type=self.config.unit_type,
vocab=self.config.collator.vocab_filepath, vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix) spm_model_prefix=self.config.spm_model_prefix)
self.config.model.input_dim = self.config.collator.feat_dim
self.config.model.output_dim = self.text_feature.vocab_size model_conf = self.config
model_conf = self.config.model
logger.info(model_conf)
model_name = model_type[:model_type.rindex( model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset} '_')] # model_type: {model_name}_{dataset}
model_class = dynamic_import(model_name, model_alias) model_class = dynamic_import(model_name, model_alias)
@ -218,7 +213,7 @@ class STExecutor(BaseExecutor):
logger.info("Preprocess audio_file:" + audio_file) logger.info("Preprocess audio_file:" + audio_file)
if "fat_st" in model_type: if "fat_st" in model_type:
cmvn = self.config.collator.cmvn_path cmvn = self.config.cmvn_path
utt_name = "_tmp" utt_name = "_tmp"
# Get the object for feature extraction # Get the object for feature extraction
@ -284,7 +279,7 @@ class STExecutor(BaseExecutor):
""" """
Model inference and result stored in self.output. Model inference and result stored in self.output.
""" """
cfg = self.config.decoding cfg = self.config.decode
audio = self._inputs["audio"] audio = self._inputs["audio"]
audio_len = self._inputs["audio_len"] audio_len = self._inputs["audio_len"]
if model_type == "fat_st_ted": if model_type == "fat_st_ted":

@ -67,21 +67,26 @@ def merge_configs(
config = load(conf_path) config = load(conf_path)
decode_config = load(decode_path) decode_config = load(decode_path)
vocab_list = load_dict(vocab_path) vocab_list = load_dict(vocab_path)
cmvn_stats = load_json(cmvn_path)
if os.path.exists(preprocess_path):
preprocess_config = load(preprocess_path)
for idx, process in enumerate(preprocess_config["process"]):
if process['type'] == "cmvn_json":
preprocess_config["process"][idx][
"cmvn_path"] = cmvn_stats
break
config.preprocess_config = preprocess_config
else:
cmvn_stats = load_cmvn_from_json(cmvn_stats)
config.mean_std_filepath = [{"cmvn_stats":cmvn_stats}]
config.augmentation_config = ''
# If use the kaldi feature, do not load the cmvn file
if cmvn_path.split(".")[-1] == 'json':
cmvn_stats = load_json(cmvn_path)
if os.path.exists(preprocess_path):
preprocess_config = load(preprocess_path)
for idx, process in enumerate(preprocess_config["process"]):
if process['type'] == "cmvn_json":
preprocess_config["process"][idx][
"cmvn_path"] = cmvn_stats
break
config.preprocess_config = preprocess_config
else:
cmvn_stats = load_cmvn_from_json(cmvn_stats)
config.mean_std_filepath = [{"cmvn_stats":cmvn_stats}]
config.augmentation_config = ''
# the cmvn file is end with .ark
else:
config.cmvn_path = cmvn_path
# Updata the config # Updata the config
config.vocab_filepath = vocab_list config.vocab_filepath = vocab_list
config.input_dim = config.feat_dim config.input_dim = config.feat_dim

Loading…
Cancel
Save