Update tags of pretrained_models.

pull/1085/head
KP 3 years ago
parent f8204c984a
commit 0b7e0d1e2e

@ -39,7 +39,11 @@ from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ['ASRExecutor']
pretrained_models = {
"wenetspeech_zh_16k": {
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
# e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k".
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
"conformer_wenetspeech-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz',
'md5':
@ -49,7 +53,7 @@ pretrained_models = {
'ckpt_path':
'exp/conformer/checkpoints/wenetspeech',
},
"transformer_zh_16k": {
"transformer_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/transformer.model.tar.gz',
'md5':
@ -83,7 +87,7 @@ class ASRExecutor(BaseExecutor):
self.parser.add_argument(
'--model',
type=str,
default='wenetspeech',
default='conformer_wenetspeech',
help='Choose model type of asr task.')
self.parser.add_argument(
'--lang',
@ -143,7 +147,7 @@ class ASRExecutor(BaseExecutor):
if cfg_path is None or ckpt_path is None:
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '_' + lang + '_' + sample_rate_str
tag = model_type + '-' + lang + '-' + sample_rate_str
res_path = self._get_pretrained_path(tag) # wenetspeech_zh
self.res_path = res_path
self.cfg_path = os.path.join(res_path,
@ -165,7 +169,7 @@ class ASRExecutor(BaseExecutor):
self.config.decoding.decoding_method = "attention_rescoring"
with UpdateConfig(self.config):
if model_type == "ds2_online" or model_type == "ds2_offline":
if "ds2_online" in model_type or "ds2_offline" in model_type:
from paddlespeech.s2t.io.collator import SpeechCollator
self.config.collator.vocab_filepath = os.path.join(
res_path, self.config.collator.vocab_filepath)
@ -178,7 +182,7 @@ class ASRExecutor(BaseExecutor):
spm_model_prefix=self.config.collator.spm_model_prefix)
self.config.model.input_dim = self.collate_fn_test.feature_size
self.config.model.output_dim = text_feature.vocab_size
elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech":
elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
self.config.collator.vocab_filepath = os.path.join(
res_path, self.config.collator.vocab_filepath)
self.config.collator.augmentation_config = os.path.join(
@ -196,7 +200,9 @@ class ASRExecutor(BaseExecutor):
raise Exception("wrong type")
# Enter the path of model root
model_class = dynamic_import(model_type, model_alias)
model_name = ''.join(
model_type.split('_')[:-1]) # model_type: {model_name}_{dataset}
model_class = dynamic_import(model_name, model_alias)
model_conf = self.config.model
logger.info(model_conf)
model = model_class.from_config(model_conf)
@ -217,7 +223,7 @@ class ASRExecutor(BaseExecutor):
logger.info("Preprocess audio_file:" + audio_file)
# Get the object for feature extraction
if model_type == "ds2_online" or model_type == "ds2_offline":
if "ds2_online" in model_type or "ds2_offline" in model_type:
audio, _ = self.collate_fn_test.process_utterance(
audio_file=audio_file, transcript=" ")
audio_len = audio.shape[0]
@ -229,7 +235,7 @@ class ASRExecutor(BaseExecutor):
self._inputs["audio_len"] = audio_len
logger.info(f"audio feat shape: {audio.shape}")
elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech":
elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
logger.info("get the preprocess conf")
preprocess_conf_file = self.config.collator.augmentation_config
# redirect the cmvn path
@ -293,7 +299,7 @@ class ASRExecutor(BaseExecutor):
cfg = self.config.decoding
audio = self._inputs["audio"]
audio_len = self._inputs["audio_len"]
if model_type == "ds2_online" or model_type == "ds2_offline":
if "ds2_online" in model_type or "ds2_offline" in model_type:
result_transcripts = self.model.decode(
audio,
audio_len,
@ -308,7 +314,7 @@ class ASRExecutor(BaseExecutor):
num_processes=cfg.num_proc_bsearch)
self._outputs["result"] = result_transcripts[0]
elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech":
elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
result_transcripts = self.model.decode(
audio,
audio_len,

@ -33,21 +33,25 @@ from paddlespeech.s2t.utils.dynamic_import import dynamic_import
__all__ = ['CLSExecutor']
pretrained_models = {
"panns_cnn6": {
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
# e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k".
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
"panns_cnn6-32k": {
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn6.tar.gz',
'md5': '051b30c56bcb9a3dd67bc205cc12ffd2',
'md5': '4cf09194a95df024fd12f84712cf0f9c',
'cfg_path': 'panns.yaml',
'ckpt_path': 'cnn6.pdparams',
'label_file': 'audioset_labels.txt',
},
"panns_cnn10": {
"panns_cnn10-32k": {
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn10.tar.gz',
'md5': '97c6f25587685379b1ebcd4c1f624927',
'md5': 'cb8427b22176cc2116367d14847f5413',
'cfg_path': 'panns.yaml',
'ckpt_path': 'cnn10.pdparams',
'label_file': 'audioset_labels.txt',
},
"panns_cnn14": {
"panns_cnn14-32k": {
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn14.tar.gz',
'md5': 'e3b9b5614a1595001161d0ab95edee97',
'cfg_path': 'panns.yaml',
@ -76,7 +80,7 @@ class CLSExecutor(BaseExecutor):
self.parser.add_argument(
'--model',
type=str,
default='panns_cnn14',
default='panns_cnn10',
help='Choose model type of cls task.')
self.parser.add_argument(
'--config',
@ -133,13 +137,14 @@ class CLSExecutor(BaseExecutor):
return
if label_file is None or ckpt_path is None:
self.res_path = self._get_pretrained_path(model_type) # panns_cnn14
self.cfg_path = os.path.join(
self.res_path, pretrained_models[model_type]['cfg_path'])
self.label_file = os.path.join(
self.res_path, pretrained_models[model_type]['label_file'])
self.ckpt_path = os.path.join(
self.res_path, pretrained_models[model_type]['ckpt_path'])
tag = model_type + '-' + '32k' # panns_cnn14-32k
self.res_path = self._get_pretrained_path(tag)
self.cfg_path = os.path.join(self.res_path,
pretrained_models[tag]['cfg_path'])
self.label_file = os.path.join(self.res_path,
pretrained_models[tag]['label_file'])
self.ckpt_path = os.path.join(self.res_path,
pretrained_models[tag]['ckpt_path'])
else:
self.cfg_path = os.path.abspath(cfg_path)
self.label_file = os.path.abspath(label_file)

Loading…
Cancel
Save