diff --git a/paddlespeech/cli/st/infer.py b/paddlespeech/cli/st/infer.py index d6bd6304..1276424c 100644 --- a/paddlespeech/cli/st/infer.py +++ b/paddlespeech/cli/st/infer.py @@ -40,11 +40,11 @@ __all__ = ["STExecutor"] pretrained_models = { "fat_st_ted-en-zh": { "url": - "https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/fat_st_ted-en-zh.tar.gz", + "https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/st1_transformer_mtl_noam_ted-en-zh_ckpt_0.1.1.model.tar.gz", "md5": - "fa0a7425b91b4f8d259c70b2aca5ae67", + "d62063f35a16d91210a71081bd2dd557", "cfg_path": - "conf/transformer_mtl_noam.yaml", + "model.yaml", "ckpt_path": "exp/transformer_mtl_noam/checkpoints/fat_st_ted-en-zh.pdparams", } @@ -170,24 +170,19 @@ class STExecutor(BaseExecutor): #Init body. self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path) - self.config.decoding.decoding_method = "fullsentence" + self.config.decode.decoding_method = "fullsentence" with UpdateConfig(self.config): - self.config.collator.vocab_filepath = os.path.join( - res_path, self.config.collator.vocab_filepath) - self.config.collator.cmvn_path = os.path.join( - res_path, self.config.collator.cmvn_path) - self.config.collator.spm_model_prefix = os.path.join( - res_path, self.config.collator.spm_model_prefix) + self.config.cmvn_path = os.path.join( + res_path, self.config.cmvn_path) + self.config.spm_model_prefix = os.path.join( + res_path, self.config.spm_model_prefix) self.text_feature = TextFeaturizer( - unit_type=self.config.collator.unit_type, - vocab=self.config.collator.vocab_filepath, - spm_model_prefix=self.config.collator.spm_model_prefix) - self.config.model.input_dim = self.config.collator.feat_dim - self.config.model.output_dim = self.text_feature.vocab_size - - model_conf = self.config.model - logger.info(model_conf) + unit_type=self.config.unit_type, + vocab=self.config.vocab_filepath, + spm_model_prefix=self.config.spm_model_prefix) + + model_conf = self.config model_name = model_type[:model_type.rindex( '_')] # model_type: {model_name}_{dataset} model_class = dynamic_import(model_name, model_alias) @@ -218,7 +213,7 @@ class STExecutor(BaseExecutor): logger.info("Preprocess audio_file:" + audio_file) if "fat_st" in model_type: - cmvn = self.config.collator.cmvn_path + cmvn = self.config.cmvn_path utt_name = "_tmp" # Get the object for feature extraction @@ -284,7 +279,7 @@ class STExecutor(BaseExecutor): """ Model inference and result stored in self.output. """ - cfg = self.config.decoding + cfg = self.config.decode audio = self._inputs["audio"] audio_len = self._inputs["audio_len"] if model_type == "fat_st_ted": diff --git a/utils/generate_infer_yaml.py b/utils/generate_infer_yaml.py index d2a6777c..a2eb28c7 100644 --- a/utils/generate_infer_yaml.py +++ b/utils/generate_infer_yaml.py @@ -67,21 +67,26 @@ def merge_configs( config = load(conf_path) decode_config = load(decode_path) vocab_list = load_dict(vocab_path) - cmvn_stats = load_json(cmvn_path) - if os.path.exists(preprocess_path): - preprocess_config = load(preprocess_path) - for idx, process in enumerate(preprocess_config["process"]): - if process['type'] == "cmvn_json": - preprocess_config["process"][idx][ - "cmvn_path"] = cmvn_stats - break - - config.preprocess_config = preprocess_config - else: - cmvn_stats = load_cmvn_from_json(cmvn_stats) - config.mean_std_filepath = [{"cmvn_stats":cmvn_stats}] - config.augmentation_config = '' + # If use the kaldi feature, do not load the cmvn file + if cmvn_path.split(".")[-1] == 'json': + cmvn_stats = load_json(cmvn_path) + if os.path.exists(preprocess_path): + preprocess_config = load(preprocess_path) + for idx, process in enumerate(preprocess_config["process"]): + if process['type'] == "cmvn_json": + preprocess_config["process"][idx][ + "cmvn_path"] = cmvn_stats + break + + config.preprocess_config = preprocess_config + else: + cmvn_stats = load_cmvn_from_json(cmvn_stats) + config.mean_std_filepath = [{"cmvn_stats":cmvn_stats}] + config.augmentation_config = '' + # the cmvn file is end with .ark + else: + config.cmvn_path = cmvn_path # Updata the config config.vocab_filepath = vocab_list config.input_dim = config.feat_dim