Merge pull request #1085 from KPatr1ck/cls_cli

[CLI]Update tags of pretrained_models.
pull/1088/head
Hui Zhang 3 years ago committed by GitHub
commit cca681f6d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -39,7 +39,11 @@ from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ['ASRExecutor'] __all__ = ['ASRExecutor']
pretrained_models = { pretrained_models = {
"wenetspeech_zh_16k": { # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
# e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k".
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
"conformer_wenetspeech-zh-16k": {
'url': 'url':
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz', 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz',
'md5': 'md5':
@ -49,7 +53,7 @@ pretrained_models = {
'ckpt_path': 'ckpt_path':
'exp/conformer/checkpoints/wenetspeech', 'exp/conformer/checkpoints/wenetspeech',
}, },
"transformer_zh_16k": { "transformer_aishell-zh-16k": {
'url': 'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/transformer.model.tar.gz', 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/transformer.model.tar.gz',
'md5': 'md5':
@ -83,7 +87,7 @@ class ASRExecutor(BaseExecutor):
self.parser.add_argument( self.parser.add_argument(
'--model', '--model',
type=str, type=str,
default='wenetspeech', default='conformer_wenetspeech',
help='Choose model type of asr task.') help='Choose model type of asr task.')
self.parser.add_argument( self.parser.add_argument(
'--lang', '--lang',
@ -143,7 +147,7 @@ class ASRExecutor(BaseExecutor):
if cfg_path is None or ckpt_path is None: if cfg_path is None or ckpt_path is None:
sample_rate_str = '16k' if sample_rate == 16000 else '8k' sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '_' + lang + '_' + sample_rate_str tag = model_type + '-' + lang + '-' + sample_rate_str
res_path = self._get_pretrained_path(tag) # wenetspeech_zh res_path = self._get_pretrained_path(tag) # wenetspeech_zh
self.res_path = res_path self.res_path = res_path
self.cfg_path = os.path.join(res_path, self.cfg_path = os.path.join(res_path,
@ -165,7 +169,7 @@ class ASRExecutor(BaseExecutor):
self.config.decoding.decoding_method = "attention_rescoring" self.config.decoding.decoding_method = "attention_rescoring"
with UpdateConfig(self.config): with UpdateConfig(self.config):
if model_type == "ds2_online" or model_type == "ds2_offline": if "ds2_online" in model_type or "ds2_offline" in model_type:
from paddlespeech.s2t.io.collator import SpeechCollator from paddlespeech.s2t.io.collator import SpeechCollator
self.config.collator.vocab_filepath = os.path.join( self.config.collator.vocab_filepath = os.path.join(
res_path, self.config.collator.vocab_filepath) res_path, self.config.collator.vocab_filepath)
@ -178,7 +182,7 @@ class ASRExecutor(BaseExecutor):
spm_model_prefix=self.config.collator.spm_model_prefix) spm_model_prefix=self.config.collator.spm_model_prefix)
self.config.model.input_dim = self.collate_fn_test.feature_size self.config.model.input_dim = self.collate_fn_test.feature_size
self.config.model.output_dim = text_feature.vocab_size self.config.model.output_dim = text_feature.vocab_size
elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech": elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
self.config.collator.vocab_filepath = os.path.join( self.config.collator.vocab_filepath = os.path.join(
res_path, self.config.collator.vocab_filepath) res_path, self.config.collator.vocab_filepath)
self.config.collator.augmentation_config = os.path.join( self.config.collator.augmentation_config = os.path.join(
@ -196,7 +200,9 @@ class ASRExecutor(BaseExecutor):
raise Exception("wrong type") raise Exception("wrong type")
# Enter the path of model root # Enter the path of model root
model_class = dynamic_import(model_type, model_alias) model_name = ''.join(
model_type.split('_')[:-1]) # model_type: {model_name}_{dataset}
model_class = dynamic_import(model_name, model_alias)
model_conf = self.config.model model_conf = self.config.model
logger.info(model_conf) logger.info(model_conf)
model = model_class.from_config(model_conf) model = model_class.from_config(model_conf)
@ -217,7 +223,7 @@ class ASRExecutor(BaseExecutor):
logger.info("Preprocess audio_file:" + audio_file) logger.info("Preprocess audio_file:" + audio_file)
# Get the object for feature extraction # Get the object for feature extraction
if model_type == "ds2_online" or model_type == "ds2_offline": if "ds2_online" in model_type or "ds2_offline" in model_type:
audio, _ = self.collate_fn_test.process_utterance( audio, _ = self.collate_fn_test.process_utterance(
audio_file=audio_file, transcript=" ") audio_file=audio_file, transcript=" ")
audio_len = audio.shape[0] audio_len = audio.shape[0]
@ -229,7 +235,7 @@ class ASRExecutor(BaseExecutor):
self._inputs["audio_len"] = audio_len self._inputs["audio_len"] = audio_len
logger.info(f"audio feat shape: {audio.shape}") logger.info(f"audio feat shape: {audio.shape}")
elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech": elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
logger.info("get the preprocess conf") logger.info("get the preprocess conf")
preprocess_conf_file = self.config.collator.augmentation_config preprocess_conf_file = self.config.collator.augmentation_config
# redirect the cmvn path # redirect the cmvn path
@ -293,7 +299,7 @@ class ASRExecutor(BaseExecutor):
cfg = self.config.decoding cfg = self.config.decoding
audio = self._inputs["audio"] audio = self._inputs["audio"]
audio_len = self._inputs["audio_len"] audio_len = self._inputs["audio_len"]
if model_type == "ds2_online" or model_type == "ds2_offline": if "ds2_online" in model_type or "ds2_offline" in model_type:
result_transcripts = self.model.decode( result_transcripts = self.model.decode(
audio, audio,
audio_len, audio_len,
@ -308,7 +314,7 @@ class ASRExecutor(BaseExecutor):
num_processes=cfg.num_proc_bsearch) num_processes=cfg.num_proc_bsearch)
self._outputs["result"] = result_transcripts[0] self._outputs["result"] = result_transcripts[0]
elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech": elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
result_transcripts = self.model.decode( result_transcripts = self.model.decode(
audio, audio,
audio_len, audio_len,

@ -33,21 +33,25 @@ from paddlespeech.s2t.utils.dynamic_import import dynamic_import
__all__ = ['CLSExecutor'] __all__ = ['CLSExecutor']
pretrained_models = { pretrained_models = {
"panns_cnn6": { # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
# e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k".
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
"panns_cnn6-32k": {
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn6.tar.gz', 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn6.tar.gz',
'md5': '051b30c56bcb9a3dd67bc205cc12ffd2', 'md5': '4cf09194a95df024fd12f84712cf0f9c',
'cfg_path': 'panns.yaml', 'cfg_path': 'panns.yaml',
'ckpt_path': 'cnn6.pdparams', 'ckpt_path': 'cnn6.pdparams',
'label_file': 'audioset_labels.txt', 'label_file': 'audioset_labels.txt',
}, },
"panns_cnn10": { "panns_cnn10-32k": {
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn10.tar.gz', 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn10.tar.gz',
'md5': '97c6f25587685379b1ebcd4c1f624927', 'md5': 'cb8427b22176cc2116367d14847f5413',
'cfg_path': 'panns.yaml', 'cfg_path': 'panns.yaml',
'ckpt_path': 'cnn10.pdparams', 'ckpt_path': 'cnn10.pdparams',
'label_file': 'audioset_labels.txt', 'label_file': 'audioset_labels.txt',
}, },
"panns_cnn14": { "panns_cnn14-32k": {
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn14.tar.gz', 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn14.tar.gz',
'md5': 'e3b9b5614a1595001161d0ab95edee97', 'md5': 'e3b9b5614a1595001161d0ab95edee97',
'cfg_path': 'panns.yaml', 'cfg_path': 'panns.yaml',
@ -123,8 +127,8 @@ class CLSExecutor(BaseExecutor):
def _init_from_path(self, def _init_from_path(self,
model_type: str='panns_cnn14', model_type: str='panns_cnn14',
cfg_path: Optional[os.PathLike]=None, cfg_path: Optional[os.PathLike]=None,
label_file: Optional[os.PathLike]=None, ckpt_path: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None): label_file: Optional[os.PathLike]=None):
""" """
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
@ -133,13 +137,14 @@ class CLSExecutor(BaseExecutor):
return return
if label_file is None or ckpt_path is None: if label_file is None or ckpt_path is None:
self.res_path = self._get_pretrained_path(model_type) # panns_cnn14 tag = model_type + '-' + '32k' # panns_cnn14-32k
self.cfg_path = os.path.join( self.res_path = self._get_pretrained_path(tag)
self.res_path, pretrained_models[model_type]['cfg_path']) self.cfg_path = os.path.join(self.res_path,
self.label_file = os.path.join( pretrained_models[tag]['cfg_path'])
self.res_path, pretrained_models[model_type]['label_file']) self.label_file = os.path.join(self.res_path,
self.ckpt_path = os.path.join( pretrained_models[tag]['label_file'])
self.res_path, pretrained_models[model_type]['ckpt_path']) self.ckpt_path = os.path.join(self.res_path,
pretrained_models[tag]['ckpt_path'])
else: else:
self.cfg_path = os.path.abspath(cfg_path) self.cfg_path = os.path.abspath(cfg_path)
self.label_file = os.path.abspath(label_file) self.label_file = os.path.abspath(label_file)
@ -239,15 +244,15 @@ class CLSExecutor(BaseExecutor):
logger.exception(e) logger.exception(e)
return False return False
def __call__(self, model_type, cfg_path, label_file, ckpt_path, audio_file, def __call__(self, model, config, ckpt_path, label_file, audio_file, topk,
topk, device): device):
""" """
Python API to call an executor. Python API to call an executor.
""" """
audio_file = os.path.abspath(audio_file) audio_file = os.path.abspath(audio_file)
# self._check(audio_file, sample_rate) # self._check(audio_file, sample_rate)
paddle.set_device(device) paddle.set_device(device)
self._init_from_path(model_type, cfg_path, label_file, ckpt_path) self._init_from_path(model, config, ckpt_path, label_file)
self.preprocess(audio_file) self.preprocess(audio_file)
self.infer() self.infer()
res = self.postprocess(topk) # Retrieve result of cls. res = self.postprocess(topk) # Retrieve result of cls.

Loading…
Cancel
Save