support hubert cli

pull/3088/head
th.zhang 2 years ago
parent 7658e54c47
commit 6f3585b199

@ -163,7 +163,7 @@ using the `tar` scripts to unpack the model and then you can use the script to t
For example:
```bash
wget https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/hubertASR-large-100h-librispeech_ckpt_1.4.0.model.tar.gz
wget https://paddlespeech.bj.bcebos.com/hubert/hubertASR-large-100h-librispeech_ckpt_1.4.0.model.tar.gz
tar xzvf hubertASR-large-100h-librispeech_ckpt_1.4.0.model.tar.gz
source path.sh
# If you have process the data and get the manifest file you can skip the following 2 steps
@ -184,7 +184,7 @@ In some situations, you want to use the trained model to do the inference for th
```
you can train the model by yourself using ```bash run.sh --stage 0 --stop_stage 3```, or you can download the pretrained model through the script below:
```bash
wget https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/hubertASR-large-100h-librispeech_ckpt_1.4.0.model.tar.gz
wget https://paddlespeech.bj.bcebos.com/hubert/hubertASR-large-100h-librispeech_ckpt_1.4.0.model.tar.gz
tar xzvf hubertASR-large-100h-librispeech_ckpt_1.4.0.model.tar.gz
```
You can download the audio demo:

