From 35c75fe05284628e27d9627c1b44740e05421d8d Mon Sep 17 00:00:00 2001 From: "th.zhang" <15600919271@163.com> Date: Thu, 20 Apr 2023 01:07:09 +0800 Subject: [PATCH] hubert cli --- paddlespeech/cli/ssl/infer.py | 41 +++++++++++----------- paddlespeech/resource/model_alias.py | 2 ++ paddlespeech/resource/pretrained_models.py | 32 +++++++++++++++++ paddlespeech/s2t/models/hubert/__init__.py | 17 +++++++++ 4 files changed, 72 insertions(+), 20 deletions(-) diff --git a/paddlespeech/cli/ssl/infer.py b/paddlespeech/cli/ssl/infer.py index dce7c7781..c1dd68f93 100644 --- a/paddlespeech/cli/ssl/infer.py +++ b/paddlespeech/cli/ssl/infer.py @@ -51,11 +51,8 @@ class SSLExecutor(BaseExecutor): self.parser.add_argument( '--model', type=str, - default=None, - choices=[ - tag[:tag.index('-')] - for tag in self.task_resource.pretrained_models.keys() - ], + default='wav2vec2', + choices=['wav2vec2', 'hubert'], help='Choose model type of asr task.') self.parser.add_argument( '--task', @@ -67,7 +64,7 @@ class SSLExecutor(BaseExecutor): '--lang', type=str, default='en', - help='Choose model language. zh or en, zh:[wav2vec2ASR_aishell1-zh-16k], en:[wav2vec2ASR_librispeech-en-16k]' + help='Choose model language. zh or en, zh:[wav2vec2ASR_aishell1-zh-16k], en:[wav2vec2ASR_librispeech-en-16k, hubertASR_librispeech_100-en-16k]' ) self.parser.add_argument( "--sample_rate", @@ -137,13 +134,6 @@ class SSLExecutor(BaseExecutor): logger.debug("start to init the model") if model_type is None: - if lang == 'en': - model_type = 'wav2vec2ASR_librispeech' - elif lang == 'zh': - model_type = 'wav2vec2ASR_aishell1' - else: - logger.error( - "invalid lang, please input --lang en or --lang zh") logger.debug( "Model type had not been specified, default {} was used.". format(model_type)) @@ -155,9 +145,20 @@ class SSLExecutor(BaseExecutor): if cfg_path is None or ckpt_path is None: sample_rate_str = '16k' if sample_rate == 16000 else '8k' if task == 'asr': - tag = model_type + '-' + lang + '-' + sample_rate_str + if model_type == 'wav2vec2': + if lang == 'en': + model_prefix = 'wav2vec2ASR_librispeech' + elif lang == 'zh': + model_prefix = 'wav2vec2ASR_aishell1' + tag = model_prefix + '-' + lang + '-' + sample_rate_str + elif model_type == 'hubert': + if lang == 'en': + model_prefix = 'hubertASR_librispeech_100' + elif lang == 'zh': + logger.error("zh hubertASR is not supported yet") + tag = model_prefix + '-' + lang + '-' + sample_rate_str else: - tag = 'wav2vec2' + '-' + lang + '-' + sample_rate_str + tag = model_type + '-' + lang + '-' + sample_rate_str self.task_resource.set_task_model(tag, version=None) self.res_path = self.task_resource.res_dir @@ -191,7 +192,7 @@ class SSLExecutor(BaseExecutor): model_name = model_type[:model_type.rindex( '_')] # model_type: {model_name}_{dataset} else: - model_name = 'wav2vec2' + model_name = model_type model_class = self.task_resource.get_model_class(model_name) model_conf = self.config @@ -204,9 +205,9 @@ class SSLExecutor(BaseExecutor): if task == 'asr': self.model.set_state_dict(model_dict) else: - self.model.wav2vec2.set_state_dict(model_dict) + getattr(self.model, model_type).set_state_dict(model_dict) - def preprocess(self, model_type: str, input: Union[str, os.PathLike]): + def preprocess(self, input: Union[str, os.PathLike]): """ Input preprocess and return paddle.Tensor stored in self.input. Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet). @@ -264,7 +265,7 @@ class SSLExecutor(BaseExecutor): if task == 'asr': cfg = self.config.decode logger.debug( - f"we will use the wav2vec2ASR like model : {model_type}") + f"we will use the {model_type}ASR like model.") try: result_transcripts = self.model.decode( audio, @@ -277,7 +278,7 @@ class SSLExecutor(BaseExecutor): logger.exception(e) else: logger.debug( - "we will use the wav2vec2 like model to extract audio feature") + f"we will use the {model_type} like model to extract audio feature.") try: out_feature = self.model(audio[:, :, 0]) self._outputs["result"] = out_feature[0] diff --git a/paddlespeech/resource/model_alias.py b/paddlespeech/resource/model_alias.py index ab0b1828c..04872c72e 100644 --- a/paddlespeech/resource/model_alias.py +++ b/paddlespeech/resource/model_alias.py @@ -23,6 +23,8 @@ model_alias = { # --------------------------------- "wav2vec2ASR": ["paddlespeech.s2t.models.wav2vec2:Wav2vec2ASR"], "wav2vec2": ["paddlespeech.s2t.models.wav2vec2:Wav2vec2Base"], + "hubertASR": ["paddlespeech.s2t.models.hubert:HubertASR"], + "hubert": ["paddlespeech.s2t.models.hubert:HubertBase"], # --------------------------------- # -------------- ASR -------------- diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py index 04df18623..a553b520f 100644 --- a/paddlespeech/resource/pretrained_models.py +++ b/paddlespeech/resource/pretrained_models.py @@ -119,6 +119,38 @@ ssl_dynamic_pretrained_models = { 'exp/wav2vec2ASR/checkpoints/avg_1.pdparams', }, }, + "hubert-en-16k": { + '1.4': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/hubert-large-lv60_ckpt_1.4.0.model.tar.gz', + 'md5': + '9f0bc943adb822789bf61e674b229d17', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'hubert-large-lv60', + 'model': + 'hubert-large-lv60.pdparams', + 'params': + 'hubert-large-lv60.pdparams', + }, + }, + "hubertASR_librispeech_100-en-16k": { + '1.4': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/hubertASR-large-100h-librispeech_ckpt_1.4.0.model.tar.gz', + 'md5': + '9f0bc943adb822789bf61e674b229d17', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/wav2vec2ASR/checkpoints/avg_1', + 'model': + 'exp/wav2vec2ASR/checkpoints/avg_1.pdparams', + 'params': + 'exp/wav2vec2ASR/checkpoints/avg_1.pdparams', + }, + }, } # --------------------------------- diff --git a/paddlespeech/s2t/models/hubert/__init__.py b/paddlespeech/s2t/models/hubert/__init__.py index e69de29bb..4df88bd34 100644 --- a/paddlespeech/s2t/models/hubert/__init__.py +++ b/paddlespeech/s2t/models/hubert/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .hubert_ASR import HubertASR +from .wav2vec2_ASR import Wav2vec2Base + +__all__ = ["Wav2vec2ASR", "Wav2vec2Base"]