diff --git a/paddlespeech/cli/vector/infer.py b/paddlespeech/cli/vector/infer.py index de4d66218..205d61f92 100644 --- a/paddlespeech/cli/vector/infer.py +++ b/paddlespeech/cli/vector/infer.py @@ -42,19 +42,19 @@ from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.vector.modules.sid_model import SpeakerIdetification pretrained_models = { - # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". - # e.g. "conformer_wenetspeech-zh-16k" and "panns_cnn6-32k". - # Command line and python api use "{model_name}[_{dataset}]" as --model, usage: - # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" - "ecapa_tdnn-16k": { + # The tags for pretrained_models should be "{model_name}[-{dataset}][-{sr}][-...]". + # e.g. "ecapa_tdnn-voxceleb12-16k". + # Command line and python api use "{model_name}[-{dataset}]" as --model, usage: + # "paddlespeech vector --task spk --model ecapa_tdnn-voxceleb12-16k --sr 16000 --input ./input.wav" + "ecapa_tdnn-voxceleb12-16k": { 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1_conformer_wenetspeech_ckpt_0.1.1.model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_0.tar.gz', 'md5': - '76cb19ed857e6623856b7cd7ebbfeda4', + '85ff08ce0ef406b8c6d7b5ffc5b2b48f', 'cfg_path': - 'model.yaml', + 'conf/model.yaml', 'ckpt_path': - 'exp/conformer/checkpoints/wenetspeech', + 'model/model', }, } @@ -202,6 +202,14 @@ class VectorExecutor(BaseExecutor): The support models includes \n\t\t{}'.format(tag, "\n\t\t".join(support_models)) res_path = os.path.join(MODEL_HOME, tag) + decompressed_path = download_and_decompress(pretrained_models[tag], + res_path) + + decompressed_path = os.path.abspath(decompressed_path) + logger.info( + 'Use pretrained model stored in: {}'.format(decompressed_path)) + + return decompressed_path def _init_from_path(self, model_type: str='ecapa_tdnn-voxceleb12', @@ -216,7 +224,12 @@ class VectorExecutor(BaseExecutor): if cfg_path is None or ckpt_path is None: sample_rate_str = "16k" if sample_rate == 16000 else "8k" tag = model_type + "-" + sample_rate_str + logger.info(f"load the pretrained model: {tag}") res_path = self._get_pretrained_path(tag) + self.res_path = res_path + + self.cfg_path = os.path.join(res_path, pretrained_models[tag]['cfg_path']) + self.ckpt_path = os.path.join(res_path, pretrained_models[tag]['ckpt_path'] + '.pdparams') else: self.cfg_path = os.path.abspath(cfg_path) self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams") @@ -226,7 +239,7 @@ class VectorExecutor(BaseExecutor): logger.info(f"start to read the ckpt from {self.ckpt_path}") logger.info(f"read the config from {self.cfg_path}") logger.info(f"get the res path {self.res_path}") - + # stage 2: read and config and init the model body self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path)