From 6f3585b199de097e4bbc81926691e2dea1b3adb2 Mon Sep 17 00:00:00 2001 From: "th.zhang" <15600919271@163.com> Date: Thu, 27 Apr 2023 02:20:04 +0800 Subject: [PATCH] support hubert cli --- examples/librispeech/asr4/README.md | 4 +- paddlespeech/cli/ssl/infer.py | 17 ++--- paddlespeech/resource/pretrained_models.py | 64 +++++++++---------- paddlespeech/s2t/models/hubert/hubert_ASR.py | 26 ++++++-- .../s2t/models/hubert/modules/__init__.py | 13 ++++ .../s2t/models/hubert/modules/hubert_model.py | 2 +- 6 files changed, 76 insertions(+), 50 deletions(-) create mode 100644 paddlespeech/s2t/models/hubert/modules/__init__.py diff --git a/examples/librispeech/asr4/README.md b/examples/librispeech/asr4/README.md index d377a623a..064a7f16b 100644 --- a/examples/librispeech/asr4/README.md +++ b/examples/librispeech/asr4/README.md @@ -163,7 +163,7 @@ using the `tar` scripts to unpack the model and then you can use the script to t For example: ```bash -wget https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/hubertASR-large-100h-librispeech_ckpt_1.4.0.model.tar.gz +wget https://paddlespeech.bj.bcebos.com/hubert/hubertASR-large-100h-librispeech_ckpt_1.4.0.model.tar.gz tar xzvf hubertASR-large-100h-librispeech_ckpt_1.4.0.model.tar.gz source path.sh # If you have process the data and get the manifest file, you can skip the following 2 steps @@ -184,7 +184,7 @@ In some situations, you want to use the trained model to do the inference for th ``` you can train the model by yourself using ```bash run.sh --stage 0 --stop_stage 3```, or you can download the pretrained model through the script below: ```bash -wget https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/hubertASR-large-100h-librispeech_ckpt_1.4.0.model.tar.gz +wget https://paddlespeech.bj.bcebos.com/hubert/hubertASR-large-100h-librispeech_ckpt_1.4.0.model.tar.gz tar xzvf hubertASR-large-100h-librispeech_ckpt_1.4.0.model.tar.gz ``` You can download the audio demo: diff --git a/paddlespeech/cli/ssl/infer.py b/paddlespeech/cli/ssl/infer.py index c1dd68f93..49fdbf39d 100644 --- a/paddlespeech/cli/ssl/infer.py +++ b/paddlespeech/cli/ssl/infer.py @@ -150,10 +150,10 @@ class SSLExecutor(BaseExecutor): model_prefix = 'wav2vec2ASR_librispeech' elif lang == 'zh': model_prefix = 'wav2vec2ASR_aishell1' - tag = model_prefix + '-' + lang + '-' + sample_rate_str + tag = model_prefix + '-' + lang + '-' + sample_rate_str elif model_type == 'hubert': if lang == 'en': - model_prefix = 'hubertASR_librispeech_100' + model_prefix = 'hubertASR_librispeech-100h' elif lang == 'zh': logger.error("zh hubertASR is not supported yet") tag = model_prefix + '-' + lang + '-' + sample_rate_str @@ -185,16 +185,17 @@ class SSLExecutor(BaseExecutor): self.text_feature = TextFeaturizer( unit_type=self.config.unit_type, vocab=self.config.vocab_filepath) + self.config.output_dim = len(self.config.vocab_filepath) elif lang == 'zh': self.text_feature = AutoTokenizer.from_pretrained( self.config.tokenizer) + self.config.output_dim = self.text_feature.vocab_size self.config.decode.decoding_method = decode_method - model_name = model_type[:model_type.rindex( + model_name = model_prefix[:model_prefix.rindex( '_')] # model_type: {model_name}_{dataset} else: model_name = model_type model_class = self.task_resource.get_model_class(model_name) - model_conf = self.config model = model_class.from_config(model_conf) self.model = model @@ -264,8 +265,7 @@ class SSLExecutor(BaseExecutor): audio = self._inputs["audio"] if task == 'asr': cfg = self.config.decode - logger.debug( - f"we will use the {model_type}ASR like model.") + logger.debug(f"we will use the {model_type}ASR like model.") try: result_transcripts = self.model.decode( audio, @@ -278,7 +278,8 @@ class SSLExecutor(BaseExecutor): logger.exception(e) else: logger.debug( - f"we will use the {model_type} like model to extract audio feature.") + f"we will use the {model_type} like model to extract audio feature." + ) try: out_feature = self.model(audio[:, :, 0]) self._outputs["result"] = out_feature[0] @@ -455,7 +456,7 @@ class SSLExecutor(BaseExecutor): if rtf: k = self.__class__.__name__ CLI_TIMER[k]['start'].append(time.time()) - self.preprocess(model, audio_file) + self.preprocess(audio_file) self.infer(model, task) res = self.postprocess() # Retrieve result of asr. diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py index 8beb7aad8..60b34e7e0 100644 --- a/paddlespeech/resource/pretrained_models.py +++ b/paddlespeech/resource/pretrained_models.py @@ -117,6 +117,38 @@ ssl_dynamic_pretrained_models = { 'exp/wav2vec2ASR/checkpoints/avg_1.pdparams', }, }, + "hubert-en-16k": { + '1.4': { + 'url': + 'https://paddlespeech.bj.bcebos.com/hubert/hubert-large-lv60_ckpt_1.4.0.model.tar.gz', + 'md5': + 'efecfb87a8718aa9253b7459c1fe9b54', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'hubert-large-lv60', + 'model': + 'hubert-large-lv60.pdparams', + 'params': + 'hubert-large-lv60.pdparams', + }, + }, + "hubertASR_librispeech-100h-en-16k": { + '1.4': { + 'url': + 'https://paddlespeech.bj.bcebos.com/hubert/hubertASR-large-100h-librispeech_ckpt_1.4.0.model.tar.gz', + 'md5': + '574cefd11aaef5737969ce22a7f33ea2', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/hubertASR/checkpoints/avg_1', + 'model': + 'exp/hubertASR/checkpoints/avg_1.pdparams', + 'params': + 'exp/hubertASR/checkpoints/avg_1.pdparams', + }, + }, } # --------------------------------- @@ -521,38 +553,6 @@ asr_onnx_pretrained_models = { '29e02312deb2e59b3c8686c7966d4fe3' }, }, - "hubert-en-16k": { - '1.4': { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/hubert-large-lv60_ckpt_1.4.0.model.tar.gz', - 'md5': - '9f0bc943adb822789bf61e674b229d17', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'hubert-large-lv60', - 'model': - 'hubert-large-lv60.pdparams', - 'params': - 'hubert-large-lv60.pdparams', - }, - }, - "hubertASR_librispeech_100-en-16k": { - '1.4': { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/hubertASR-large-100h-librispeech_ckpt_1.4.0.model.tar.gz', - 'md5': - '9f0bc943adb822789bf61e674b229d17', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/wav2vec2ASR/checkpoints/avg_1', - 'model': - 'exp/wav2vec2ASR/checkpoints/avg_1.pdparams', - 'params': - 'exp/wav2vec2ASR/checkpoints/avg_1.pdparams', - }, - }, } whisper_dynamic_pretrained_models = { diff --git a/paddlespeech/s2t/models/hubert/hubert_ASR.py b/paddlespeech/s2t/models/hubert/hubert_ASR.py index 00411029a..df3475897 100644 --- a/paddlespeech/s2t/models/hubert/hubert_ASR.py +++ b/paddlespeech/s2t/models/hubert/hubert_ASR.py @@ -44,13 +44,11 @@ class HubertASR(nn.Layer): init_type = config.get("init_type", None) with DefaultInitializerContext(init_type): self.config = config - with open(config.vocab_filepath) as f: - dicts = [symbol.strip() for symbol in f.readlines()] task_cfg = self.merge_with_parent(HubertPretrainingConfig, dict(self.config.task_cfg)) model_cfg = self.merge_with_parent(HubertConfig, dict(self.config.model_cfg)) - hubert = HubertModel(model_cfg, task_cfg, dicts) + hubert = HubertModel(model_cfg, task_cfg, [None]) self.normalize_wav = config.normalize_wav self.output_norm = config.output_norm @@ -329,13 +327,12 @@ class HubertBase(nn.Layer): def __init__(self, config: dict): super().__init__() - with open(config.vocab_filepath) as f: - dicts = [symbol.strip() for symbol in f.readlines()] + self.config = config task_cfg = self.merge_with_parent(HubertPretrainingConfig, dict(self.config.task_cfg)) model_cfg = self.merge_with_parent(HubertConfig, dict(self.config.model_cfg)) - hubert = HubertModel(model_cfg, task_cfg, dicts) + hubert = HubertModel(model_cfg, task_cfg, [None]) self.hubert = hubert @classmethod @@ -351,6 +348,21 @@ class HubertBase(nn.Layer): model = cls(configs) return model + def merge_with_parent(self, dc: dataclass, cfg: dict): + assert is_dataclass(dc) + assert type(cfg) == dict + cfg = deepcopy(cfg) + + def fix_cfg(cfg): + target_keys = set(dc.__dataclass_fields__.keys()) + for k in list(cfg.keys()): + if k not in target_keys: + del cfg[k] + + fix_cfg(cfg) + assert len(cfg) > 0 + return dc(**cfg) + def forward(self, wav): - out = self.hubert(wav) + out = self.hubert.extract_features(wav) return out diff --git a/paddlespeech/s2t/models/hubert/modules/__init__.py b/paddlespeech/s2t/models/hubert/modules/__init__.py new file mode 100644 index 000000000..595add0ae --- /dev/null +++ b/paddlespeech/s2t/models/hubert/modules/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/s2t/models/hubert/modules/hubert_model.py b/paddlespeech/s2t/models/hubert/modules/hubert_model.py index dc30d9ee6..46f4d9bc5 100644 --- a/paddlespeech/s2t/models/hubert/modules/hubert_model.py +++ b/paddlespeech/s2t/models/hubert/modules/hubert_model.py @@ -377,7 +377,7 @@ class HubertModel(nn.Layer): min_space=self.mask_channel_min_space, ) mask_channel_indices = (paddle.to_tensor( mask_channel_indices, dtype='int64', place=x.place).unsqueeze(1) - .expand(-1, T, -1)) + .expand([-1, T, -1])) x[mask_channel_indices] = 0 return x, mask_indices