pull/3088/head
th.zhang 2 years ago
parent 3a31163f1c
commit 35c75fe052

@ -51,11 +51,8 @@ class SSLExecutor(BaseExecutor):
self.parser.add_argument( self.parser.add_argument(
'--model', '--model',
type=str, type=str,
default=None, default='wav2vec2',
choices=[ choices=['wav2vec2', 'hubert'],
tag[:tag.index('-')]
for tag in self.task_resource.pretrained_models.keys()
],
help='Choose model type of asr task.') help='Choose model type of asr task.')
self.parser.add_argument( self.parser.add_argument(
'--task', '--task',
@ -67,7 +64,7 @@ class SSLExecutor(BaseExecutor):
'--lang', '--lang',
type=str, type=str,
default='en', default='en',
help='Choose model language. zh or en, zh:[wav2vec2ASR_aishell1-zh-16k], en:[wav2vec2ASR_librispeech-en-16k]' help='Choose model language. zh or en, zh:[wav2vec2ASR_aishell1-zh-16k], en:[wav2vec2ASR_librispeech-en-16k, hubertASR_librispeech_100-en-16k]'
) )
self.parser.add_argument( self.parser.add_argument(
"--sample_rate", "--sample_rate",
@ -137,13 +134,6 @@ class SSLExecutor(BaseExecutor):
logger.debug("start to init the model") logger.debug("start to init the model")
if model_type is None: if model_type is None:
if lang == 'en':
model_type = 'wav2vec2ASR_librispeech'
elif lang == 'zh':
model_type = 'wav2vec2ASR_aishell1'
else:
logger.error(
"invalid lang, please input --lang en or --lang zh")
logger.debug( logger.debug(
"Model type had not been specified, default {} was used.". "Model type had not been specified, default {} was used.".
format(model_type)) format(model_type))
@ -155,9 +145,20 @@ class SSLExecutor(BaseExecutor):
if cfg_path is None or ckpt_path is None: if cfg_path is None or ckpt_path is None:
sample_rate_str = '16k' if sample_rate == 16000 else '8k' sample_rate_str = '16k' if sample_rate == 16000 else '8k'
if task == 'asr': if task == 'asr':
tag = model_type + '-' + lang + '-' + sample_rate_str if model_type == 'wav2vec2':
if lang == 'en':
model_prefix = 'wav2vec2ASR_librispeech'
elif lang == 'zh':
model_prefix = 'wav2vec2ASR_aishell1'
tag = model_prefix + '-' + lang + '-' + sample_rate_str
elif model_type == 'hubert':
if lang == 'en':
model_prefix = 'hubertASR_librispeech_100'
elif lang == 'zh':
logger.error("zh hubertASR is not supported yet")
tag = model_prefix + '-' + lang + '-' + sample_rate_str
else: else:
tag = 'wav2vec2' + '-' + lang + '-' + sample_rate_str tag = model_type + '-' + lang + '-' + sample_rate_str
self.task_resource.set_task_model(tag, version=None) self.task_resource.set_task_model(tag, version=None)
self.res_path = self.task_resource.res_dir self.res_path = self.task_resource.res_dir
@ -191,7 +192,7 @@ class SSLExecutor(BaseExecutor):
model_name = model_type[:model_type.rindex( model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset} '_')] # model_type: {model_name}_{dataset}
else: else:
model_name = 'wav2vec2' model_name = model_type
model_class = self.task_resource.get_model_class(model_name) model_class = self.task_resource.get_model_class(model_name)
model_conf = self.config model_conf = self.config
@ -204,9 +205,9 @@ class SSLExecutor(BaseExecutor):
if task == 'asr': if task == 'asr':
self.model.set_state_dict(model_dict) self.model.set_state_dict(model_dict)
else: else:
self.model.wav2vec2.set_state_dict(model_dict) getattr(self.model, model_type).set_state_dict(model_dict)
def preprocess(self, model_type: str, input: Union[str, os.PathLike]): def preprocess(self, input: Union[str, os.PathLike]):
""" """
Input preprocess and return paddle.Tensor stored in self.input. Input preprocess and return paddle.Tensor stored in self.input.
Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet). Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet).
@ -264,7 +265,7 @@ class SSLExecutor(BaseExecutor):
if task == 'asr': if task == 'asr':
cfg = self.config.decode cfg = self.config.decode
logger.debug( logger.debug(
f"we will use the wav2vec2ASR like model : {model_type}") f"we will use the {model_type}ASR like model.")
try: try:
result_transcripts = self.model.decode( result_transcripts = self.model.decode(
audio, audio,
@ -277,7 +278,7 @@ class SSLExecutor(BaseExecutor):
logger.exception(e) logger.exception(e)
else: else:
logger.debug( logger.debug(
"we will use the wav2vec2 like model to extract audio feature") f"we will use the {model_type} like model to extract audio feature.")
try: try:
out_feature = self.model(audio[:, :, 0]) out_feature = self.model(audio[:, :, 0])
self._outputs["result"] = out_feature[0] self._outputs["result"] = out_feature[0]

@ -23,6 +23,8 @@ model_alias = {
# --------------------------------- # ---------------------------------
"wav2vec2ASR": ["paddlespeech.s2t.models.wav2vec2:Wav2vec2ASR"], "wav2vec2ASR": ["paddlespeech.s2t.models.wav2vec2:Wav2vec2ASR"],
"wav2vec2": ["paddlespeech.s2t.models.wav2vec2:Wav2vec2Base"], "wav2vec2": ["paddlespeech.s2t.models.wav2vec2:Wav2vec2Base"],
"hubertASR": ["paddlespeech.s2t.models.hubert:HubertASR"],
"hubert": ["paddlespeech.s2t.models.hubert:HubertBase"],
# --------------------------------- # ---------------------------------
# -------------- ASR -------------- # -------------- ASR --------------

@ -119,6 +119,38 @@ ssl_dynamic_pretrained_models = {
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams', 'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
}, },
}, },
"hubert-en-16k": {
'1.4': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/hubert-large-lv60_ckpt_1.4.0.model.tar.gz',
'md5':
'9f0bc943adb822789bf61e674b229d17',
'cfg_path':
'model.yaml',
'ckpt_path':
'hubert-large-lv60',
'model':
'hubert-large-lv60.pdparams',
'params':
'hubert-large-lv60.pdparams',
},
},
"hubertASR_librispeech_100-en-16k": {
'1.4': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/hubertASR-large-100h-librispeech_ckpt_1.4.0.model.tar.gz',
'md5':
'9f0bc943adb822789bf61e674b229d17',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/wav2vec2ASR/checkpoints/avg_1',
'model':
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
'params':
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
},
},
} }
# --------------------------------- # ---------------------------------

@ -0,0 +1,17 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .hubert_ASR import HubertASR
from .wav2vec2_ASR import Wav2vec2Base
__all__ = ["Wav2vec2ASR", "Wav2vec2Base"]
Loading…
Cancel
Save