From fc02cd0540fd9d706f9400b30dafde243c7f9753 Mon Sep 17 00:00:00 2001 From: Zth9730 <32243340+Zth9730@users.noreply.github.com> Date: Tue, 22 Nov 2022 17:37:33 +0800 Subject: [PATCH] [doc] update wav2vec2 demos README.md, test=doc (#2674) * fix wav2vec2 demos, test=doc * fix wav2vec2 demos, test=doc * fix enc_dropout and nor.py, test=asr --- demos/speech_ssl/README.md | 2 +- demos/speech_ssl/README_cn.md | 8 +- .../s2t/models/wav2vec2/modules/VanillaNN.py | 5 +- .../models/wav2vec2/modules/normalization.py | 97 +++++++++++++++++++ 4 files changed, 104 insertions(+), 8 deletions(-) create mode 100644 paddlespeech/s2t/models/wav2vec2/modules/normalization.py diff --git a/demos/speech_ssl/README.md b/demos/speech_ssl/README.md index fdef37e7..b98a7cc6 100644 --- a/demos/speech_ssl/README.md +++ b/demos/speech_ssl/README.md @@ -82,7 +82,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav Output: ```bash ASR Result: - 我认为跑步最重要的就是给我带来了身体健康 + i knocked at the door on the ancient side of the building Representation: Tensor(shape=[1, 164, 1024], dtype=float32, place=Place(gpu:0), stop_gradient=True, diff --git a/demos/speech_ssl/README_cn.md b/demos/speech_ssl/README_cn.md index 76ec2f1f..65961ce9 100644 --- a/demos/speech_ssl/README_cn.md +++ b/demos/speech_ssl/README_cn.md @@ -36,9 +36,9 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav ``` 参数: - `input`(必须输入):用于识别的音频文件。 - - `model`:ASR 任务的模型,默认值:`conformer_wenetspeech`。 + - `model`:ASR 任务的模型,默认值:`wav2vec2ASR_librispeech`。 - `task`:输出类别,默认值:`asr`。 - - `lang`:模型语言,默认值:`zh`。 + - `lang`:模型语言,默认值:`en`。 - `sample_rate`:音频采样率,默认值:`16000`。 - `config`:ASR 任务的参数文件,若不设置则使用预训练模型中的默认配置,默认值:`None`。 - `ckpt_path`:模型参数文件,若不设置则下载预训练模型使用,默认值:`None`。 @@ -83,8 +83,8 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav 输出: ```bash ASR Result: - 我认为跑步最重要的就是给我带来了身体健康 - + i knocked at the door on the ancient side of the building + Representation: Tensor(shape=[1, 164, 1024], dtype=float32, place=Place(gpu:0), stop_gradient=True, [[[ 0.02351918, -0.12980647, 0.17868176, ..., 0.10118122, diff --git a/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py b/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py index 82313c33..9c88796b 100644 --- a/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py +++ b/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py @@ -46,7 +46,7 @@ class VanillaNN(containers.Sequential): dnn_neurons=512, activation=True, normalization=False, - dropout_rate=0.0): + dropout_rate=0.5): super().__init__(input_shape=[None, None, input_shape]) if not isinstance(dropout_rate, list): @@ -68,6 +68,5 @@ class VanillaNN(containers.Sequential): if activation: self.append(paddle.nn.LeakyReLU(), layer_name="act") self.append( - paddle.nn.Dropout(), - p=dropout_rate[block_index], + paddle.nn.Dropout(p=dropout_rate[block_index]), layer_name='dropout') diff --git a/paddlespeech/s2t/models/wav2vec2/modules/normalization.py b/paddlespeech/s2t/models/wav2vec2/modules/normalization.py new file mode 100644 index 00000000..91298105 --- /dev/null +++ b/paddlespeech/s2t/models/wav2vec2/modules/normalization.py @@ -0,0 +1,97 @@ +# Authors +# * Mirco Ravanelli 2020 +# * Guillermo Cámbara 2021 +# * Sarthak Yadav 2022 +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from speechbrain(https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/nnet/normalization.py) +import paddle.nn as nn + +from paddlespeech.s2t.modules.align import BatchNorm1D + + +class BatchNorm1d(nn.Layer): + """Applies 1d batch normalization to the input tensor. + Arguments + --------- + input_shape : tuple + The expected shape of the input. Alternatively, use ``input_size``. + input_size : int + The expected size of the input. Alternatively, use ``input_shape``. + eps : float + This value is added to std deviation estimation to improve the numerical + stability. + momentum : float + It is a value used for the running_mean and running_var computation. + affine : bool + When set to True, the affine parameters are learned. + track_running_stats : bool + When set to True, this module tracks the running mean and variance, + and when set to False, this module does not track such statistics. + combine_batch_time : bool + When true, it combines batch an time axis. + Example + ------- + >>> input = paddle.randn([100, 10]) + >>> norm = BatchNorm1d(input_shape=input.shape) + >>> output = norm(input) + >>> output.shape + Paddle.Shape([100, 10]) + """ + + def __init__( + self, + input_shape=None, + input_size=None, + eps=1e-05, + momentum=0.9, + combine_batch_time=False, + skip_transpose=False, ): + super().__init__() + self.combine_batch_time = combine_batch_time + self.skip_transpose = skip_transpose + + if input_size is None and skip_transpose: + input_size = input_shape[1] + elif input_size is None: + input_size = input_shape[-1] + + self.norm = BatchNorm1D(input_size, momentum=momentum, epsilon=eps) + + def forward(self, x): + """Returns the normalized input tensor. + Arguments + --------- + x : paddle.Tensor (batch, time, [channels]) + input to normalize. 2d or 3d tensors are expected in input + 4d tensors can be used when combine_dims=True. + """ + shape_or = x.shape + if self.combine_batch_time: + if x.ndim == 3: + x = x.reshape(shape_or[0] * shape_or[1], shape_or[2]) + else: + x = x.reshape(shape_or[0] * shape_or[1], shape_or[3], + shape_or[2]) + + elif not self.skip_transpose: + x = x.transpose([0, 2, 1]) + + x_n = self.norm(x) + if self.combine_batch_time: + x_n = x_n.reshape(shape_or) + elif not self.skip_transpose: + x_n = x_n.transpose([0, 2, 1]) + + return x_n