From fc02cd0540fd9d706f9400b30dafde243c7f9753 Mon Sep 17 00:00:00 2001 From: Zth9730 <32243340+Zth9730@users.noreply.github.com> Date: Tue, 22 Nov 2022 17:37:33 +0800 Subject: [PATCH 1/5] [doc] update wav2vec2 demos README.md, test=doc (#2674) * fix wav2vec2 demos, test=doc * fix wav2vec2 demos, test=doc * fix enc_dropout and nor.py, test=asr --- demos/speech_ssl/README.md | 2 +- demos/speech_ssl/README_cn.md | 8 +- .../s2t/models/wav2vec2/modules/VanillaNN.py | 5 +- .../models/wav2vec2/modules/normalization.py | 97 +++++++++++++++++++ 4 files changed, 104 insertions(+), 8 deletions(-) create mode 100644 paddlespeech/s2t/models/wav2vec2/modules/normalization.py diff --git a/demos/speech_ssl/README.md b/demos/speech_ssl/README.md index fdef37e7b..b98a7cc61 100644 --- a/demos/speech_ssl/README.md +++ b/demos/speech_ssl/README.md @@ -82,7 +82,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav Output: ```bash ASR Result: - 我认为跑步最重要的就是给我带来了身体健康 + i knocked at the door on the ancient side of the building Representation: Tensor(shape=[1, 164, 1024], dtype=float32, place=Place(gpu:0), stop_gradient=True, diff --git a/demos/speech_ssl/README_cn.md b/demos/speech_ssl/README_cn.md index 76ec2f1ff..65961ce90 100644 --- a/demos/speech_ssl/README_cn.md +++ b/demos/speech_ssl/README_cn.md @@ -36,9 +36,9 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav ``` 参数: - `input`(必须输入):用于识别的音频文件。 - - `model`:ASR 任务的模型,默认值:`conformer_wenetspeech`。 + - `model`:ASR 任务的模型,默认值:`wav2vec2ASR_librispeech`。 - `task`:输出类别,默认值:`asr`。 - - `lang`:模型语言,默认值:`zh`。 + - `lang`:模型语言,默认值:`en`。 - `sample_rate`:音频采样率,默认值:`16000`。 - `config`:ASR 任务的参数文件,若不设置则使用预训练模型中的默认配置,默认值:`None`。 - `ckpt_path`:模型参数文件,若不设置则下载预训练模型使用,默认值:`None`。 @@ -83,8 +83,8 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav 输出: ```bash ASR Result: - 我认为跑步最重要的就是给我带来了身体健康 - + i knocked at the door on the ancient side of the building + Representation: Tensor(shape=[1, 164, 1024], dtype=float32, place=Place(gpu:0), stop_gradient=True, [[[ 0.02351918, -0.12980647, 0.17868176, ..., 0.10118122, diff --git a/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py b/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py index 82313c330..9c88796bb 100644 --- a/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py +++ b/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py @@ -46,7 +46,7 @@ class VanillaNN(containers.Sequential): dnn_neurons=512, activation=True, normalization=False, - dropout_rate=0.0): + dropout_rate=0.5): super().__init__(input_shape=[None, None, input_shape]) if not isinstance(dropout_rate, list): @@ -68,6 +68,5 @@ class VanillaNN(containers.Sequential): if activation: self.append(paddle.nn.LeakyReLU(), layer_name="act") self.append( - paddle.nn.Dropout(), - p=dropout_rate[block_index], + paddle.nn.Dropout(p=dropout_rate[block_index]), layer_name='dropout') diff --git a/paddlespeech/s2t/models/wav2vec2/modules/normalization.py b/paddlespeech/s2t/models/wav2vec2/modules/normalization.py new file mode 100644 index 000000000..912981058 --- /dev/null +++ b/paddlespeech/s2t/models/wav2vec2/modules/normalization.py @@ -0,0 +1,97 @@ +# Authors +# * Mirco Ravanelli 2020 +# * Guillermo Cámbara 2021 +# * Sarthak Yadav 2022 +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from speechbrain(https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/nnet/normalization.py) +import paddle.nn as nn + +from paddlespeech.s2t.modules.align import BatchNorm1D + + +class BatchNorm1d(nn.Layer): + """Applies 1d batch normalization to the input tensor. + Arguments + --------- + input_shape : tuple + The expected shape of the input. Alternatively, use ``input_size``. + input_size : int + The expected size of the input. Alternatively, use ``input_shape``. + eps : float + This value is added to std deviation estimation to improve the numerical + stability. + momentum : float + It is a value used for the running_mean and running_var computation. + affine : bool + When set to True, the affine parameters are learned. + track_running_stats : bool + When set to True, this module tracks the running mean and variance, + and when set to False, this module does not track such statistics. + combine_batch_time : bool + When true, it combines batch an time axis. + Example + ------- + >>> input = paddle.randn([100, 10]) + >>> norm = BatchNorm1d(input_shape=input.shape) + >>> output = norm(input) + >>> output.shape + Paddle.Shape([100, 10]) + """ + + def __init__( + self, + input_shape=None, + input_size=None, + eps=1e-05, + momentum=0.9, + combine_batch_time=False, + skip_transpose=False, ): + super().__init__() + self.combine_batch_time = combine_batch_time + self.skip_transpose = skip_transpose + + if input_size is None and skip_transpose: + input_size = input_shape[1] + elif input_size is None: + input_size = input_shape[-1] + + self.norm = BatchNorm1D(input_size, momentum=momentum, epsilon=eps) + + def forward(self, x): + """Returns the normalized input tensor. + Arguments + --------- + x : paddle.Tensor (batch, time, [channels]) + input to normalize. 2d or 3d tensors are expected in input + 4d tensors can be used when combine_dims=True. + """ + shape_or = x.shape + if self.combine_batch_time: + if x.ndim == 3: + x = x.reshape(shape_or[0] * shape_or[1], shape_or[2]) + else: + x = x.reshape(shape_or[0] * shape_or[1], shape_or[3], + shape_or[2]) + + elif not self.skip_transpose: + x = x.transpose([0, 2, 1]) + + x_n = self.norm(x) + if self.combine_batch_time: + x_n = x_n.reshape(shape_or) + elif not self.skip_transpose: + x_n = x_n.transpose([0, 2, 1]) + + return x_n From 45426846942f68cf43a23677d8d55f6d4ab93ab1 Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Wed, 23 Nov 2022 11:06:49 +0800 Subject: [PATCH 2/5] [ASR] fix Whisper cli model download path error. test=asr (#2679) * add all whisper model size support * add choices in parser. * fix Whisper cli model download path error. * fix resource download path. * fix code style --- demos/whisper/README.md | 4 +- demos/whisper/README_cn.md | 4 +- paddlespeech/cli/whisper/infer.py | 14 +-- paddlespeech/resource/pretrained_models.py | 90 ++++++++------------ paddlespeech/s2t/exps/whisper/test_wav.py | 3 +- paddlespeech/s2t/models/whisper/tokenizer.py | 8 +- paddlespeech/s2t/models/whisper/whipser.py | 68 +++++++++------ 7 files changed, 97 insertions(+), 94 deletions(-) diff --git a/demos/whisper/README.md b/demos/whisper/README.md index 017eb93ef..9b12554e6 100644 --- a/demos/whisper/README.md +++ b/demos/whisper/README.md @@ -61,7 +61,7 @@ Whisper model trained by OpenAI whisper https://github.com/openai/whisper # to recognize text text = whisper_executor( - model='whisper-large', + model='whisper', task='transcribe', sample_rate=16000, config=None, # Set `config` and `ckpt_path` to None to use pretrained model. @@ -72,7 +72,7 @@ Whisper model trained by OpenAI whisper https://github.com/openai/whisper # to recognize text and translate to English feature = whisper_executor( - model='whisper-large', + model='whisper', task='translate', sample_rate=16000, config=None, # Set `config` and `ckpt_path` to None to use pretrained model. diff --git a/demos/whisper/README_cn.md b/demos/whisper/README_cn.md index 4da079553..6f7c35f04 100644 --- a/demos/whisper/README_cn.md +++ b/demos/whisper/README_cn.md @@ -61,7 +61,7 @@ Whisper模型由OpenAI Whisper训练 https://github.com/openai/whisper # 识别文本 text = whisper_executor( - model='whisper-large', + model='whisper', task='transcribe', sample_rate=16000, config=None, # Set `config` and `ckpt_path` to None to use pretrained model. @@ -72,7 +72,7 @@ Whisper模型由OpenAI Whisper训练 https://github.com/openai/whisper # 将语音翻译成英语 feature = whisper_executor( - model='whisper-large', + model='whisper', task='translate', sample_rate=16000, config=None, # Set `config` and `ckpt_path` to None to use pretrained model. diff --git a/paddlespeech/cli/whisper/infer.py b/paddlespeech/cli/whisper/infer.py index b6b461f62..c016b453a 100644 --- a/paddlespeech/cli/whisper/infer.py +++ b/paddlespeech/cli/whisper/infer.py @@ -27,6 +27,7 @@ import paddle import soundfile from yacs.config import CfgNode +from ...utils.env import DATA_HOME from ..download import get_path_from_url from ..executor import BaseExecutor from ..log import logger @@ -187,10 +188,12 @@ class WhisperExecutor(BaseExecutor): with UpdateConfig(self.config): if "whisper" in model_type: - resource_url = self.task_resource.res_dict['resuource_data'] - resource_md5 = self.task_resource.res_dict['resuource_data_md5'] - resuource_path = self.task_resource.res_dict['resuource_path'] - self.download_resource(resource_url, resuource_path, + resource_url = self.task_resource.res_dict['resource_data'] + resource_md5 = self.task_resource.res_dict['resource_data_md5'] + + self.resource_path = os.path.join( + DATA_HOME, self.task_resource.version, 'whisper') + self.download_resource(resource_url, self.resource_path, resource_md5) else: raise Exception("wrong type") @@ -249,7 +252,7 @@ class WhisperExecutor(BaseExecutor): logger.debug(f"audio shape: {audio.shape}") # fbank - audio = log_mel_spectrogram(audio) + audio = log_mel_spectrogram(audio, resource_path=self.resource_path) audio_len = paddle.to_tensor(audio.shape[0]) @@ -279,6 +282,7 @@ class WhisperExecutor(BaseExecutor): verbose=cfg.verbose, task=self.task, language=self.language, + resource_path=self.resource_path, temperature=temperature, compression_ratio_threshold=cfg.compression_ratio_threshold, logprob_threshold=cfg.logprob_threshold, diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py index 85b41e685..067246749 100644 --- a/paddlespeech/resource/pretrained_models.py +++ b/paddlespeech/resource/pretrained_models.py @@ -468,9 +468,9 @@ whisper_dynamic_pretrained_models = { "whisper-large-16k": { '1.3': { 'url': - 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/whisper-large-model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-large-model.tar.gz', 'md5': - '364c4d670835e5ca489045e1c29d75fe', + 'cf1557af9d8ffa493fefad9cb08ae189', 'cfg_path': 'whisper.yaml', 'ckpt_path': @@ -479,20 +479,18 @@ whisper_dynamic_pretrained_models = { 'whisper-large-model.pdparams', 'params': 'whisper-large-model.pdparams', - 'resuource_data': + 'resource_data': 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar', - 'resuource_data_md5': + 'resource_data_md5': '37a0a8abdb3641a51194f79567a93b61', - 'resuource_path': - 'paddlespeech/s2t/models/whisper', }, }, "whisper-base-en-16k": { '1.3': { 'url': - 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-base-en-model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-base-en-model.tar.gz', 'md5': - 'f5bb8cdff42c7031d9e4c0ea20f7ceee', + 'b156529aefde6beb7726d2ea98fd067a', 'cfg_path': 'whisper.yaml', 'ckpt_path': @@ -501,20 +499,18 @@ whisper_dynamic_pretrained_models = { 'whisper-base-en-model.pdparams', 'params': 'whisper-base-en-model.pdparams', - 'resuource_data': + 'resource_data': 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar', - 'resuource_data_md5': + 'resource_data_md5': '37a0a8abdb3641a51194f79567a93b61', - 'resuource_path': - 'paddlespeech/s2t/models/whisper', }, }, "whisper-base-16k": { '1.3': { 'url': - 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-base-model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-base-model.tar.gz', 'md5': - '46f254e89a01b71586af1a46d28d7ce9', + '6b012a5abd583db14398c3492e47120b', 'cfg_path': 'whisper.yaml', 'ckpt_path': @@ -523,20 +519,18 @@ whisper_dynamic_pretrained_models = { 'whisper-base-model.pdparams', 'params': 'whisper-base-model.pdparams', - 'resuource_data': + 'resource_data': 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar', - 'resuource_data_md5': + 'resource_data_md5': '37a0a8abdb3641a51194f79567a93b61', - 'resuource_path': - 'paddlespeech/s2t/models/whisper', }, }, "whisper-medium-en-16k": { '1.3': { 'url': - 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-medium-en-model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-medium-en-model.tar.gz', 'md5': - '98228f3ba94636c2760b51e5f3d6885f', + 'c7f57d270bd20c7b170ba9dcf6c16f74', 'cfg_path': 'whisper.yaml', 'ckpt_path': @@ -545,20 +539,18 @@ whisper_dynamic_pretrained_models = { 'whisper-medium-en-model.pdparams', 'params': 'whisper-medium-en-model.pdparams', - 'resuource_data': + 'resource_data': 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar', - 'resuource_data_md5': + 'resource_data_md5': '37a0a8abdb3641a51194f79567a93b61', - 'resuource_path': - 'paddlespeech/s2t/models/whisper', }, }, "whisper-medium-16k": { '1.3': { 'url': - 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-medium-model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-medium-model.tar.gz', 'md5': - '51ac154b264db75492ed1cc5280baebf', + '4c7dcd0df25f408199db4a4548336786', 'cfg_path': 'whisper.yaml', 'ckpt_path': @@ -567,20 +559,18 @@ whisper_dynamic_pretrained_models = { 'whisper-medium-model.pdparams', 'params': 'whisper-medium-model.pdparams', - 'resuource_data': + 'resource_data': 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar', - 'resuource_data_md5': + 'resource_data_md5': '37a0a8abdb3641a51194f79567a93b61', - 'resuource_path': - 'paddlespeech/s2t/models/whisper', }, }, "whisper-small-en-16k": { '1.3': { 'url': - 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-small-en-model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-small-en-model.tar.gz', 'md5': - '973b784a335580a393e13a13995b110a', + '2b24efcb2e93f3275af7c0c7f598ff1c', 'cfg_path': 'whisper.yaml', 'ckpt_path': @@ -589,20 +579,18 @@ whisper_dynamic_pretrained_models = { 'whisper-small-en-model.pdparams', 'params': 'whisper-small-en-model.pdparams', - 'resuource_data': + 'resource_data': 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar', - 'resuource_data_md5': + 'resource_data_md5': '37a0a8abdb3641a51194f79567a93b61', - 'resuource_path': - 'paddlespeech/s2t/models/whisper', }, }, "whisper-small-16k": { '1.3': { 'url': - 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-small-model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-small-model.tar.gz', 'md5': - '57a7530851cc98631c6fb29c606489c6', + '5a57911dd41651dd6ed78c5763912825', 'cfg_path': 'whisper.yaml', 'ckpt_path': @@ -611,20 +599,18 @@ whisper_dynamic_pretrained_models = { 'whisper-small-model.pdparams', 'params': 'whisper-small-model.pdparams', - 'resuource_data': + 'resource_data': 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar', - 'resuource_data_md5': + 'resource_data_md5': '37a0a8abdb3641a51194f79567a93b61', - 'resuource_path': - 'paddlespeech/s2t/models/whisper', }, }, "whisper-tiny-en-16k": { '1.3': { 'url': - 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-tiny-en-model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-tiny-en-model.tar.gz', 'md5': - '3ef5c0777e0bd4a1a240895167b0eb0d', + '14969164a3f713fd58e56978c34188f6', 'cfg_path': 'whisper.yaml', 'ckpt_path': @@ -633,20 +619,18 @@ whisper_dynamic_pretrained_models = { 'whisper-tiny-en-model.pdparams', 'params': 'whisper-tiny-en-model.pdparams', - 'resuource_data': + 'resource_data': 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar', - 'resuource_data_md5': + 'resource_data_md5': '37a0a8abdb3641a51194f79567a93b61', - 'resuource_path': - 'paddlespeech/s2t/models/whisper', }, }, "whisper-tiny-16k": { '1.3': { 'url': - 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-tiny-model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-tiny-model.tar.gz', 'md5': - 'ddf232cd16c85120e89c870a53451e53', + 'a5b82a1f2067a2ca400f17fabd62b81b', 'cfg_path': 'whisper.yaml', 'ckpt_path': @@ -655,12 +639,10 @@ whisper_dynamic_pretrained_models = { 'whisper-tiny-model.pdparams', 'params': 'whisper-tiny-model.pdparams', - 'resuource_data': + 'resource_data': 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar', - 'resuource_data_md5': + 'resource_data_md5': '37a0a8abdb3641a51194f79567a93b61', - 'resuource_path': - 'paddlespeech/s2t/models/whisper', }, }, } diff --git a/paddlespeech/s2t/exps/whisper/test_wav.py b/paddlespeech/s2t/exps/whisper/test_wav.py index 63945b9eb..e04eec4f2 100644 --- a/paddlespeech/s2t/exps/whisper/test_wav.py +++ b/paddlespeech/s2t/exps/whisper/test_wav.py @@ -62,7 +62,8 @@ class WhisperInfer(): temperature = [temperature] #load audio - mel = log_mel_spectrogram(args.audio) + mel = log_mel_spectrogram( + args.audio_file, resource_path=config.resource_path) result = transcribe( self.model, mel, temperature=temperature, **config) diff --git a/paddlespeech/s2t/models/whisper/tokenizer.py b/paddlespeech/s2t/models/whisper/tokenizer.py index 1c58c94c7..8bd85c914 100644 --- a/paddlespeech/s2t/models/whisper/tokenizer.py +++ b/paddlespeech/s2t/models/whisper/tokenizer.py @@ -298,9 +298,9 @@ class Tokenizer: @lru_cache(maxsize=None) -def build_tokenizer(name: str="gpt2"): +def build_tokenizer(resource_path: str, name: str="gpt2"): os.environ["TOKENIZERS_PARALLELISM"] = "false" - path = os.path.join(os.path.dirname(__file__), "assets", name) + path = os.path.join(resource_path, "assets", name) tokenizer = GPTTokenizer.from_pretrained(path) specials = [ @@ -321,6 +321,7 @@ def build_tokenizer(name: str="gpt2"): @lru_cache(maxsize=None) def get_tokenizer( multilingual: bool, + resource_path: str, *, task: Optional[str]=None, # Literal["transcribe", "translate", None] language: Optional[str]=None, ) -> Tokenizer: @@ -341,7 +342,8 @@ def get_tokenizer( task = None language = None - tokenizer = build_tokenizer(name=tokenizer_name) + tokenizer = build_tokenizer( + resource_path=resource_path, name=tokenizer_name) all_special_ids: List[int] = tokenizer.all_special_ids sot: int = all_special_ids[1] translate: int = all_special_ids[-6] diff --git a/paddlespeech/s2t/models/whisper/whipser.py b/paddlespeech/s2t/models/whisper/whipser.py index c01a09485..ba9983338 100644 --- a/paddlespeech/s2t/models/whisper/whipser.py +++ b/paddlespeech/s2t/models/whisper/whipser.py @@ -1,6 +1,6 @@ # MIT License, Copyright (c) 2022 OpenAI. # Copyright (c) 2022 PaddlePaddle Authors and . All Rights Reserved. -# +# # Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper) import os from dataclasses import dataclass @@ -265,7 +265,6 @@ class DecodingOptions: task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate" language: Optional[ str] = None # language that the audio is in; uses detected language if None - # sampling-related options temperature: float = 0.0 sample_len: Optional[int] = None # maximum number of tokens to sample @@ -361,10 +360,11 @@ class WhisperInference(Inference): @paddle.no_grad() -def detect_language(model: "Whisper", - mel: paddle.Tensor, - tokenizer: Tokenizer=None - ) -> Tuple[paddle.Tensor, List[dict]]: +def detect_language( + model: "Whisper", + mel: paddle.Tensor, + resource_path: str, + tokenizer: Tokenizer=None) -> Tuple[paddle.Tensor, List[dict]]: """ Detect the spoken language in the audio, and return them as list of strings, along with the ids of the most probable language tokens and the probability distribution over all language tokens. @@ -378,7 +378,8 @@ def detect_language(model: "Whisper", list of dictionaries containing the probability distribution over all languages. """ if tokenizer is None: - tokenizer = get_tokenizer(model.is_multilingual) + tokenizer = get_tokenizer( + model.is_multilingual, resource_path=resource_path) if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence: raise ValueError( "This model doesn't have language tokens so it can't perform lang id" @@ -419,6 +420,7 @@ def detect_language(model: "Whisper", def transcribe( model: "Whisper", mel: paddle.Tensor, + resource_path: str, *, verbose: Optional[bool]=None, temperature: Union[float, Tuple[float, ...]]=(0.0, 0.2, 0.4, 0.6, 0.8, @@ -485,7 +487,7 @@ def transcribe( "Detecting language using up to the first 30 seconds. Use `--language` to specify the language" ) segment = pad_or_trim(mel, N_FRAMES) - _, probs = model.detect_language(segment) + _, probs = model.detect_language(segment, resource_path) decode_options["language"] = max(probs, key=probs.get) if verbose is not None: print( @@ -495,7 +497,10 @@ def transcribe( language = decode_options["language"] task = decode_options.get("task", "transcribe") tokenizer = get_tokenizer( - model.is_multilingual, language=language, task=task) + model.is_multilingual, + resource_path=resource_path, + language=language, + task=task) def decode_with_fallback(segment: paddle.Tensor) -> DecodingResult: temperatures = [temperature] if isinstance(temperature, ( @@ -513,7 +518,7 @@ def transcribe( kwargs.pop("best_of", None) options = DecodingOptions(**kwargs, temperature=t) - decode_result = model.decode(segment, options) + decode_result = model.decode(segment, options, resource_path) needs_fallback = False if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold: @@ -979,14 +984,21 @@ class DecodingTask: decoder: TokenDecoder logit_filters: List[LogitFilter] - def __init__(self, model: "Whisper", options: DecodingOptions): + def __init__(self, + model: "Whisper", + options: DecodingOptions, + resource_path: str): self.model = model language = options.language or "en" tokenizer = get_tokenizer( - model.is_multilingual, language=language, task=options.task) + model.is_multilingual, + resource_path=resource_path, + language=language, + task=options.task) self.tokenizer: Tokenizer = tokenizer self.options: DecodingOptions = self._verify_options(options) + self.resource_path: str = resource_path self.beam_size: int = options.beam_size or options.best_of or 1 self.n_ctx: int = model.dims.n_text_ctx @@ -1112,13 +1124,14 @@ class DecodingTask: def _detect_language(self, audio_features: paddle.Tensor, - tokens: paddle.Tensor): + tokens: paddle.Tensor, + resource_path: str): languages = [self.options.language] * audio_features.shape[0] lang_probs = None if self.options.language is None or self.options.task == "lang_id": - lang_tokens, lang_probs = self.model.detect_language(audio_features, - self.tokenizer) + lang_tokens, lang_probs = self.model.detect_language( + audio_features, self.tokenizer, self.resource_path) languages = [max(probs, key=probs.get) for probs in lang_probs] if self.options.language is None: tokens[:, self.sot_index + @@ -1185,7 +1198,8 @@ class DecodingTask: # detect language if requested, overwriting the language token languages, language_probs = self._detect_language( - paddle.to_tensor(audio_features), paddle.to_tensor(tokens)) + paddle.to_tensor(audio_features), + paddle.to_tensor(tokens), self.resource_path) if self.options.task == "lang_id": return [ @@ -1254,10 +1268,11 @@ class DecodingTask: @paddle.no_grad() -def decode(model: "Whisper", - mel: paddle.Tensor, - options: DecodingOptions=DecodingOptions() - ) -> Union[DecodingResult, List[DecodingResult]]: +def decode( + model: "Whisper", + mel: paddle.Tensor, + options: DecodingOptions=DecodingOptions(), + resource_path=str, ) -> Union[DecodingResult, List[DecodingResult]]: """ Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s). @@ -1281,7 +1296,7 @@ def decode(model: "Whisper", if single: mel = mel.unsqueeze(0) - result = DecodingTask(model, options).run(mel) + result = DecodingTask(model, options, resource_path).run(mel) if single: result = result[0] @@ -1407,7 +1422,7 @@ def hann_window(n_fft: int=N_FFT): @lru_cache(maxsize=None) -def mel_filters(device, n_mels: int=N_MELS) -> paddle.Tensor: +def mel_filters(resource_path: str, n_mels: int=N_MELS) -> paddle.Tensor: """ load the mel filterbank matrix for projecting STFT into a Mel spectrogram. Allows decoupling librosa dependency; saved using: @@ -1418,14 +1433,13 @@ def mel_filters(device, n_mels: int=N_MELS) -> paddle.Tensor: ) """ assert n_mels == 80, f"Unsupported n_mels: {n_mels}" - with np.load( - os.path.join( - os.path.dirname(__file__), "assets", "mel_filters.npz")) as f: + with np.load(os.path.join(resource_path, "assets", "mel_filters.npz")) as f: return paddle.to_tensor(f[f"mel_{n_mels}"]) def log_mel_spectrogram(audio: Union[str, np.ndarray, paddle.Tensor], - n_mels: int=N_MELS): + n_mels: int=N_MELS, + resource_path: str=None): """ Compute the log-Mel spectrogram of @@ -1454,7 +1468,7 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, paddle.Tensor], magnitudes = stft[:, :-1].abs()**2 - filters = mel_filters(audio, n_mels) + filters = mel_filters(resource_path, n_mels) mel_spec = filters @ magnitudes mel_spec = paddle.to_tensor(mel_spec.numpy().tolist()) From 58309aa9d716949150fdd8cebc3eabba3d1267ec Mon Sep 17 00:00:00 2001 From: heyudage <1143790582@qq.com> Date: Fri, 25 Nov 2022 21:09:58 +0800 Subject: [PATCH 3/5] update docs test=doc (#2688) --- README.md | 1 + README_cn.md | 1 + 2 files changed, 2 insertions(+) diff --git a/README.md b/README.md index 2321920de..32e1c23d8 100644 --- a/README.md +++ b/README.md @@ -981,6 +981,7 @@ You are warmly welcome to submit questions in [discussions](https://github.com/P - Many thanks to [jerryuhoo](https://github.com/jerryuhoo)/[VTuberTalk](https://github.com/jerryuhoo/VTuberTalk) for developing a GUI tool based on PaddleSpeech TTS and code for making datasets from videos based on PaddleSpeech ASR. - Many thanks to [vpegasus](https://github.com/vpegasus)/[xuesebot](https://github.com/vpegasus/xuesebot) for developing a rasa chatbot,which is able to speak and listen thanks to PaddleSpeech. - Many thanks to [chenkui164](https://github.com/chenkui164)/[FastASR](https://github.com/chenkui164/FastASR) for the C++ inference implementation of PaddleSpeech ASR. +- Many thanks to [heyudage](https://github.com/heyudage)/[VoiceTyping](https://github.com/heyudage/VoiceTyping) for the real-time voice typing tool implementation of PaddleSpeech ASR streaming services. Besides, PaddleSpeech depends on a lot of open source repositories. See [references](./docs/source/reference.md) for more information. diff --git a/README_cn.md b/README_cn.md index 8127c5570..427d59caf 100644 --- a/README_cn.md +++ b/README_cn.md @@ -987,6 +987,7 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声 - 非常感谢 [vpegasus](https://github.com/vpegasus)/[xuesebot](https://github.com/vpegasus/xuesebot) 基于 PaddleSpeech 的 ASR 与 TTS 设计的可听、说对话机器人。 - 非常感谢 [chenkui164](https://github.com/chenkui164)/[FastASR](https://github.com/chenkui164/FastASR) 对 PaddleSpeech 的 ASR 进行 C++ 推理实现。 +- 非常感谢 [heyudage](https://github.com/heyudage)/[VoiceTyping](https://github.com/heyudage/VoiceTyping) 基于 PaddleSpeech 的 ASR 流式服务实现的实时语音输入法工具。 此外,PaddleSpeech 依赖于许多开源存储库。有关更多信息,请参阅 [references](./docs/source/reference.md)。 From bd01bc155de267202588a821ccb0695952059e23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20An=20=EF=BC=88An=20Hongliang=EF=BC=89?= Date: Mon, 28 Nov 2022 14:54:22 +0800 Subject: [PATCH 4/5] add greek char and fix issue2571 (#2683) Co-authored-by: TianYuan --- .../zh_normalization/text_normlization.py | 31 +++++++++++++++++-- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/paddlespeech/t2s/frontend/zh_normalization/text_normlization.py b/paddlespeech/t2s/frontend/zh_normalization/text_normlization.py index 1942e6661..1250e96ca 100644 --- a/paddlespeech/t2s/frontend/zh_normalization/text_normlization.py +++ b/paddlespeech/t2s/frontend/zh_normalization/text_normlization.py @@ -65,7 +65,7 @@ class TextNormalizer(): if lang == "zh": text = text.replace(" ", "") # 过滤掉特殊字符 - text = re.sub(r'[《》【】<=>{}()()#&@“”^_|…\\]', '', text) + text = re.sub(r'[——《》【】<=>{}()()#&@“”^_|…\\]', '', text) text = self.SENTENCE_SPLITOR.sub(r'\1\n', text) text = text.strip() sentences = [sentence.strip() for sentence in re.split(r'\n+', text)] @@ -85,7 +85,33 @@ class TextNormalizer(): sentence = sentence.replace('⑧', '八') sentence = sentence.replace('⑨', '九') sentence = sentence.replace('⑩', '十') - + sentence = sentence.replace('α', '阿尔法') + sentence = sentence.replace('β', '贝塔') + sentence = sentence.replace('γ', '伽玛').replace('Γ', '伽玛') + sentence = sentence.replace('δ', '德尔塔').replace('Δ', '德尔塔') + sentence = sentence.replace('ε', '艾普西龙') + sentence = sentence.replace('ζ', '捷塔') + sentence = sentence.replace('η', '依塔') + sentence = sentence.replace('θ', '西塔').replace('Θ', '西塔') + sentence = sentence.replace('ι', '艾欧塔') + sentence = sentence.replace('κ', '喀帕') + sentence = sentence.replace('λ', '拉姆达').replace('Λ', '拉姆达') + sentence = sentence.replace('μ', '缪') + sentence = sentence.replace('ν', '拗') + sentence = sentence.replace('ξ', '克西').replace('Ξ', '克西') + sentence = sentence.replace('ο', '欧米克伦') + sentence = sentence.replace('π', '派').replace('Π', '派') + sentence = sentence.replace('ρ', '肉') + sentence = sentence.replace('ς', '西格玛').replace('Σ', '西格玛').replace( + 'σ', '西格玛') + sentence = sentence.replace('τ', '套') + sentence = sentence.replace('υ', '宇普西龙') + sentence = sentence.replace('φ', '服艾').replace('Φ', '服艾') + sentence = sentence.replace('χ', '器') + sentence = sentence.replace('ψ', '普赛').replace('Ψ', '普赛') + sentence = sentence.replace('ω', '欧米伽').replace('Ω', '欧米伽') + # re filter special characters, have one more character "-" than line 68 + sentence = re.sub(r'[-——《》【】<=>{}()()#&@“”^_|…\\]', '', sentence) return sentence def normalize_sentence(self, sentence: str) -> str: @@ -124,6 +150,5 @@ class TextNormalizer(): def normalize(self, text: str) -> List[str]: sentences = self._split(text) - sentences = [self.normalize_sentence(sent) for sent in sentences] return sentences From a01c163dc359176fc2a71ea8a7e94db624c7f503 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 29 Nov 2022 11:30:01 +0800 Subject: [PATCH 5/5] [speechx] more doc of speechx u2 and ds2 onnx (#2692) * more doc of speechx u2 onnx --- speechx/examples/ds2_ol/onnx/README.md | 23 ++++--- .../examples/u2pp_ol/wenetspeech/README.md | 62 +++++++++++++++++-- speechx/examples/u2pp_ol/wenetspeech/run.sh | 3 + 3 files changed, 72 insertions(+), 16 deletions(-) diff --git a/speechx/examples/ds2_ol/onnx/README.md b/speechx/examples/ds2_ol/onnx/README.md index e6ab953c8..b98b74b6f 100644 --- a/speechx/examples/ds2_ol/onnx/README.md +++ b/speechx/examples/ds2_ol/onnx/README.md @@ -1,11 +1,8 @@ -# DeepSpeech2 to ONNX model +# Convert DeepSpeech2 model to ONNX format -1. convert deepspeech2 model to ONNX, using Paddle2ONNX. -2. check paddleinference and onnxruntime output equal. -3. optimize onnx model -4. check paddleinference and optimized onnxruntime output equal. -5. quantize onnx model -4. check paddleinference and optimized onnxruntime output equal. +> We recommend using U2/U2++ model instead of DS2, please see [here](../../u2pp_ol/wenetspeech/). + +This example demonstrate converting ds2 model to ONNX fromat. Please make sure [Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX) and [onnx-simplifier](https://github.com/zh794390558/onnx-simplifier/tree/dyn_time_shape) version is correct. @@ -25,18 +22,24 @@ onnxoptimizer 0.2.7 onnxruntime 1.11.0 ``` + ## Using ``` bash run.sh --stage 0 --stop_stage 5 ``` +1. convert deepspeech2 model to ONNX, using Paddle2ONNX. +2. check paddleinference and onnxruntime output equal. +3. optimize onnx model +4. check paddleinference and optimized onnxruntime output equal. +5. quantize onnx model +6. check paddleinference and optimized onnxruntime output equal. + For more details please see `run.sh`. ## Outputs -The optimized onnx model is `exp/model.opt.onnx`, quanted model is `$exp/model.optset11.quant.onnx`. - -To show the graph, please using `local/netron.sh`. +The optimized onnx model is `exp/model.opt.onnx`, quanted model is `exp/model.optset11.quant.onnx`. ## [Results](https://github.com/PaddlePaddle/PaddleSpeech/wiki/ASR-Benchmark#streaming-asr) diff --git a/speechx/examples/u2pp_ol/wenetspeech/README.md b/speechx/examples/u2pp_ol/wenetspeech/README.md index b90b8e201..6ca8f6dd8 100644 --- a/speechx/examples/u2pp_ol/wenetspeech/README.md +++ b/speechx/examples/u2pp_ol/wenetspeech/README.md @@ -1,27 +1,77 @@ -# u2/u2pp Streaming ASR +# U2/U2++ Streaming ASR + +A C++ deployment example for `PaddleSpeech/examples/wenetspeech/asr1` recipe. The model is static model from `export`, how to export model please see [here](../../../../examples/wenetspeech/asr1/). If you want using exported model, `run.sh` will download it, for the model link please see `run.sh`. + +This example will demonstrate how to using the u2/u2++ model to recognize `wav` and compute `CER`. We using AISHELL-1 as test data. ## Testing with Aishell Test Data -### Download wav and model +### Source `path.sh` first + +```bash +source path.sh +``` + +All bins are under `echo $SPEECHX_BUILD` dir. + +### Download dataset and model ``` ./run.sh --stop_stage 0 ``` -### compute feature +### process `cmvn` and compute feature -``` +```bash ./run.sh --stage 1 --stop_stage 1 ``` -### decoding using feature +If you only want to convert `cmvn` file format, can using this cmd: + +```bash +./local/feat.sh --stage 1 --stop_stage 1 +``` + +### Decoding using `feature` input ``` ./run.sh --stage 2 --stop_stage 2 ``` -### decoding using wav +### Decoding using `wav` input ``` ./run.sh --stage 3 --stop_stage 3 ``` + +This stage using `u2_recognizer_main` to recognize wav file. + +The input is `scp` file which look like this: +```text +# head data/split1/1/aishell_test.scp +BAC009S0764W0121 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0121.wav +BAC009S0764W0122 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0122.wav +... +BAC009S0764W0125 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0125.wav +``` + +If you want to recognize one wav, you can make `scp` file like this: +```text +key path/to/wav/file +``` + +Then specify `--wav_rspecifier=` param for `u2_recognizer_main` bin. For other flags meaning, please see `help`: +```bash +u2_recognizer_main --help +``` + +The exmaple using `u2_recgonize_main` bin please see `local/recognizer.sh`. + +### Decoding with `wav` using quant model + +`local/recognizer_quant.sh` is same to `local/recognizer.sh`, but using quanted model. + + +## Results + +Please see [here](./RESULTS.md). diff --git a/speechx/examples/u2pp_ol/wenetspeech/run.sh b/speechx/examples/u2pp_ol/wenetspeech/run.sh index 870c5deeb..711d68083 100755 --- a/speechx/examples/u2pp_ol/wenetspeech/run.sh +++ b/speechx/examples/u2pp_ol/wenetspeech/run.sh @@ -72,13 +72,16 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # process cmvn and compute fbank feat ./local/feat.sh fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # decode with fbank feat input ./local/decode.sh fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # decode with wav input ./loca/recognizer.sh fi