|
|
|
@ -40,11 +40,11 @@ __all__ = ["STExecutor"]
|
|
|
|
|
pretrained_models = {
|
|
|
|
|
"fat_st_ted-en-zh": {
|
|
|
|
|
"url":
|
|
|
|
|
"https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/fat_st_ted-en-zh.tar.gz",
|
|
|
|
|
"https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/st1_transformer_mtl_noam_ted-en-zh_ckpt_0.1.1.model.tar.gz",
|
|
|
|
|
"md5":
|
|
|
|
|
"fa0a7425b91b4f8d259c70b2aca5ae67",
|
|
|
|
|
"d62063f35a16d91210a71081bd2dd557",
|
|
|
|
|
"cfg_path":
|
|
|
|
|
"conf/transformer_mtl_noam.yaml",
|
|
|
|
|
"model.yaml",
|
|
|
|
|
"ckpt_path":
|
|
|
|
|
"exp/transformer_mtl_noam/checkpoints/fat_st_ted-en-zh.pdparams",
|
|
|
|
|
}
|
|
|
|
@ -170,24 +170,19 @@ class STExecutor(BaseExecutor):
|
|
|
|
|
#Init body.
|
|
|
|
|
self.config = CfgNode(new_allowed=True)
|
|
|
|
|
self.config.merge_from_file(self.cfg_path)
|
|
|
|
|
self.config.decoding.decoding_method = "fullsentence"
|
|
|
|
|
self.config.decode.decoding_method = "fullsentence"
|
|
|
|
|
|
|
|
|
|
with UpdateConfig(self.config):
|
|
|
|
|
self.config.collator.vocab_filepath = os.path.join(
|
|
|
|
|
res_path, self.config.collator.vocab_filepath)
|
|
|
|
|
self.config.collator.cmvn_path = os.path.join(
|
|
|
|
|
res_path, self.config.collator.cmvn_path)
|
|
|
|
|
self.config.collator.spm_model_prefix = os.path.join(
|
|
|
|
|
res_path, self.config.collator.spm_model_prefix)
|
|
|
|
|
self.config.cmvn_path = os.path.join(
|
|
|
|
|
res_path, self.config.cmvn_path)
|
|
|
|
|
self.config.spm_model_prefix = os.path.join(
|
|
|
|
|
res_path, self.config.spm_model_prefix)
|
|
|
|
|
self.text_feature = TextFeaturizer(
|
|
|
|
|
unit_type=self.config.collator.unit_type,
|
|
|
|
|
vocab=self.config.collator.vocab_filepath,
|
|
|
|
|
spm_model_prefix=self.config.collator.spm_model_prefix)
|
|
|
|
|
self.config.model.input_dim = self.config.collator.feat_dim
|
|
|
|
|
self.config.model.output_dim = self.text_feature.vocab_size
|
|
|
|
|
|
|
|
|
|
model_conf = self.config.model
|
|
|
|
|
logger.info(model_conf)
|
|
|
|
|
unit_type=self.config.unit_type,
|
|
|
|
|
vocab=self.config.vocab_filepath,
|
|
|
|
|
spm_model_prefix=self.config.spm_model_prefix)
|
|
|
|
|
|
|
|
|
|
model_conf = self.config
|
|
|
|
|
model_name = model_type[:model_type.rindex(
|
|
|
|
|
'_')] # model_type: {model_name}_{dataset}
|
|
|
|
|
model_class = dynamic_import(model_name, model_alias)
|
|
|
|
@ -218,7 +213,7 @@ class STExecutor(BaseExecutor):
|
|
|
|
|
logger.info("Preprocess audio_file:" + audio_file)
|
|
|
|
|
|
|
|
|
|
if "fat_st" in model_type:
|
|
|
|
|
cmvn = self.config.collator.cmvn_path
|
|
|
|
|
cmvn = self.config.cmvn_path
|
|
|
|
|
utt_name = "_tmp"
|
|
|
|
|
|
|
|
|
|
# Get the object for feature extraction
|
|
|
|
@ -284,7 +279,7 @@ class STExecutor(BaseExecutor):
|
|
|
|
|
"""
|
|
|
|
|
Model inference and result stored in self.output.
|
|
|
|
|
"""
|
|
|
|
|
cfg = self.config.decoding
|
|
|
|
|
cfg = self.config.decode
|
|
|
|
|
audio = self._inputs["audio"]
|
|
|
|
|
audio_len = self._inputs["audio_len"]
|
|
|
|
|
if model_type == "fat_st_ted":
|
|
|
|
|