diff --git a/examples/aishell/asr1/conf/chunk_squeezeformer.yaml b/examples/aishell/asr1/conf/chunk_squeezeformer.yaml index 2533eacfc..45a2ac965 100644 --- a/examples/aishell/asr1/conf/chunk_squeezeformer.yaml +++ b/examples/aishell/asr1/conf/chunk_squeezeformer.yaml @@ -12,7 +12,7 @@ encoder_conf: num_blocks: 12 # the number of encoder blocks reduce_idx: 5 recover_idx: 11 - feed_forward_expansion_factor: 4 + feed_forward_expansion_factor: 8 input_dropout_rate: 0.1 feed_forward_dropout_rate: 0.1 attention_dropout_rate: 0.1 diff --git a/examples/aishell/asr1/conf/squeezeformer.yaml b/examples/aishell/asr1/conf/squeezeformer.yaml index db8ef7c2d..49a837a82 100644 --- a/examples/aishell/asr1/conf/squeezeformer.yaml +++ b/examples/aishell/asr1/conf/squeezeformer.yaml @@ -12,7 +12,7 @@ encoder_conf: num_blocks: 12 # the number of encoder blocks reduce_idx: 5 recover_idx: 11 - feed_forward_expansion_factor: 4 + feed_forward_expansion_factor: 8 input_dropout_rate: 0.1 feed_forward_dropout_rate: 0.1 attention_dropout_rate: 0.1 diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index d133735b2..7be192575 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -28,7 +28,6 @@ from paddlespeech.s2t.modules.align import LayerNorm from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.modules.attention import MultiHeadedAttention from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention -from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention2 from paddlespeech.s2t.modules.conformer_convolution import ConvolutionModule from paddlespeech.s2t.modules.embedding import NoPositionalEncoding from paddlespeech.s2t.modules.embedding import PositionalEncoding @@ -44,9 +43,9 @@ from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling6 from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling8 from paddlespeech.s2t.modules.subsampling import DepthwiseConv2DSubsampling4 from paddlespeech.s2t.modules.subsampling import LinearNoSubsampling -from paddlespeech.s2t.modules.subsampling import TimeReductionLayer1D -from paddlespeech.s2t.modules.subsampling import TimeReductionLayer2D -from paddlespeech.s2t.modules.subsampling import TimeReductionLayerStream +from paddlespeech.s2t.modules.time_reduction import TimeReductionLayer1D +from paddlespeech.s2t.modules.time_reduction import TimeReductionLayer2D +from paddlespeech.s2t.modules.time_reduction import TimeReductionLayerStream from paddlespeech.s2t.utils.log import Log logger = Log(__name__).getlog() diff --git a/paddlespeech/s2t/modules/subsampling.py b/paddlespeech/s2t/modules/subsampling.py index 51322d324..ef60bdf0a 100644 --- a/paddlespeech/s2t/modules/subsampling.py +++ b/paddlespeech/s2t/modules/subsampling.py @@ -17,15 +17,11 @@ from typing import Tuple import paddle -import paddle.nn.functional as F from paddle import nn -from paddlespeech.s2t import masked_fill -from paddlespeech.s2t.modules.align import Conv1D from paddlespeech.s2t.modules.align import Conv2D from paddlespeech.s2t.modules.align import LayerNorm from paddlespeech.s2t.modules.align import Linear -from paddlespeech.s2t.modules.conv2d import Conv2DValid from paddlespeech.s2t.modules.embedding import PositionalEncoding from paddlespeech.s2t.utils.log import Log @@ -33,8 +29,7 @@ logger = Log(__name__).getlog() __all__ = [ "LinearNoSubsampling", "Conv2dSubsampling4", "Conv2dSubsampling6", - "Conv2dSubsampling8", "TimeReductionLayerStream", "TimeReductionLayer1D", - "TimeReductionLayer2D", "DepthwiseConv2DSubsampling4" + "Conv2dSubsampling8", "DepthwiseConv2DSubsampling4" ] @@ -318,234 +313,3 @@ class DepthwiseConv2DSubsampling4(BaseSubsampling): x, pos_emb = self.pos_enc(x, offset) x = self.input_proj(x) return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2] - - -class TimeReductionLayer1D(nn.Layer): - """ - Modified NeMo, - Squeezeformer Time Reduction procedure. - Downsamples the audio by `stride` in the time dimension. - Args: - channel (int): input dimension of - MultiheadAttentionMechanism and PositionwiseFeedForward - out_dim (int): Output dimension of the module. - kernel_size (int): Conv kernel size for - depthwise convolution in convolution module - stride (int): Downsampling factor in time dimension. - """ - - def __init__(self, - channel: int, - out_dim: int, - kernel_size: int=5, - stride: int=2): - super(TimeReductionLayer1D, self).__init__() - - self.channel = channel - self.out_dim = out_dim - self.kernel_size = kernel_size - self.stride = stride - self.padding = max(0, self.kernel_size - self.stride) - - self.dw_conv = Conv1D( - in_channels=channel, - out_channels=channel, - kernel_size=kernel_size, - stride=stride, - padding=self.padding, - groups=channel, ) - - self.pw_conv = Conv1D( - in_channels=channel, - out_channels=out_dim, - kernel_size=1, - stride=1, - padding=0, - groups=1, ) - - self.init_weights() - - def init_weights(self): - dw_max = self.kernel_size**-0.5 - pw_max = self.channel**-0.5 - self.dw_conv._param_attr = paddle.nn.initializer.Uniform( - low=-dw_max, high=dw_max) - self.dw_conv._bias_attr = paddle.nn.initializer.Uniform( - low=-dw_max, high=dw_max) - self.pw_conv._param_attr = paddle.nn.initializer.Uniform( - low=-pw_max, high=pw_max) - self.pw_conv._bias_attr = paddle.nn.initializer.Uniform( - low=-pw_max, high=pw_max) - - def forward( - self, - xs, - xs_lens: paddle.Tensor, - mask: paddle.Tensor=paddle.ones((0, 0, 0), dtype=paddle.bool), - mask_pad: paddle.Tensor=paddle.ones((0, 0, 0), - dtype=paddle.bool), ): - xs = xs.transpose([0, 2, 1]) # [B, C, T] - xs = masked_fill(xs, mask_pad.equal(0), 0.0) - - xs = self.dw_conv(xs) - xs = self.pw_conv(xs) - - xs = xs.transpose([0, 2, 1]) # [B, T, C] - - B, T, D = xs.shape - mask = mask[:, ::self.stride, ::self.stride] - mask_pad = mask_pad[:, :, ::self.stride] - L = mask_pad.shape[-1] - # For JIT exporting, we remove F.pad operator. - if L - T < 0: - xs = xs[:, :L - T, :] - else: - dummy_pad = paddle.zeros([B, L - T, D], dtype=paddle.float32) - xs = paddle.concat([xs, dummy_pad], axis=1) - - xs_lens = (xs_lens + 1) // 2 - return xs, xs_lens, mask, mask_pad - - -class TimeReductionLayer2D(nn.Layer): - def __init__(self, kernel_size: int=5, stride: int=2, encoder_dim: int=256): - super(TimeReductionLayer2D, self).__init__() - self.encoder_dim = encoder_dim - self.kernel_size = kernel_size - self.dw_conv = Conv2DValid( - in_channels=encoder_dim, - out_channels=encoder_dim, - kernel_size=(kernel_size, 1), - stride=stride, - valid_trigy=True) - self.pw_conv = Conv2DValid( - in_channels=encoder_dim, - out_channels=encoder_dim, - kernel_size=1, - stride=1, - valid_trigx=False, - valid_trigy=False) - - self.kernel_size = kernel_size - self.stride = stride - self.init_weights() - - def init_weights(self): - dw_max = self.kernel_size**-0.5 - pw_max = self.encoder_dim**-0.5 - self.dw_conv._param_attr = paddle.nn.initializer.Uniform( - low=-dw_max, high=dw_max) - self.dw_conv._bias_attr = paddle.nn.initializer.Uniform( - low=-dw_max, high=dw_max) - self.pw_conv._param_attr = paddle.nn.initializer.Uniform( - low=-pw_max, high=pw_max) - self.pw_conv._bias_attr = paddle.nn.initializer.Uniform( - low=-pw_max, high=pw_max) - - def forward( - self, - xs: paddle.Tensor, - xs_lens: paddle.Tensor, - mask: paddle.Tensor=paddle.ones((0, 0, 0), dtype=paddle.bool), - mask_pad: paddle.Tensor=paddle.ones((0, 0, 0), dtype=paddle.bool), - ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: - xs = masked_fill(xs, mask_pad.transpose([0, 2, 1]).equal(0), 0.0) - xs = xs.unsqueeze(1) - padding1 = self.kernel_size - self.stride - xs = F.pad( - xs, (0, 0, 0, 0, 0, padding1, 0, 0), mode='constant', value=0.) - xs = self.dw_conv(xs.transpose([0, 3, 2, 1])) - xs = self.pw_conv(xs).transpose([0, 3, 2, 1]).squeeze(1) - tmp_length = xs.shape[1] - xs_lens = (xs_lens + 1) // 2 - padding2 = max(0, (xs_lens.max() - tmp_length).item()) - batch_size, hidden = xs.shape[0], xs.shape[-1] - dummy_pad = paddle.zeros( - [batch_size, padding2, hidden], dtype=paddle.float32) - xs = paddle.concat([xs, dummy_pad], axis=1) - mask = mask[:, ::2, ::2] - mask_pad = mask_pad[:, :, ::2] - return xs, xs_lens, mask, mask_pad - - -class TimeReductionLayerStream(nn.Layer): - """ - Squeezeformer Time Reduction procedure. - Downsamples the audio by `stride` in the time dimension. - Args: - channel (int): input dimension of - MultiheadAttentionMechanism and PositionwiseFeedForward - out_dim (int): Output dimension of the module. - kernel_size (int): Conv kernel size for - depthwise convolution in convolution module - stride (int): Downsampling factor in time dimension. - """ - - def __init__(self, - channel: int, - out_dim: int, - kernel_size: int=1, - stride: int=2): - super(TimeReductionLayerStream, self).__init__() - - self.channel = channel - self.out_dim = out_dim - self.kernel_size = kernel_size - self.stride = stride - - self.dw_conv = Conv1D( - in_channels=channel, - out_channels=channel, - kernel_size=kernel_size, - stride=stride, - padding=0, - groups=channel) - - self.pw_conv = Conv1D( - in_channels=channel, - out_channels=out_dim, - kernel_size=1, - stride=1, - padding=0, - groups=1) - self.init_weights() - - def init_weights(self): - dw_max = self.kernel_size**-0.5 - pw_max = self.channel**-0.5 - self.dw_conv._param_attr = paddle.nn.initializer.Uniform( - low=-dw_max, high=dw_max) - self.dw_conv._bias_attr = paddle.nn.initializer.Uniform( - low=-dw_max, high=dw_max) - self.pw_conv._param_attr = paddle.nn.initializer.Uniform( - low=-pw_max, high=pw_max) - self.pw_conv._bias_attr = paddle.nn.initializer.Uniform( - low=-pw_max, high=pw_max) - - def forward( - self, - xs, - xs_lens: paddle.Tensor, - mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), - mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool)): - xs = xs.transpose([0, 2, 1]) # [B, C, T] - xs = masked_fill(xs, mask_pad.equal(0), 0.0) - - xs = self.dw_conv(xs) - xs = self.pw_conv(xs) - - xs = xs.transpose([0, 2, 1]) # [B, T, C] - - B, T, D = xs.shape - mask = mask[:, ::self.stride, ::self.stride] - mask_pad = mask_pad[:, :, ::self.stride] - L = mask_pad.shape[-1] - # For JIT exporting, we remove F.pad operator. - if L - T < 0: - xs = xs[:, :L - T, :] - else: - dummy_pad = paddle.zeros([B, L - T, D], dtype=paddle.float32) - xs = paddle.concat([xs, dummy_pad], axis=1) - - xs_lens = (xs_lens + 1) // 2 - return xs, xs_lens, mask, mask_pad diff --git a/paddlespeech/s2t/modules/time_reduction.py b/paddlespeech/s2t/modules/time_reduction.py new file mode 100644 index 000000000..d3393f108 --- /dev/null +++ b/paddlespeech/s2t/modules/time_reduction.py @@ -0,0 +1,263 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2019 Mobvoi Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from wenet(https://github.com/wenet-e2e/wenet) +"""Subsampling layer definition.""" +from typing import Tuple + +import paddle +import paddle.nn.functional as F +from paddle import nn + +from paddlespeech.s2t import masked_fill +from paddlespeech.s2t.modules.align import Conv1D +from paddlespeech.s2t.modules.conv2d import Conv2DValid +from paddlespeech.s2t.utils.log import Log + +logger = Log(__name__).getlog() + +__all__ = [ + "TimeReductionLayerStream", "TimeReductionLayer1D", "TimeReductionLayer2D" +] + + +class TimeReductionLayer1D(nn.Layer): + """ + Modified NeMo, + Squeezeformer Time Reduction procedure. + Downsamples the audio by `stride` in the time dimension. + Args: + channel (int): input dimension of + MultiheadAttentionMechanism and PositionwiseFeedForward + out_dim (int): Output dimension of the module. + kernel_size (int): Conv kernel size for + depthwise convolution in convolution module + stride (int): Downsampling factor in time dimension. + """ + + def __init__(self, + channel: int, + out_dim: int, + kernel_size: int=5, + stride: int=2): + super(TimeReductionLayer1D, self).__init__() + + self.channel = channel + self.out_dim = out_dim + self.kernel_size = kernel_size + self.stride = stride + self.padding = max(0, self.kernel_size - self.stride) + + self.dw_conv = Conv1D( + in_channels=channel, + out_channels=channel, + kernel_size=kernel_size, + stride=stride, + padding=self.padding, + groups=channel, ) + + self.pw_conv = Conv1D( + in_channels=channel, + out_channels=out_dim, + kernel_size=1, + stride=1, + padding=0, + groups=1, ) + + self.init_weights() + + def init_weights(self): + dw_max = self.kernel_size**-0.5 + pw_max = self.channel**-0.5 + self.dw_conv._param_attr = paddle.nn.initializer.Uniform( + low=-dw_max, high=dw_max) + self.dw_conv._bias_attr = paddle.nn.initializer.Uniform( + low=-dw_max, high=dw_max) + self.pw_conv._param_attr = paddle.nn.initializer.Uniform( + low=-pw_max, high=pw_max) + self.pw_conv._bias_attr = paddle.nn.initializer.Uniform( + low=-pw_max, high=pw_max) + + def forward( + self, + xs, + xs_lens: paddle.Tensor, + mask: paddle.Tensor=paddle.ones((0, 0, 0), dtype=paddle.bool), + mask_pad: paddle.Tensor=paddle.ones((0, 0, 0), + dtype=paddle.bool), ): + xs = xs.transpose([0, 2, 1]) # [B, C, T] + xs = masked_fill(xs, mask_pad.equal(0), 0.0) + + xs = self.dw_conv(xs) + xs = self.pw_conv(xs) + + xs = xs.transpose([0, 2, 1]) # [B, T, C] + + B, T, D = xs.shape + mask = mask[:, ::self.stride, ::self.stride] + mask_pad = mask_pad[:, :, ::self.stride] + L = mask_pad.shape[-1] + # For JIT exporting, we remove F.pad operator. + if L - T < 0: + xs = xs[:, :L - T, :] + else: + dummy_pad = paddle.zeros([B, L - T, D], dtype=paddle.float32) + xs = paddle.concat([xs, dummy_pad], axis=1) + + xs_lens = (xs_lens + 1) // 2 + return xs, xs_lens, mask, mask_pad + + +class TimeReductionLayer2D(nn.Layer): + def __init__(self, kernel_size: int=5, stride: int=2, encoder_dim: int=256): + super(TimeReductionLayer2D, self).__init__() + self.encoder_dim = encoder_dim + self.kernel_size = kernel_size + self.dw_conv = Conv2DValid( + in_channels=encoder_dim, + out_channels=encoder_dim, + kernel_size=(kernel_size, 1), + stride=stride, + valid_trigy=True) + self.pw_conv = Conv2DValid( + in_channels=encoder_dim, + out_channels=encoder_dim, + kernel_size=1, + stride=1, + valid_trigx=False, + valid_trigy=False) + + self.kernel_size = kernel_size + self.stride = stride + self.init_weights() + + def init_weights(self): + dw_max = self.kernel_size**-0.5 + pw_max = self.encoder_dim**-0.5 + self.dw_conv._param_attr = paddle.nn.initializer.Uniform( + low=-dw_max, high=dw_max) + self.dw_conv._bias_attr = paddle.nn.initializer.Uniform( + low=-dw_max, high=dw_max) + self.pw_conv._param_attr = paddle.nn.initializer.Uniform( + low=-pw_max, high=pw_max) + self.pw_conv._bias_attr = paddle.nn.initializer.Uniform( + low=-pw_max, high=pw_max) + + def forward( + self, + xs: paddle.Tensor, + xs_lens: paddle.Tensor, + mask: paddle.Tensor=paddle.ones((0, 0, 0), dtype=paddle.bool), + mask_pad: paddle.Tensor=paddle.ones((0, 0, 0), dtype=paddle.bool), + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + xs = masked_fill(xs, mask_pad.transpose([0, 2, 1]).equal(0), 0.0) + xs = xs.unsqueeze(1) + padding1 = self.kernel_size - self.stride + xs = F.pad( + xs, (0, 0, 0, 0, 0, padding1, 0, 0), mode='constant', value=0.) + xs = self.dw_conv(xs.transpose([0, 3, 2, 1])) + xs = self.pw_conv(xs).transpose([0, 3, 2, 1]).squeeze(1) + tmp_length = xs.shape[1] + xs_lens = (xs_lens + 1) // 2 + padding2 = max(0, (xs_lens.max() - tmp_length).item()) + batch_size, hidden = xs.shape[0], xs.shape[-1] + dummy_pad = paddle.zeros( + [batch_size, padding2, hidden], dtype=paddle.float32) + xs = paddle.concat([xs, dummy_pad], axis=1) + mask = mask[:, ::2, ::2] + mask_pad = mask_pad[:, :, ::2] + return xs, xs_lens, mask, mask_pad + + +class TimeReductionLayerStream(nn.Layer): + """ + Squeezeformer Time Reduction procedure. + Downsamples the audio by `stride` in the time dimension. + Args: + channel (int): input dimension of + MultiheadAttentionMechanism and PositionwiseFeedForward + out_dim (int): Output dimension of the module. + kernel_size (int): Conv kernel size for + depthwise convolution in convolution module + stride (int): Downsampling factor in time dimension. + """ + + def __init__(self, + channel: int, + out_dim: int, + kernel_size: int=1, + stride: int=2): + super(TimeReductionLayerStream, self).__init__() + + self.channel = channel + self.out_dim = out_dim + self.kernel_size = kernel_size + self.stride = stride + + self.dw_conv = Conv1D( + in_channels=channel, + out_channels=channel, + kernel_size=kernel_size, + stride=stride, + padding=0, + groups=channel) + + self.pw_conv = Conv1D( + in_channels=channel, + out_channels=out_dim, + kernel_size=1, + stride=1, + padding=0, + groups=1) + self.init_weights() + + def init_weights(self): + dw_max = self.kernel_size**-0.5 + pw_max = self.channel**-0.5 + self.dw_conv._param_attr = paddle.nn.initializer.Uniform( + low=-dw_max, high=dw_max) + self.dw_conv._bias_attr = paddle.nn.initializer.Uniform( + low=-dw_max, high=dw_max) + self.pw_conv._param_attr = paddle.nn.initializer.Uniform( + low=-pw_max, high=pw_max) + self.pw_conv._bias_attr = paddle.nn.initializer.Uniform( + low=-pw_max, high=pw_max) + + def forward( + self, + xs, + xs_lens: paddle.Tensor, + mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), + mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool)): + xs = xs.transpose([0, 2, 1]) # [B, C, T] + xs = masked_fill(xs, mask_pad.equal(0), 0.0) + + xs = self.dw_conv(xs) + xs = self.pw_conv(xs) + + xs = xs.transpose([0, 2, 1]) # [B, T, C] + + B, T, D = xs.shape + mask = mask[:, ::self.stride, ::self.stride] + mask_pad = mask_pad[:, :, ::self.stride] + L = mask_pad.shape[-1] + # For JIT exporting, we remove F.pad operator. + if L - T < 0: + xs = xs[:, :L - T, :] + else: + dummy_pad = paddle.zeros([B, L - T, D], dtype=paddle.float32) + xs = paddle.concat([xs, dummy_pad], axis=1) + + xs_lens = (xs_lens + 1) // 2 + return xs, xs_lens, mask, mask_pad