pull/3088/head
th.zhang 2 years ago
parent 3a31163f1c
commit 35c75fe052

@ -51,11 +51,8 @@ class SSLExecutor(BaseExecutor):
self.parser.add_argument(
'--model',
type=str,
default=None,
choices=[
tag[:tag.index('-')]
for tag in self.task_resource.pretrained_models.keys()
],
default='wav2vec2',
choices=['wav2vec2', 'hubert'],
help='Choose model type of asr task.')
self.parser.add_argument(
'--task',
@ -67,7 +64,7 @@ class SSLExecutor(BaseExecutor):
'--lang',
type=str,
default='en',
help='Choose model language. zh or en, zh:[wav2vec2ASR_aishell1-zh-16k], en:[wav2vec2ASR_librispeech-en-16k]'
help='Choose model language. zh or en, zh:[wav2vec2ASR_aishell1-zh-16k], en:[wav2vec2ASR_librispeech-en-16k, hubertASR_librispeech_100-en-16k]'
)
self.parser.add_argument(
"--sample_rate",
@ -137,13 +134,6 @@ class SSLExecutor(BaseExecutor):
logger.debug("start to init the model")
if model_type is None:
if lang == 'en':
model_type = 'wav2vec2ASR_librispeech'
elif lang == 'zh':
model_type = 'wav2vec2ASR_aishell1'
else:
logger.error(
"invalid lang, please input --lang en or --lang zh")
logger.debug(
"Model type had not been specified, default {} was used.".
format(model_type))
@ -155,9 +145,20 @@ class SSLExecutor(BaseExecutor):
if cfg_path is None or ckpt_path is None:
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
if task == 'asr':
tag = model_type + '-' + lang + '-' + sample_rate_str
if model_type == 'wav2vec2':
if lang == 'en':
model_prefix = 'wav2vec2ASR_librispeech'
elif lang == 'zh':
model_prefix = 'wav2vec2ASR_aishell1'
tag = model_prefix + '-' + lang + '-' + sample_rate_str
elif model_type == 'hubert':
if lang == 'en':
model_prefix = 'hubertASR_librispeech_100'
elif lang == 'zh':
logger.error("zh hubertASR is not supported yet")
tag = model_prefix + '-' + lang + '-' + sample_rate_str
else:
tag = 'wav2vec2' + '-' + lang + '-' + sample_rate_str
tag = model_type + '-' + lang + '-' + sample_rate_str
self.task_resource.set_task_model(tag, version=None)
self.res_path = self.task_resource.res_dir
@ -191,7 +192,7 @@ class SSLExecutor(BaseExecutor):
model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset}
else:
model_name = 'wav2vec2'
model_name = model_type
model_class = self.task_resource.get_model_class(model_name)
model_conf = self.config
@ -204,9 +205,9 @@ class SSLExecutor(BaseExecutor):
if task == 'asr':
self.model.set_state_dict(model_dict)
else:
self.model.wav2vec2.set_state_dict(model_dict)
getattr(self.model, model_type).set_state_dict(model_dict)
def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
def preprocess(self, input: Union[str, os.PathLike]):
"""
Input preprocess and return paddle.Tensor stored in self.input.
Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet).
@ -264,7 +265,7 @@ class SSLExecutor(BaseExecutor):
if task == 'asr':
cfg = self.config.decode
logger.debug(
f"we will use the wav2vec2ASR like model : {model_type}")
f"we will use the {model_type}ASR like model.")
try:
result_transcripts = self.model.decode(
audio,
@ -277,7 +278,7 @@ class SSLExecutor(BaseExecutor):
logger.exception(e)
else:
logger.debug(
"we will use the wav2vec2 like model to extract audio feature")
f"we will use the {model_type} like model to extract audio feature.")
try:
out_feature = self.model(audio[:, :, 0])
self._outputs["result"] = out_feature[0]

@ -23,6 +23,8 @@ model_alias = {
# ---------------------------------
"wav2vec2ASR": ["paddlespeech.s2t.models.wav2vec2:Wav2vec2ASR"],
"wav2vec2": ["paddlespeech.s2t.models.wav2vec2:Wav2vec2Base"],
"hubertASR": ["paddlespeech.s2t.models.hubert:HubertASR"],
"hubert": ["paddlespeech.s2t.models.hubert:HubertBase"],
# ---------------------------------
# -------------- ASR --------------

@ -119,6 +119,38 @@ ssl_dynamic_pretrained_models = {
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
},
},
"hubert-en-16k": {
'1.4': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/hubert-large-lv60_ckpt_1.4.0.model.tar.gz',
'md5':
'9f0bc943adb822789bf61e674b229d17',
'cfg_path':
'model.yaml',
'ckpt_path':
'hubert-large-lv60',
'model':
'hubert-large-lv60.pdparams',
'params':
'hubert-large-lv60.pdparams',
},
},
"hubertASR_librispeech_100-en-16k": {
'1.4': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/hubertASR-large-100h-librispeech_ckpt_1.4.0.model.tar.gz',
'md5':
'9f0bc943adb822789bf61e674b229d17',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/wav2vec2ASR/checkpoints/avg_1',
'model':
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
'params':
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
},
},
}
# ---------------------------------

@ -0,0 +1,17 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .hubert_ASR import HubertASR
from .wav2vec2_ASR import Wav2vec2Base
__all__ = ["Wav2vec2ASR", "Wav2vec2Base"]
Loading…
Cancel
Save