From c1df5b7985eb8a28328e2b7b4685e984c77fcb2f Mon Sep 17 00:00:00 2001 From: yeyupiaoling Date: Wed, 4 Jan 2023 16:37:18 +0800 Subject: [PATCH] merge classes as required, test=asr --- .../asr1/conf/chunk_squeezeformer.yaml | 5 +- paddlespeech/s2t/modules/attention.py | 194 ++++-------------- .../s2t/modules/conformer_convolution.py | 43 +++- paddlespeech/s2t/modules/convolution.py | 181 ---------------- paddlespeech/s2t/modules/encoder.py | 10 +- .../s2t/modules/positionwise_feed_forward.py | 52 +---- 6 files changed, 95 insertions(+), 390 deletions(-) delete mode 100644 paddlespeech/s2t/modules/convolution.py diff --git a/examples/aishell/asr1/conf/chunk_squeezeformer.yaml b/examples/aishell/asr1/conf/chunk_squeezeformer.yaml index 691d90461..2533eacfc 100644 --- a/examples/aishell/asr1/conf/chunk_squeezeformer.yaml +++ b/examples/aishell/asr1/conf/chunk_squeezeformer.yaml @@ -4,7 +4,7 @@ cmvn_file: cmvn_file_type: "json" # encoder related -encoder: conformer +encoder: squeezeformer encoder_conf: encoder_dim: 256 # dimension of attention output_size: 256 # dimension of output @@ -21,7 +21,8 @@ encoder_conf: normalize_before: false activation_type: 'swish' pos_enc_layer_type: 'rel_pos' - time_reduction_layer_type: 'conv2d' + do_rel_shift: false + time_reduction_layer_type: 'stream' causal: true use_dynamic_chunk: true use_dynamic_left_chunk: false diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py index 6347bdb12..b2184dbc7 100644 --- a/paddlespeech/s2t/modules/attention.py +++ b/paddlespeech/s2t/modules/attention.py @@ -203,7 +203,10 @@ class MultiHeadedAttention(nn.Layer): class RelPositionMultiHeadedAttention(MultiHeadedAttention): """Multi-Head Attention layer with relative position encoding.""" - def __init__(self, n_head, n_feat, dropout_rate): + def __init__(self, n_head, n_feat, dropout_rate, + do_rel_shift=False, + adaptive_scale=False, + init_weights=False): """Construct an RelPositionMultiHeadedAttention object. Paper: https://arxiv.org/abs/1901.02860 Args: @@ -226,151 +229,15 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): pos_bias_v = self.create_parameter( (self.h, self.d_k), default_initializer=I.XavierUniform()) self.add_parameter('pos_bias_v', pos_bias_v) - - def rel_shift(self, x, zero_triu: bool=False): - """Compute relative positinal encoding. - Args: - x (paddle.Tensor): Input tensor (batch, head, time1, time1). - zero_triu (bool): If true, return the lower triangular part of - the matrix. - Returns: - paddle.Tensor: Output tensor. (batch, head, time1, time1) - """ - zero_pad = paddle.zeros( - (x.shape[0], x.shape[1], x.shape[2], 1), dtype=x.dtype) - x_padded = paddle.cat([zero_pad, x], dim=-1) - - x_padded = x_padded.view(x.shape[0], x.shape[1], x.shape[3] + 1, - x.shape[2]) - x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1] - - if zero_triu: - ones = paddle.ones((x.shape[2], x.shape[3])) - x = x * paddle.tril(ones, x.shape[3] - x.shape[2])[None, None, :, :] - - return x - - def forward(self, - query: paddle.Tensor, - key: paddle.Tensor, - value: paddle.Tensor, - mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), - pos_emb: paddle.Tensor=paddle.empty([0]), - cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]) - ) -> Tuple[paddle.Tensor, paddle.Tensor]: - """Compute 'Scaled Dot Product Attention' with rel. positional encoding. - Args: - query (paddle.Tensor): Query tensor (#batch, time1, size). - key (paddle.Tensor): Key tensor (#batch, time2, size). - value (paddle.Tensor): Value tensor (#batch, time2, size). - mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or - (#batch, time1, time2), (0, 0, 0) means fake mask. - pos_emb (paddle.Tensor): Positional embedding tensor - (#batch, time2, size). - cache (paddle.Tensor): Cache tensor (1, head, cache_t, d_k * 2), - where `cache_t == chunk_size * num_decoding_left_chunks` - and `head * d_k == size` - Returns: - paddle.Tensor: Output tensor (#batch, time1, d_model). - paddle.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) - where `cache_t == chunk_size * num_decoding_left_chunks` - and `head * d_k == size` - """ - q, k, v = self.forward_qkv(query, key, value) - # q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k) - - # when export onnx model, for 1st chunk, we feed - # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) - # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). - # In all modes, `if cache.size(0) > 0` will alwayse be `True` - # and we will always do splitting and - # concatnation(this will simplify onnx export). Note that - # it's OK to concat & split zero-shaped tensors(see code below). - # when export jit model, for 1st chunk, we always feed - # cache(0, 0, 0, 0) since jit supports dynamic if-branch. - # >>> a = torch.ones((1, 2, 0, 4)) - # >>> b = torch.ones((1, 2, 3, 4)) - # >>> c = torch.cat((a, b), dim=2) - # >>> torch.equal(b, c) # True - # >>> d = torch.split(a, 2, dim=-1) - # >>> torch.equal(d[0], d[1]) # True - if cache.shape[0] > 0: - # last dim `d_k * 2` for (key, val) - key_cache, value_cache = paddle.split(cache, 2, axis=-1) - k = paddle.concat([key_cache, k], axis=2) - v = paddle.concat([value_cache, v], axis=2) - # We do cache slicing in encoder.forward_chunk, since it's - # non-trivial to calculate `next_cache_start` here. - new_cache = paddle.concat((k, v), axis=-1) - - n_batch_pos = pos_emb.shape[0] - p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) - p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) - - # (batch, head, time1, d_k) - # q_with_bias_u = (q + self.pos_bias_u).transpose([0, 2, 1, 3]) - q_with_bias_u = q + self.pos_bias_u.unsqueeze(1) - # (batch, head, time1, d_k) - # q_with_bias_v = (q + self.pos_bias_v).transpose([0, 2, 1, 3]) - q_with_bias_v = q + self.pos_bias_v.unsqueeze(1) - - # compute attention score - # first compute matrix a and matrix c - # as described in https://arxiv.org/abs/1901.02860 Section 3.3 - # (batch, head, time1, time2) - # matrix_ac = paddle.matmul(q_with_bias_u, k.transpose([0, 1, 3, 2])) - matrix_ac = paddle.matmul(q_with_bias_u, k, transpose_y=True) - - # compute matrix b and matrix d - # (batch, head, time1, time2) - # matrix_bd = paddle.matmul(q_with_bias_v, p.transpose([0, 1, 3, 2])) - matrix_bd = paddle.matmul(q_with_bias_v, p, transpose_y=True) - # Remove rel_shift since it is useless in speech recognition, - # and it requires special attention for streaming. - # matrix_bd = self.rel_shift(matrix_bd) - - scores = (matrix_ac + matrix_bd) / math.sqrt( - self.d_k) # (batch, head, time1, time2) - - return self.forward_attention(v, scores, mask), new_cache - - -class RelPositionMultiHeadedAttention2(MultiHeadedAttention): - """Multi-Head Attention layer with relative position encoding. - Paper: https://arxiv.org/abs/1901.02860 - Args: - n_head (int): The number of heads. - n_feat (int): The number of features. - dropout_rate (float): Dropout rate. - """ - - def __init__(self, - n_head, - n_feat, - dropout_rate, - do_rel_shift=False, - adaptive_scale=False, - init_weights=False): - """Construct an RelPositionMultiHeadedAttention object.""" - super().__init__(n_head, n_feat, dropout_rate) - # linear transformation for positional encoding - self.linear_pos = Linear(n_feat, n_feat) - # these two learnable bias are used in matrix c and matrix d - # as described in https://arxiv.org/abs/1901.02860 Section 3.3 self.do_rel_shift = do_rel_shift - pos_bias_u = self.create_parameter( - [self.h, self.d_k], default_initializer=I.XavierUniform()) - self.add_parameter('pos_bias_u', pos_bias_u) - pos_bias_v = self.create_parameter( - [self.h, self.d_k], default_initializer=I.XavierUniform()) - self.add_parameter('pos_bias_v', pos_bias_v) self.adaptive_scale = adaptive_scale - ada_scale = self.create_parameter( - [1, 1, n_feat], default_initializer=I.Constant(1.0)) - self.add_parameter('ada_scale', ada_scale) - ada_bias = self.create_parameter( - [1, 1, n_feat], default_initializer=I.Constant(0.0)) - self.add_parameter('ada_bias', ada_bias) + if self.adaptive_scale: + ada_scale = self.create_parameter( + [1, 1, n_feat], default_initializer=I.Constant(1.0)) + self.add_parameter('ada_scale', ada_scale) + ada_bias = self.create_parameter( + [1, 1, n_feat], default_initializer=I.Constant(0.0)) + self.add_parameter('ada_bias', ada_bias) if init_weights: self.init_weights() @@ -407,12 +274,12 @@ class RelPositionMultiHeadedAttention2(MultiHeadedAttention): paddle.Tensor: Output tensor. (batch, head, time1, time1) """ zero_pad = paddle.zeros( - [x.shape[0], x.shape[1], x.shape[2], 1], dtype=x.dtype) - x_padded = paddle.concat([zero_pad, x], axis=-1) + (x.shape[0], x.shape[1], x.shape[2], 1), dtype=x.dtype) + x_padded = paddle.cat([zero_pad, x], dim=-1) - x_padded = x_padded.reshape( - [x.shape[0], x.shape[1], x.shape[3] + 1, x.shape[2]]) - x = x_padded[:, :, 1:].reshape(paddle.shape(x)) # [B, H, T1, T1] + x_padded = x_padded.view(x.shape[0], x.shape[1], x.shape[3] + 1, + x.shape[2]) + x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1] if zero_triu: ones = paddle.ones((x.shape[2], x.shape[3])) @@ -424,10 +291,10 @@ class RelPositionMultiHeadedAttention2(MultiHeadedAttention): query: paddle.Tensor, key: paddle.Tensor, value: paddle.Tensor, - mask: paddle.Tensor=paddle.ones((0, 0, 0), dtype=paddle.bool), + mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), pos_emb: paddle.Tensor=paddle.empty([0]), - cache: paddle.Tensor=paddle.zeros( - (0, 0, 0, 0))) -> Tuple[paddle.Tensor, paddle.Tensor]: + cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]) + ) -> Tuple[paddle.Tensor, paddle.Tensor]: """Compute 'Scaled Dot Product Attention' with rel. positional encoding. Args: query (paddle.Tensor): Query tensor (#batch, time1, size). @@ -452,17 +319,34 @@ class RelPositionMultiHeadedAttention2(MultiHeadedAttention): value = self.ada_scale * value + self.ada_bias q, k, v = self.forward_qkv(query, key, value) + # q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k) + + # when export onnx model, for 1st chunk, we feed + # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) + # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). + # In all modes, `if cache.size(0) > 0` will alwayse be `True` + # and we will always do splitting and + # concatnation(this will simplify onnx export). Note that + # it's OK to concat & split zero-shaped tensors(see code below). + # when export jit model, for 1st chunk, we always feed + # cache(0, 0, 0, 0) since jit supports dynamic if-branch. + # >>> a = torch.ones((1, 2, 0, 4)) + # >>> b = torch.ones((1, 2, 3, 4)) + # >>> c = torch.cat((a, b), dim=2) + # >>> torch.equal(b, c) # True + # >>> d = torch.split(a, 2, dim=-1) + # >>> torch.equal(d[0], d[1]) # True if cache.shape[0] > 0: + # last dim `d_k * 2` for (key, val) key_cache, value_cache = paddle.split(cache, 2, axis=-1) k = paddle.concat([key_cache, k], axis=2) v = paddle.concat([value_cache, v], axis=2) - # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's + # We do cache slicing in encoder.forward_chunk, since it's # non-trivial to calculate `next_cache_start` here. new_cache = paddle.concat((k, v), axis=-1) n_batch_pos = pos_emb.shape[0] - p = self.linear_pos(pos_emb).reshape( - [n_batch_pos, -1, self.h, self.d_k]) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) # (batch, head, time1, d_k) diff --git a/paddlespeech/s2t/modules/conformer_convolution.py b/paddlespeech/s2t/modules/conformer_convolution.py index 09d903eee..e4196e3d4 100644 --- a/paddlespeech/s2t/modules/conformer_convolution.py +++ b/paddlespeech/s2t/modules/conformer_convolution.py @@ -18,6 +18,7 @@ from typing import Tuple import paddle from paddle import nn +from paddle.nn import initializer as I from typeguard import check_argument_types from paddlespeech.s2t.modules.align import BatchNorm1D @@ -39,7 +40,9 @@ class ConvolutionModule(nn.Layer): activation: nn.Layer=nn.ReLU(), norm: str="batch_norm", causal: bool=False, - bias: bool=True): + bias: bool=True, + adaptive_scale: bool=False, + init_weights: bool=False): """Construct an ConvolutionModule object. Args: channels (int): The number of channels of conv layers. @@ -51,6 +54,19 @@ class ConvolutionModule(nn.Layer): """ assert check_argument_types() super().__init__() + self.bias = bias + self.channels = channels + self.kernel_size = kernel_size + self.adaptive_scale = adaptive_scale + if self.adaptive_scale: + ada_scale = self.create_parameter( + [1, 1, channels], default_initializer=I.Constant(1.0)) + self.add_parameter('ada_scale', ada_scale) + ada_bias = self.create_parameter( + [1, 1, channels], default_initializer=I.Constant(0.0)) + self.add_parameter('ada_bias', ada_bias) + + self.pointwise_conv1 = Conv1D( channels, 2 * channels, @@ -105,6 +121,28 @@ class ConvolutionModule(nn.Layer): ) self.activation = activation + if init_weights: + self.init_weights() + + def init_weights(self): + pw_max = self.channels**-0.5 + dw_max = self.kernel_size**-0.5 + self.pointwise_conv1._param_attr = paddle.nn.initializer.Uniform( + low=-pw_max, high=pw_max) + if self.bias: + self.pointwise_conv1._bias_attr = paddle.nn.initializer.Uniform( + low=-pw_max, high=pw_max) + self.depthwise_conv._param_attr = paddle.nn.initializer.Uniform( + low=-dw_max, high=dw_max) + if self.bias: + self.depthwise_conv._bias_attr = paddle.nn.initializer.Uniform( + low=-dw_max, high=dw_max) + self.pointwise_conv2._param_attr = paddle.nn.initializer.Uniform( + low=-pw_max, high=pw_max) + if self.bias: + self.pointwise_conv2._bias_attr = paddle.nn.initializer.Uniform( + low=-pw_max, high=pw_max) + def forward( self, x: paddle.Tensor, @@ -123,6 +161,9 @@ class ConvolutionModule(nn.Layer): paddle.Tensor: Output tensor (#batch, time, channels). paddle.Tensor: Output cache tensor (#batch, channels, time') """ + if self.adaptive_scale: + x = self.ada_scale * x + self.ada_bias + # exchange the temporal dimension and the feature dimension x = x.transpose([0, 2, 1]) # [B, C, T] diff --git a/paddlespeech/s2t/modules/convolution.py b/paddlespeech/s2t/modules/convolution.py deleted file mode 100644 index caaa98566..000000000 --- a/paddlespeech/s2t/modules/convolution.py +++ /dev/null @@ -1,181 +0,0 @@ -from typing import Tuple - -import paddle -from paddle import nn -from paddle.nn import initializer as I -from typeguard import check_argument_types - -__all__ = ['ConvolutionModule2'] - -from paddlespeech.s2t import masked_fill -from paddlespeech.s2t.modules.align import Conv1D, BatchNorm1D, LayerNorm - - -class ConvolutionModule2(nn.Layer): - """ConvolutionModule in Conformer model.""" - - def __init__(self, - channels: int, - kernel_size: int=15, - activation: nn.Layer=nn.ReLU(), - norm: str="batch_norm", - causal: bool=False, - bias: bool=True, - adaptive_scale: bool=False, - init_weights: bool=False): - """Construct an ConvolutionModule object. - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernel size of conv layers. - causal (int): Whether use causal convolution or not - """ - assert check_argument_types() - super().__init__() - self.bias = bias - self.channels = channels - self.kernel_size = kernel_size - self.adaptive_scale = adaptive_scale - ada_scale = self.create_parameter( - [1, 1, channels], default_initializer=I.Constant(1.0)) - self.add_parameter('ada_scale', ada_scale) - ada_bias = self.create_parameter( - [1, 1, channels], default_initializer=I.Constant(0.0)) - self.add_parameter('ada_bias', ada_bias) - - self.pointwise_conv1 = Conv1D( - channels, - 2 * channels, - kernel_size=1, - stride=1, - padding=0, - bias_attr=None - if bias else False, # None for True, using bias as default config - ) - - # self.lorder is used to distinguish if it's a causal convolution, - # if self.lorder > 0: it's a causal convolution, the input will be - # padded with self.lorder frames on the left in forward. - # else: it's a symmetrical convolution - if causal: - padding = 0 - self.lorder = kernel_size - 1 - else: - # kernel_size should be an odd number for none causal convolution - assert (kernel_size - 1) % 2 == 0 - padding = (kernel_size - 1) // 2 - self.lorder = 0 - self.depthwise_conv = Conv1D( - channels, - channels, - kernel_size, - stride=1, - padding=padding, - groups=channels, - bias_attr=None - if bias else False, # None for True, using bias as default config - ) - - assert norm in ['batch_norm', 'layer_norm'] - if norm == "batch_norm": - self.use_layer_norm = False - self.norm = BatchNorm1D(channels) - else: - self.use_layer_norm = True - self.norm = LayerNorm(channels) - - self.pointwise_conv2 = Conv1D( - channels, - channels, - kernel_size=1, - stride=1, - padding=0, - bias_attr=None - if bias else False, # None for True, using bias as default config - ) - self.activation = activation - - if init_weights: - self.init_weights() - - def init_weights(self): - pw_max = self.channels**-0.5 - dw_max = self.kernel_size**-0.5 - self.pointwise_conv1._param_attr = paddle.nn.initializer.Uniform( - low=-pw_max, high=pw_max) - if self.bias: - self.pointwise_conv1._bias_attr = paddle.nn.initializer.Uniform( - low=-pw_max, high=pw_max) - self.depthwise_conv._param_attr = paddle.nn.initializer.Uniform( - low=-dw_max, high=dw_max) - if self.bias: - self.depthwise_conv._bias_attr = paddle.nn.initializer.Uniform( - low=-dw_max, high=dw_max) - self.pointwise_conv2._param_attr = paddle.nn.initializer.Uniform( - low=-pw_max, high=pw_max) - if self.bias: - self.pointwise_conv2._bias_attr = paddle.nn.initializer.Uniform( - low=-pw_max, high=pw_max) - - def forward( - self, - x: paddle.Tensor, - mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), - cache: paddle.Tensor=paddle.zeros([0, 0, 0]), - ) -> Tuple[paddle.Tensor, paddle.Tensor]: - """Compute convolution module. - Args: - x (torch.Tensor): Input tensor (#batch, time, channels). - mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), - (0, 0, 0) means fake mask. - cache (torch.Tensor): left context cache, it is only - used in causal convolution (#batch, channels, cache_t), - (0, 0, 0) meas fake cache. - Returns: - torch.Tensor: Output tensor (#batch, time, channels). - """ - if self.adaptive_scale: - x = self.ada_scale * x + self.ada_bias - - # exchange the temporal dimension and the feature dimension - x = x.transpose([0, 2, 1]) # [B, C, T] - - # mask batch padding - if mask_pad.shape[2] > 0: # time > 0 - x = masked_fill(x, mask_pad, 0.0) - - if self.lorder > 0: - if cache.shape[2] == 0: # cache_t == 0 - x = nn.functional.pad( - x, [self.lorder, 0], 'constant', 0.0, data_format='NCL') - else: - assert cache.shape[0] == x.shape[0] # B - assert cache.shape[1] == x.shape[1] # C - x = paddle.concat((cache, x), axis=2) - - assert (x.shape[2] > self.lorder) - new_cache = x[:, :, -self.lorder:] # [B, C, T] - else: - # It's better we just return None if no cache is required, - # However, for JIT export, here we just fake one tensor instead of - # None. - new_cache = paddle.zeros([0, 0, 0], dtype=x.dtype) - - # GLU mechanism - x = self.pointwise_conv1(x) # (batch, 2*channel, dim) - x = nn.functional.glu(x, axis=1) # (batch, channel, dim) - - # 1D Depthwise Conv - x = self.depthwise_conv(x) - if self.use_layer_norm: - x = x.transpose([0, 2, 1]) # [B, T, C] - x = self.activation(self.norm(x)) - if self.use_layer_norm: - x = x.transpose([0, 2, 1]) # [B, C, T] - x = self.pointwise_conv2(x) - - # mask batch padding - if mask_pad.shape[2] > 0: # time > 0 - x = masked_fill(x, mask_pad, 0.0) - - x = x.transpose([0, 2, 1]) # [B, T, C] - return x, new_cache diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index f19ecfe41..d133735b2 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -30,7 +30,6 @@ from paddlespeech.s2t.modules.attention import MultiHeadedAttention from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention2 from paddlespeech.s2t.modules.conformer_convolution import ConvolutionModule -from paddlespeech.s2t.modules.convolution import ConvolutionModule2 from paddlespeech.s2t.modules.embedding import NoPositionalEncoding from paddlespeech.s2t.modules.embedding import PositionalEncoding from paddlespeech.s2t.modules.embedding import RelPositionalEncoding @@ -40,7 +39,6 @@ from paddlespeech.s2t.modules.encoder_layer import TransformerEncoderLayer from paddlespeech.s2t.modules.mask import add_optional_chunk_mask from paddlespeech.s2t.modules.mask import make_non_pad_mask from paddlespeech.s2t.modules.positionwise_feed_forward import PositionwiseFeedForward -from paddlespeech.s2t.modules.positionwise_feed_forward import PositionwiseFeedForward2 from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling4 from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling6 from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling8 @@ -591,19 +589,19 @@ class SqueezeformerEncoder(nn.Layer): encoder_selfattn_layer_args = (attention_heads, output_size, attention_dropout_rate) else: - encoder_selfattn_layer = RelPositionMultiHeadedAttention2 + encoder_selfattn_layer = RelPositionMultiHeadedAttention encoder_selfattn_layer_args = (attention_heads, encoder_dim, attention_dropout_rate, do_rel_shift, adaptive_scale, init_weights) # feed-forward module definition - positionwise_layer = PositionwiseFeedForward2 + positionwise_layer = PositionwiseFeedForward positionwise_layer_args = ( encoder_dim, encoder_dim * feed_forward_expansion_factor, feed_forward_dropout_rate, activation, adaptive_scale, init_weights) # convolution module definition - convolution_layer = ConvolutionModule2 + convolution_layer = ConvolutionModule convolution_layer_args = (encoder_dim, cnn_module_kernel, activation, cnn_norm_type, causal, True, adaptive_scale, init_weights) @@ -676,7 +674,7 @@ class SqueezeformerEncoder(nn.Layer): if self.global_cmvn is not None: xs = self.global_cmvn(xs) xs, pos_emb, masks = self.embed(xs, masks) - mask_pad = ~masks + mask_pad = masks chunk_masks = add_optional_chunk_mask( xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk, decoding_chunk_size, self.static_chunk_size, diff --git a/paddlespeech/s2t/modules/positionwise_feed_forward.py b/paddlespeech/s2t/modules/positionwise_feed_forward.py index 39d8b1893..b5395f049 100644 --- a/paddlespeech/s2t/modules/positionwise_feed_forward.py +++ b/paddlespeech/s2t/modules/positionwise_feed_forward.py @@ -23,7 +23,7 @@ from paddlespeech.s2t.utils.log import Log logger = Log(__name__).getlog() -__all__ = ["PositionwiseFeedForward", "PositionwiseFeedForward2"] +__all__ = ["PositionwiseFeedForward"] class PositionwiseFeedForward(nn.Layer): @@ -33,7 +33,9 @@ class PositionwiseFeedForward(nn.Layer): idim: int, hidden_units: int, dropout_rate: float, - activation: nn.Layer=nn.ReLU()): + activation: nn.Layer=nn.ReLU(), + adaptive_scale: bool=False, + init_weights: bool=False): """Construct a PositionwiseFeedForward object. FeedForward are appied on each position of the sequence. @@ -46,48 +48,11 @@ class PositionwiseFeedForward(nn.Layer): activation (paddle.nn.Layer): Activation function """ super().__init__() - self.w_1 = Linear(idim, hidden_units) - self.activation = activation - self.dropout = nn.Dropout(dropout_rate) - self.w_2 = Linear(hidden_units, idim) - - def forward(self, xs: paddle.Tensor) -> paddle.Tensor: - """Forward function. - Args: - xs: input tensor (B, Lmax, D) - Returns: - output tensor, (B, Lmax, D) - """ - return self.w_2(self.dropout(self.activation(self.w_1(xs)))) - - -class PositionwiseFeedForward2(paddle.nn.Layer): - """Positionwise feed forward layer. - - FeedForward are appied on each position of the sequence. - The output dim is same with the input dim. - - Args: - idim (int): Input dimenstion. - hidden_units (int): The number of hidden units. - dropout_rate (float): Dropout rate. - activation (paddle.nn.Layer): Activation function - """ - - def __init__(self, - idim: int, - hidden_units: int, - dropout_rate: float, - activation: paddle.nn.Layer=paddle.nn.ReLU(), - adaptive_scale: bool=False, - init_weights: bool=False): - """Construct a PositionwiseFeedForward object.""" - super(PositionwiseFeedForward2, self).__init__() self.idim = idim self.hidden_units = hidden_units self.w_1 = Linear(idim, hidden_units) self.activation = activation - self.dropout = paddle.nn.Dropout(dropout_rate) + self.dropout = nn.Dropout(dropout_rate) self.w_2 = Linear(hidden_units, idim) self.adaptive_scale = adaptive_scale ada_scale = self.create_parameter( @@ -114,12 +79,9 @@ class PositionwiseFeedForward2(paddle.nn.Layer): def forward(self, xs: paddle.Tensor) -> paddle.Tensor: """Forward function. - Args: - xs: input tensor (B, L, D) + xs: input tensor (B, Lmax, D) Returns: - output tensor, (B, L, D) + output tensor, (B, Lmax, D) """ - if self.adaptive_scale: - xs = self.ada_scale * xs + self.ada_bias return self.w_2(self.dropout(self.activation(self.w_1(xs))))