diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 30e4bb9c1..1d235201d 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -39,7 +39,11 @@ from paddlespeech.s2t.utils.utility import UpdateConfig __all__ = ['ASRExecutor'] pretrained_models = { - "wenetspeech_zh_16k": { + # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". + # e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k". + # Command line and python api use "{model_name}[_{dataset}]" as --model, usage: + # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" + "conformer_wenetspeech-zh-16k": { 'url': 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz', 'md5': @@ -49,7 +53,7 @@ pretrained_models = { 'ckpt_path': 'exp/conformer/checkpoints/wenetspeech', }, - "transformer_zh_16k": { + "transformer_aishell-zh-16k": { 'url': 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/transformer.model.tar.gz', 'md5': @@ -83,7 +87,7 @@ class ASRExecutor(BaseExecutor): self.parser.add_argument( '--model', type=str, - default='wenetspeech', + default='conformer_wenetspeech', help='Choose model type of asr task.') self.parser.add_argument( '--lang', @@ -143,7 +147,7 @@ class ASRExecutor(BaseExecutor): if cfg_path is None or ckpt_path is None: sample_rate_str = '16k' if sample_rate == 16000 else '8k' - tag = model_type + '_' + lang + '_' + sample_rate_str + tag = model_type + '-' + lang + '-' + sample_rate_str res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.res_path = res_path self.cfg_path = os.path.join(res_path, @@ -165,7 +169,7 @@ class ASRExecutor(BaseExecutor): self.config.decoding.decoding_method = "attention_rescoring" with UpdateConfig(self.config): - if model_type == "ds2_online" or model_type == "ds2_offline": + if "ds2_online" in model_type or "ds2_offline" in model_type: from paddlespeech.s2t.io.collator import SpeechCollator self.config.collator.vocab_filepath = os.path.join( res_path, self.config.collator.vocab_filepath) @@ -178,7 +182,7 @@ class ASRExecutor(BaseExecutor): spm_model_prefix=self.config.collator.spm_model_prefix) self.config.model.input_dim = self.collate_fn_test.feature_size self.config.model.output_dim = text_feature.vocab_size - elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech": + elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: self.config.collator.vocab_filepath = os.path.join( res_path, self.config.collator.vocab_filepath) self.config.collator.augmentation_config = os.path.join( @@ -196,7 +200,9 @@ class ASRExecutor(BaseExecutor): raise Exception("wrong type") # Enter the path of model root - model_class = dynamic_import(model_type, model_alias) + model_name = ''.join( + model_type.split('_')[:-1]) # model_type: {model_name}_{dataset} + model_class = dynamic_import(model_name, model_alias) model_conf = self.config.model logger.info(model_conf) model = model_class.from_config(model_conf) @@ -217,7 +223,7 @@ class ASRExecutor(BaseExecutor): logger.info("Preprocess audio_file:" + audio_file) # Get the object for feature extraction - if model_type == "ds2_online" or model_type == "ds2_offline": + if "ds2_online" in model_type or "ds2_offline" in model_type: audio, _ = self.collate_fn_test.process_utterance( audio_file=audio_file, transcript=" ") audio_len = audio.shape[0] @@ -229,7 +235,7 @@ class ASRExecutor(BaseExecutor): self._inputs["audio_len"] = audio_len logger.info(f"audio feat shape: {audio.shape}") - elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech": + elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: logger.info("get the preprocess conf") preprocess_conf_file = self.config.collator.augmentation_config # redirect the cmvn path @@ -293,7 +299,7 @@ class ASRExecutor(BaseExecutor): cfg = self.config.decoding audio = self._inputs["audio"] audio_len = self._inputs["audio_len"] - if model_type == "ds2_online" or model_type == "ds2_offline": + if "ds2_online" in model_type or "ds2_offline" in model_type: result_transcripts = self.model.decode( audio, audio_len, @@ -308,7 +314,7 @@ class ASRExecutor(BaseExecutor): num_processes=cfg.num_proc_bsearch) self._outputs["result"] = result_transcripts[0] - elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech": + elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: result_transcripts = self.model.decode( audio, audio_len, diff --git a/paddlespeech/cli/cls/infer.py b/paddlespeech/cli/cls/infer.py index c4206f7e5..dc976cb43 100644 --- a/paddlespeech/cli/cls/infer.py +++ b/paddlespeech/cli/cls/infer.py @@ -33,21 +33,25 @@ from paddlespeech.s2t.utils.dynamic_import import dynamic_import __all__ = ['CLSExecutor'] pretrained_models = { - "panns_cnn6": { + # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". + # e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k". + # Command line and python api use "{model_name}[_{dataset}]" as --model, usage: + # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" + "panns_cnn6-32k": { 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn6.tar.gz', - 'md5': '051b30c56bcb9a3dd67bc205cc12ffd2', + 'md5': '4cf09194a95df024fd12f84712cf0f9c', 'cfg_path': 'panns.yaml', 'ckpt_path': 'cnn6.pdparams', 'label_file': 'audioset_labels.txt', }, - "panns_cnn10": { + "panns_cnn10-32k": { 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn10.tar.gz', - 'md5': '97c6f25587685379b1ebcd4c1f624927', + 'md5': 'cb8427b22176cc2116367d14847f5413', 'cfg_path': 'panns.yaml', 'ckpt_path': 'cnn10.pdparams', 'label_file': 'audioset_labels.txt', }, - "panns_cnn14": { + "panns_cnn14-32k": { 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn14.tar.gz', 'md5': 'e3b9b5614a1595001161d0ab95edee97', 'cfg_path': 'panns.yaml', @@ -76,7 +80,7 @@ class CLSExecutor(BaseExecutor): self.parser.add_argument( '--model', type=str, - default='panns_cnn14', + default='panns_cnn10', help='Choose model type of cls task.') self.parser.add_argument( '--config', @@ -133,13 +137,14 @@ class CLSExecutor(BaseExecutor): return if label_file is None or ckpt_path is None: - self.res_path = self._get_pretrained_path(model_type) # panns_cnn14 - self.cfg_path = os.path.join( - self.res_path, pretrained_models[model_type]['cfg_path']) - self.label_file = os.path.join( - self.res_path, pretrained_models[model_type]['label_file']) - self.ckpt_path = os.path.join( - self.res_path, pretrained_models[model_type]['ckpt_path']) + tag = model_type + '-' + '32k' # panns_cnn14-32k + self.res_path = self._get_pretrained_path(tag) + self.cfg_path = os.path.join(self.res_path, + pretrained_models[tag]['cfg_path']) + self.label_file = os.path.join(self.res_path, + pretrained_models[tag]['label_file']) + self.ckpt_path = os.path.join(self.res_path, + pretrained_models[tag]['ckpt_path']) else: self.cfg_path = os.path.abspath(cfg_path) self.label_file = os.path.abspath(label_file)