@ -150,10 +150,10 @@ class SSLExecutor(BaseExecutor):
model_prefix = 'wav2vec2ASR_librispeech'
elif lang == 'zh':
model_prefix = 'wav2vec2ASR_aishell1'
tag = model_prefix + '-' + lang + '-' + sample_rate_str
tag = model_prefix + '-' + lang + '-' + sample_rate_str
elif model_type == 'hubert':
if lang == 'en':
model_prefix = 'hubertASR_librispeech_100'
model_prefix = 'hubertASR_librispeech-100h'
elif lang == 'zh':
logger.error("zh hubertASR is not supported yet")
tag = model_prefix + '-' + lang + '-' + sample_rate_str
@ -185,16 +185,17 @@ class SSLExecutor(BaseExecutor):
self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type,
vocab=self.config.vocab_filepath)
self.config.output_dim = len(self.config.vocab_filepath)
elif lang == 'zh':
self.text_feature = AutoTokenizer.from_pretrained(
self.config.tokenizer)
self.config.output_dim = self.text_feature.vocab_size
self.config.decode.decoding_method = decode_method
model_name = model_type[:model_type.rindex(
model_name = model_prefix[:model_prefix.rindex(
'_')] # model_type: {model_name}_{dataset}
else:
model_name = model_type
model_class = self.task_resource.get_model_class(model_name)
model_conf = self.config
model = model_class.from_config(model_conf)
self.model = model
@ -264,8 +265,7 @@ class SSLExecutor(BaseExecutor):
audio = self._inputs["audio"]
if task == 'asr':
cfg = self.config.decode
logger.debug(
f"we will use the {model_type}ASR like model.")
logger.debug(f"we will use the {model_type}ASR like model.")
try:
result_transcripts = self.model.decode(
audio,
@ -278,7 +278,8 @@ class SSLExecutor(BaseExecutor):
logger.exception(e)
else:
logger.debug(
f"we will use the {model_type} like model to extract audio feature.")
f"we will use the {model_type} like model to extract audio feature."
)
try:
out_feature = self.model(audio[:, :, 0])
self._outputs["result"] = out_feature[0]
@ -455,7 +456,7 @@ class SSLExecutor(BaseExecutor):
if rtf:
k = self.__class__.__name__
CLI_TIMER[k]['start'].append(time.time())
self.preprocess(model, audio_file)
self.preprocess(audio_file)
self.infer(model, task)
res = self.postprocess() # Retrieve result of asr.

@ -117,6 +117,38 @@ ssl_dynamic_pretrained_models = {
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
},
},
"hubert-en-16k": {
'1.4': {
'url':
'https://paddlespeech.bj.bcebos.com/hubert/hubert-large-lv60_ckpt_1.4.0.model.tar.gz',
'md5':
'efecfb87a8718aa9253b7459c1fe9b54',
'cfg_path':
'model.yaml',
'ckpt_path':
'hubert-large-lv60',
'model':
'hubert-large-lv60.pdparams',
'params':
'hubert-large-lv60.pdparams',
},
},
"hubertASR_librispeech-100h-en-16k": {
'1.4': {
'url':
'https://paddlespeech.bj.bcebos.com/hubert/hubertASR-large-100h-librispeech_ckpt_1.4.0.model.tar.gz',
'md5':
'574cefd11aaef5737969ce22a7f33ea2',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/hubertASR/checkpoints/avg_1',
'model':
'exp/hubertASR/checkpoints/avg_1.pdparams',
'params':
'exp/hubertASR/checkpoints/avg_1.pdparams',
},
},
}
# ---------------------------------
@ -521,38 +553,6 @@ asr_onnx_pretrained_models = {
'29e02312deb2e59b3c8686c7966d4fe3'
},
},
"hubert-en-16k": {
'1.4': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/hubert-large-lv60_ckpt_1.4.0.model.tar.gz',
'md5':
'9f0bc943adb822789bf61e674b229d17',
'cfg_path':
'model.yaml',
'ckpt_path':
'hubert-large-lv60',
'model':
'hubert-large-lv60.pdparams',
'params':
'hubert-large-lv60.pdparams',
},
},
"hubertASR_librispeech_100-en-16k": {
'1.4': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/hubertASR-large-100h-librispeech_ckpt_1.4.0.model.tar.gz',
'md5':
'9f0bc943adb822789bf61e674b229d17',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/wav2vec2ASR/checkpoints/avg_1',
'model':
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
'params':
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
},
},
}
whisper_dynamic_pretrained_models = {

@ -44,13 +44,11 @@ class HubertASR(nn.Layer):
init_type = config.get("init_type", None)
with DefaultInitializerContext(init_type):
self.config = config
with open(config.vocab_filepath) as f:
dicts = [symbol.strip() for symbol in f.readlines()]
task_cfg = self.merge_with_parent(HubertPretrainingConfig,
dict(self.config.task_cfg))
model_cfg = self.merge_with_parent(HubertConfig,
dict(self.config.model_cfg))
hubert = HubertModel(model_cfg, task_cfg, dicts)
hubert = HubertModel(model_cfg, task_cfg, [None])
self.normalize_wav = config.normalize_wav
self.output_norm = config.output_norm
@ -329,13 +327,12 @@ class HubertBase(nn.Layer):
def __init__(self, config: dict):
super().__init__()
with open(config.vocab_filepath) as f:
dicts = [symbol.strip() for symbol in f.readlines()]
self.config = config
task_cfg = self.merge_with_parent(HubertPretrainingConfig,
dict(self.config.task_cfg))
model_cfg = self.merge_with_parent(HubertConfig,
dict(self.config.model_cfg))
hubert = HubertModel(model_cfg, task_cfg, dicts)
hubert = HubertModel(model_cfg, task_cfg, [None])
self.hubert = hubert
@classmethod
@ -351,6 +348,21 @@ class HubertBase(nn.Layer):
model = cls(configs)
return model
def merge_with_parent(self, dc: dataclass, cfg: dict):
assert is_dataclass(dc)
assert type(cfg) == dict
cfg = deepcopy(cfg)
def fix_cfg(cfg):
target_keys = set(dc.__dataclass_fields__.keys())
for k in list(cfg.keys()):
if k not in target_keys:
del cfg[k]
fix_cfg(cfg)
assert len(cfg) > 0
return dc(**cfg)
def forward(self, wav):
out = self.hubert(wav)
out = self.hubert.extract_features(wav)
return out

@ -0,0 +1,13 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -377,7 +377,7 @@ class HubertModel(nn.Layer):
min_space=self.mask_channel_min_space, )
mask_channel_indices = (paddle.to_tensor(
mask_channel_indices, dtype='int64', place=x.place).unsqueeze(1)
.expand(-1, T, -1))
.expand([-1, T, -1]))
x[mask_channel_indices] = 0
return x, mask_indices

Loading…
Cancel
Save