diff --git a/examples/aishell/asr1/conf/chunk_squeezeformer.yaml b/examples/aishell/asr1/conf/chunk_squeezeformer.yaml new file mode 100644 index 000000000..691d90461 --- /dev/null +++ b/examples/aishell/asr1/conf/chunk_squeezeformer.yaml @@ -0,0 +1,98 @@ +############################################ +# Network Architecture # +############################################ +cmvn_file: +cmvn_file_type: "json" +# encoder related +encoder: conformer +encoder_conf: + encoder_dim: 256 # dimension of attention + output_size: 256 # dimension of output + attention_heads: 4 + num_blocks: 12 # the number of encoder blocks + reduce_idx: 5 + recover_idx: 11 + feed_forward_expansion_factor: 4 + input_dropout_rate: 0.1 + feed_forward_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + adaptive_scale: true + cnn_module_kernel: 31 + normalize_before: false + activation_type: 'swish' + pos_enc_layer_type: 'rel_pos' + time_reduction_layer_type: 'conv2d' + causal: true + use_dynamic_chunk: true + use_dynamic_left_chunk: false + +# decoder related +decoder: transformer +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 # sublayer output dropout + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 +# hybrid CTC/attention +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + init_type: 'kaiming_uniform' # !Warning: need to convergence + +########################################### +# Data # +########################################### + +train_manifest: data/manifest.train +dev_manifest: data/manifest.dev +test_manifest: data/manifest.test + + +########################################### +# Dataloader # +########################################### + +vocab_filepath: data/lang_char/vocab.txt +spm_model_prefix: '' +unit_type: 'char' +preprocess_config: conf/preprocess.yaml +feat_dim: 80 +stride_ms: 10.0 +window_ms: 25.0 +sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs +batch_size: 32 +maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced +maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced +minibatches: 0 # for debug +batch_count: auto +batch_bins: 0 +batch_frames_in: 0 +batch_frames_out: 0 +batch_frames_inout: 0 +num_workers: 2 +subsampling_factor: 1 +num_encs: 1 + +########################################### +# Training # +########################################### +n_epoch: 240 +accum_grad: 1 +global_grad_clip: 5.0 +dist_sampler: True +optim: adam +optim_conf: + lr: 0.001 + weight_decay: 1.0e-6 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 25000 + lr_decay: 1.0 +log_interval: 100 +checkpoint: + kbest_n: 50 + latest_n: 5 diff --git a/examples/aishell/asr1/conf/squeezeformer.yaml b/examples/aishell/asr1/conf/squeezeformer.yaml new file mode 100644 index 000000000..db8ef7c2d --- /dev/null +++ b/examples/aishell/asr1/conf/squeezeformer.yaml @@ -0,0 +1,93 @@ +############################################ +# Network Architecture # +############################################ +cmvn_file: +cmvn_file_type: "json" +# encoder related +encoder: squeezeformer +encoder_conf: + encoder_dim: 256 # dimension of attention + output_size: 256 # dimension of output + attention_heads: 4 + num_blocks: 12 # the number of encoder blocks + reduce_idx: 5 + recover_idx: 11 + feed_forward_expansion_factor: 4 + input_dropout_rate: 0.1 + feed_forward_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + adaptive_scale: true + cnn_module_kernel: 31 + normalize_before: false + activation_type: 'swish' + pos_enc_layer_type: 'rel_pos' + time_reduction_layer_type: 'conv2d' + +# decoder related +decoder: transformer +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + +# hybrid CTC/attention +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + init_type: 'kaiming_uniform' # !Warning: need to convergence + +########################################### +# Data # +########################################### +train_manifest: data/manifest.train +dev_manifest: data/manifest.dev +test_manifest: data/manifest.test + +########################################### +# Dataloader # +########################################### +vocab_filepath: data/lang_char/vocab.txt +spm_model_prefix: '' +unit_type: 'char' +preprocess_config: conf/preprocess.yaml +feat_dim: 80 +stride_ms: 10.0 +window_ms: 25.0 +sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs +batch_size: 32 +maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced +maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced +minibatches: 0 # for debug +batch_count: auto +batch_bins: 0 +batch_frames_in: 0 +batch_frames_out: 0 +batch_frames_inout: 0 +num_workers: 2 +subsampling_factor: 1 +num_encs: 1 + +########################################### +# Training # +########################################### +n_epoch: 150 +accum_grad: 8 +global_grad_clip: 5.0 +dist_sampler: False +optim: adam +optim_conf: + lr: 0.002 + weight_decay: 1.0e-6 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 25000 + lr_decay: 1.0 +log_interval: 100 +checkpoint: + kbest_n: 50 + latest_n: 5 diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index 544c1e836..7e72a01fd 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -42,7 +42,7 @@ from paddlespeech.s2t.modules.cmvn import GlobalCMVN from paddlespeech.s2t.modules.ctc import CTCDecoderBase from paddlespeech.s2t.modules.decoder import BiTransformerDecoder from paddlespeech.s2t.modules.decoder import TransformerDecoder -from paddlespeech.s2t.modules.encoder import ConformerEncoder +from paddlespeech.s2t.modules.encoder import ConformerEncoder, SqueezeformerEncoder from paddlespeech.s2t.modules.encoder import TransformerEncoder from paddlespeech.s2t.modules.initializer import DefaultInitializerContext from paddlespeech.s2t.modules.loss import LabelSmoothingLoss @@ -905,6 +905,9 @@ class U2Model(U2DecodeModel): elif encoder_type == 'conformer': encoder = ConformerEncoder( input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) + elif encoder_type == 'squeezeformer': + encoder = SqueezeformerEncoder( + input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) else: raise ValueError(f"not support encoder type:{encoder_type}") diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py index d9568dcc9..e149c5041 100644 --- a/paddlespeech/s2t/modules/attention.py +++ b/paddlespeech/s2t/modules/attention.py @@ -330,3 +330,136 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): self.d_k) # (batch, head, time1, time2) return self.forward_attention(v, scores, mask), new_cache + + +class RelPositionMultiHeadedAttention2(MultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding. + Paper: https://arxiv.org/abs/1901.02860 + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + """ + + def __init__(self, n_head, n_feat, dropout_rate, do_rel_shift=False, adaptive_scale=False, init_weights=False): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate) + # linear transformation for positional encoding + self.linear_pos = Linear(n_feat, n_feat) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.do_rel_shift = do_rel_shift + pos_bias_u = self.create_parameter([self.h, self.d_k], default_initializer=I.XavierUniform()) + self.add_parameter('pos_bias_u', pos_bias_u) + pos_bias_v = self.create_parameter([self.h, self.d_k], default_initializer=I.XavierUniform()) + self.add_parameter('pos_bias_v', pos_bias_v) + self.adaptive_scale = adaptive_scale + ada_scale = self.create_parameter([1, 1, n_feat], default_initializer=I.Constant(1.0)) + self.add_parameter('ada_scale', ada_scale) + ada_bias = self.create_parameter([1, 1, n_feat], default_initializer=I.Constant(0.0)) + self.add_parameter('ada_bias', ada_bias) + if init_weights: + self.init_weights() + + def init_weights(self): + input_max = (self.h * self.d_k) ** -0.5 + self.linear_q._param_attr = paddle.nn.initializer.Uniform(low=-input_max, high=input_max) + self.linear_q._bias_attr = paddle.nn.initializer.Uniform(low=-input_max, high=input_max) + self.linear_k._param_attr = paddle.nn.initializer.Uniform(low=-input_max, high=input_max) + self.linear_k._bias_attr = paddle.nn.initializer.Uniform(low=-input_max, high=input_max) + self.linear_v._param_attr = paddle.nn.initializer.Uniform(low=-input_max, high=input_max) + self.linear_v._bias_attr = paddle.nn.initializer.Uniform(low=-input_max, high=input_max) + self.linear_pos._param_attr = paddle.nn.initializer.Uniform(low=-input_max, high=input_max) + self.linear_pos._bias_attr = paddle.nn.initializer.Uniform(low=-input_max, high=input_max) + self.linear_out._param_attr = paddle.nn.initializer.Uniform(low=-input_max, high=input_max) + self.linear_out._bias_attr = paddle.nn.initializer.Uniform(low=-input_max, high=input_max) + + def rel_shift(self, x, zero_triu: bool = False): + """Compute relative positinal encoding. + Args: + x (paddle.Tensor): Input tensor (batch, head, time1, time1). + zero_triu (bool): If true, return the lower triangular part of + the matrix. + Returns: + paddle.Tensor: Output tensor. (batch, head, time1, time1) + """ + zero_pad = paddle.zeros([x.shape[0], x.shape[1], x.shape[2], 1], dtype=x.dtype) + x_padded = paddle.concat([zero_pad, x], axis=-1) + + x_padded = x_padded.reshape([x.shape[0], x.shape[1], x.shape[3] + 1, x.shape[2]]) + x = x_padded[:, :, 1:].reshape(paddle.shape(x)) # [B, H, T1, T1] + + if zero_triu: + ones = paddle.ones((x.shape[2], x.shape[3])) + x = x * paddle.tril(ones, x.shape[3] - x.shape[2])[None, None, :, :] + + return x + + def forward(self, query: paddle.Tensor, + key: paddle.Tensor, value: paddle.Tensor, + mask: paddle.Tensor = paddle.ones((0, 0, 0), dtype=paddle.bool), + pos_emb: paddle.Tensor = paddle.empty([0]), + cache: paddle.Tensor = paddle.zeros((0, 0, 0, 0)) + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + Args: + query (paddle.Tensor): Query tensor (#batch, time1, size). + key (paddle.Tensor): Key tensor (#batch, time2, size). + value (paddle.Tensor): Value tensor (#batch, time2, size). + mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2), (0, 0, 0) means fake mask. + pos_emb (paddle.Tensor): Positional embedding tensor + (#batch, time2, size). + cache (paddle.Tensor): Cache tensor (1, head, cache_t, d_k * 2), + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + Returns: + paddle.Tensor: Output tensor (#batch, time1, d_model). + paddle.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + """ + if self.adaptive_scale: + query = self.ada_scale * query + self.ada_bias + key = self.ada_scale * key + self.ada_bias + value = self.ada_scale * value + self.ada_bias + + q, k, v = self.forward_qkv(query, key, value) + if cache.shape[0] > 0: + key_cache, value_cache = paddle.split(cache, 2, axis=-1) + k = paddle.concat([key_cache, k], axis=2) + v = paddle.concat([value_cache, v], axis=2) + # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's + # non-trivial to calculate `next_cache_start` here. + new_cache = paddle.concat((k, v), axis=-1) + + n_batch_pos = pos_emb.shape[0] + p = self.linear_pos(pos_emb).reshape([n_batch_pos, -1, self.h, self.d_k]) + p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) + + # (batch, head, time1, d_k) + # q_with_bias_u = (q + self.pos_bias_u).transpose([0, 2, 1, 3]) + q_with_bias_u = q + self.pos_bias_u.unsqueeze(1) + # (batch, head, time1, d_k) + # q_with_bias_v = (q + self.pos_bias_v).transpose([0, 2, 1, 3]) + q_with_bias_v = q + self.pos_bias_v.unsqueeze(1) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + # matrix_ac = paddle.matmul(q_with_bias_u, k.transpose([0, 1, 3, 2])) + matrix_ac = paddle.matmul(q_with_bias_u, k, transpose_y=True) + + # compute matrix b and matrix d + # (batch, head, time1, time2) + # matrix_bd = paddle.matmul(q_with_bias_v, p.transpose([0, 1, 3, 2])) + matrix_bd = paddle.matmul(q_with_bias_v, p, transpose_y=True) + # Remove rel_shift since it is useless in speech recognition, + # and it requires special attention for streaming. + if self.do_rel_shift: + matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2) + + return self.forward_attention(v, scores, mask), new_cache diff --git a/paddlespeech/s2t/modules/conv2d.py b/paddlespeech/s2t/modules/conv2d.py new file mode 100644 index 000000000..4b41d80a4 --- /dev/null +++ b/paddlespeech/s2t/modules/conv2d.py @@ -0,0 +1,56 @@ +from typing import Union, Optional + +import paddle +import paddle.nn.functional as F +from paddle.nn.layer.conv import _ConvNd + +__all__ = ['Conv2DValid'] + + +class Conv2DValid(_ConvNd): + """ + Conv2d operator for VALID mode padding. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: Union[str, int] = 0, + dilation: int = 1, + groups: int = 1, + padding_mode: str = 'zeros', + weight_attr=None, + bias_attr=None, + data_format="NCHW", + valid_trigx: bool = False, + valid_trigy: bool = False + ) -> None: + super(Conv2DValid, self).__init__(in_channels, + out_channels, + kernel_size, + False, + 2, + stride=stride, + padding=padding, + padding_mode=padding_mode, + dilation=dilation, + groups=groups, + weight_attr=weight_attr, + bias_attr=bias_attr, + data_format=data_format) + self.valid_trigx = valid_trigx + self.valid_trigy = valid_trigy + + def _conv_forward(self, input: paddle.Tensor, weight: paddle.Tensor, bias: Optional[paddle.Tensor]): + validx, validy = 0, 0 + if self.valid_trigx: + validx = (input.shape[-2] * (self._stride[-2] - 1) - 1 + self._kernel_size[-2]) // 2 + if self.valid_trigy: + validy = (input.shape[-1] * (self._stride[-1] - 1) - 1 + self._kernel_size[-1]) // 2 + return F.conv2d(input, weight, bias, self._stride, (validx, validy), self._dilation, self._groups) + + def forward(self, input: paddle.Tensor) -> paddle.Tensor: + return self._conv_forward(input, self.weight, self.bias) diff --git a/paddlespeech/s2t/modules/convolution.py b/paddlespeech/s2t/modules/convolution.py new file mode 100644 index 000000000..0b5ab0356 --- /dev/null +++ b/paddlespeech/s2t/modules/convolution.py @@ -0,0 +1,172 @@ +from typing import Tuple + +import paddle +from paddle import nn +from paddle.nn import initializer as I +from typeguard import check_argument_types + +__all__ = ['ConvolutionModule'] + +from paddlespeech.s2t import masked_fill +from paddlespeech.s2t.modules.align import Conv1D, BatchNorm1D, LayerNorm + + +class ConvolutionModule2(nn.Layer): + """ConvolutionModule in Conformer model.""" + + def __init__(self, + channels: int, + kernel_size: int = 15, + activation: nn.Layer = nn.ReLU(), + norm: str = "batch_norm", + causal: bool = False, + bias: bool = True, + adaptive_scale: bool = False, + init_weights: bool = False): + """Construct an ConvolutionModule object. + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernel size of conv layers. + causal (int): Whether use causal convolution or not + """ + assert check_argument_types() + super().__init__() + self.bias = bias + self.channels = channels + self.kernel_size = kernel_size + self.adaptive_scale = adaptive_scale + ada_scale = self.create_parameter([1, 1, channels], default_initializer=I.Constant(1.0)) + self.add_parameter('ada_scale', ada_scale) + ada_bias = self.create_parameter([1, 1, channels], default_initializer=I.Constant(0.0)) + self.add_parameter('ada_bias', ada_bias) + + self.pointwise_conv1 = Conv1D( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias_attr=None + if bias else False, # None for True, using bias as default config + ) + + # self.lorder is used to distinguish if it's a causal convolution, + # if self.lorder > 0: it's a causal convolution, the input will be + # padded with self.lorder frames on the left in forward. + # else: it's a symmetrical convolution + if causal: + padding = 0 + self.lorder = kernel_size - 1 + else: + # kernel_size should be an odd number for none causal convolution + assert (kernel_size - 1) % 2 == 0 + padding = (kernel_size - 1) // 2 + self.lorder = 0 + self.depthwise_conv = Conv1D( + channels, + channels, + kernel_size, + stride=1, + padding=padding, + groups=channels, + bias_attr=None + if bias else False, # None for True, using bias as default config + ) + + assert norm in ['batch_norm', 'layer_norm'] + if norm == "batch_norm": + self.use_layer_norm = False + self.norm = BatchNorm1D(channels) + else: + self.use_layer_norm = True + self.norm = LayerNorm(channels) + + self.pointwise_conv2 = Conv1D( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias_attr=None + if bias else False, # None for True, using bias as default config + ) + self.activation = activation + + if init_weights: + self.init_weights() + + def init_weights(self): + pw_max = self.channels ** -0.5 + dw_max = self.kernel_size ** -0.5 + self.pointwise_conv1._param_attr = paddle.nn.initializer.Uniform(low=-pw_max, high=pw_max) + if self.bias: + self.pointwise_conv1._bias_attr = paddle.nn.initializer.Uniform(low=-pw_max, high=pw_max) + self.depthwise_conv._param_attr = paddle.nn.initializer.Uniform(low=-dw_max, high=dw_max) + if self.bias: + self.depthwise_conv._bias_attr = paddle.nn.initializer.Uniform(low=-dw_max, high=dw_max) + self.pointwise_conv2._param_attr = paddle.nn.initializer.Uniform(low=-pw_max, high=pw_max) + if self.bias: + self.pointwise_conv2._bias_attr = paddle.nn.initializer.Uniform(low=-pw_max, high=pw_max) + + def forward( + self, + x: paddle.Tensor, + mask_pad: paddle.Tensor = paddle.ones([0, 0, 0], dtype=paddle.bool), + cache: paddle.Tensor = paddle.zeros([0, 0, 0]), + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Compute convolution module. + Args: + x (torch.Tensor): Input tensor (#batch, time, channels). + mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), + (0, 0, 0) means fake mask. + cache (torch.Tensor): left context cache, it is only + used in causal convolution (#batch, channels, cache_t), + (0, 0, 0) meas fake cache. + Returns: + torch.Tensor: Output tensor (#batch, time, channels). + """ + if self.adaptive_scale: + x = self.ada_scale * x + self.ada_bias + + # exchange the temporal dimension and the feature dimension + x = x.transpose([0, 2, 1]) # [B, C, T] + + # mask batch padding + if mask_pad.shape[2] > 0: # time > 0 + x = masked_fill(x, mask_pad, 0.0) + + if self.lorder > 0: + if cache.shape[2] == 0: # cache_t == 0 + x = nn.functional.pad(x, [self.lorder, 0], 'constant', 0.0, data_format='NCL') + else: + assert cache.shape[0] == x.shape[0] # B + assert cache.shape[1] == x.shape[1] # C + x = paddle.concat((cache, x), axis=2) + + assert (x.shape[2] > self.lorder) + new_cache = x[:, :, -self.lorder:] # [B, C, T] + else: + # It's better we just return None if no cache is required, + # However, for JIT export, here we just fake one tensor instead of + # None. + new_cache = paddle.zeros([0, 0, 0], dtype=x.dtype) + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channel, dim) + x = nn.functional.glu(x, axis=1) # (batch, channel, dim) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + if self.use_layer_norm: + x = x.transpose([0, 2, 1]) # [B, T, C] + x = self.activation(self.norm(x)) + if self.use_layer_norm: + x = x.transpose([0, 2, 1]) # [B, C, T] + x = self.pointwise_conv2(x) + + # mask batch padding + if mask_pad.shape[2] > 0: # time > 0 + x = masked_fill(x, mask_pad, 0.0) + + x = x.transpose([0, 2, 1]) # [B, T, C] + return x, new_cache diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index fd7bd7b9a..654c1607a 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -14,26 +14,28 @@ # limitations under the License. # Modified from wenet(https://github.com/wenet-e2e/wenet) """Encoder definition.""" -from typing import Tuple +from typing import Tuple, Union, Optional, List import paddle from paddle import nn from typeguard import check_argument_types from paddlespeech.s2t.modules.activation import get_activation -from paddlespeech.s2t.modules.align import LayerNorm -from paddlespeech.s2t.modules.attention import MultiHeadedAttention +from paddlespeech.s2t.modules.align import LayerNorm, Linear +from paddlespeech.s2t.modules.attention import MultiHeadedAttention, RelPositionMultiHeadedAttention2 from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention from paddlespeech.s2t.modules.conformer_convolution import ConvolutionModule +from paddlespeech.s2t.modules.convolution import ConvolutionModule2 from paddlespeech.s2t.modules.embedding import NoPositionalEncoding from paddlespeech.s2t.modules.embedding import PositionalEncoding from paddlespeech.s2t.modules.embedding import RelPositionalEncoding -from paddlespeech.s2t.modules.encoder_layer import ConformerEncoderLayer +from paddlespeech.s2t.modules.encoder_layer import ConformerEncoderLayer, SqueezeformerEncoderLayer from paddlespeech.s2t.modules.encoder_layer import TransformerEncoderLayer from paddlespeech.s2t.modules.mask import add_optional_chunk_mask from paddlespeech.s2t.modules.mask import make_non_pad_mask -from paddlespeech.s2t.modules.positionwise_feed_forward import PositionwiseFeedForward -from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling4 +from paddlespeech.s2t.modules.positionwise_feed_forward import PositionwiseFeedForward, PositionwiseFeedForward2 +from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling4, TimeReductionLayerStream, TimeReductionLayer1D, \ + DepthwiseConv2DSubsampling4, TimeReductionLayer2D from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling6 from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling8 from paddlespeech.s2t.modules.subsampling import LinearNoSubsampling @@ -487,3 +489,372 @@ class ConformerEncoder(BaseEncoder): normalize_before=normalize_before, concat_after=concat_after) for _ in range(num_blocks) ]) + + +class SqueezeformerEncoder(nn.Layer): + def __init__( + self, + input_size: int, + encoder_dim: int = 256, + output_size: int = 256, + attention_heads: int = 4, + num_blocks: int = 12, + reduce_idx: Optional[Union[int, List[int]]] = 5, + recover_idx: Optional[Union[int, List[int]]] = 11, + feed_forward_expansion_factor: int = 4, + dw_stride: bool = False, + input_dropout_rate: float = 0.1, + pos_enc_layer_type: str = "rel_pos", + time_reduction_layer_type: str = "conv1d", + do_rel_shift: bool = True, + feed_forward_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.1, + cnn_module_kernel: int = 31, + cnn_norm_type: str = "layer_norm", + dropout: float = 0.1, + causal: bool = False, + adaptive_scale: bool = True, + activation_type: str = "swish", + init_weights: bool = True, + global_cmvn: paddle.nn.Layer = None, + normalize_before: bool = False, + use_dynamic_chunk: bool = False, + concat_after: bool = False, + static_chunk_size: int = 0, + use_dynamic_left_chunk: bool = False + ): + """Construct SqueezeformerEncoder + + Args: + input_size to use_dynamic_chunk, see in Transformer BaseEncoder. + encoder_dim (int): The hidden dimension of encoder layer. + output_size (int): The output dimension of final projection layer. + attention_heads (int): Num of attention head in attention module. + num_blocks (int): Num of encoder layers. + reduce_idx Optional[Union[int, List[int]]]: + reduce layer index, from 40ms to 80ms per frame. + recover_idx Optional[Union[int, List[int]]]: + recover layer index, from 80ms to 40ms per frame. + feed_forward_expansion_factor (int): Enlarge coefficient of FFN. + dw_stride (bool): Whether do depthwise convolution + on subsampling module. + input_dropout_rate (float): Dropout rate of input projection layer. + pos_enc_layer_type (str): Self attention type. + time_reduction_layer_type (str): Conv1d or Conv2d reduction layer. + do_rel_shift (bool): Whether to do relative shift + operation on rel-attention module. + cnn_module_kernel (int): Kernel size of CNN module. + activation_type (str): Encoder activation function type. + cnn_module_kernel (int): Kernel size of convolution module. + adaptive_scale (bool): Whether to use adaptive scale. + init_weights (bool): Whether to initialize weights. + causal (bool): whether to use causal convolution or not. + """ + assert check_argument_types() + super().__init__() + self.global_cmvn = global_cmvn + self.reduce_idx: Optional[Union[int, List[int]]] = [reduce_idx] \ + if type(reduce_idx) == int else reduce_idx + self.recover_idx: Optional[Union[int, List[int]]] = [recover_idx] \ + if type(recover_idx) == int else recover_idx + self.check_ascending_list() + if reduce_idx is None: + self.time_reduce = None + else: + if recover_idx is None: + self.time_reduce = 'normal' # no recovery at the end + else: + self.time_reduce = 'recover' # recovery at the end + assert len(self.reduce_idx) == len(self.recover_idx) + self.reduce_stride = 2 + self._output_size = output_size + self.normalize_before = normalize_before + self.static_chunk_size = static_chunk_size + self.use_dynamic_chunk = use_dynamic_chunk + self.use_dynamic_left_chunk = use_dynamic_left_chunk + activation = get_activation(activation_type) + + # self-attention module definition + if pos_enc_layer_type != "rel_pos": + encoder_selfattn_layer = MultiHeadedAttention + encoder_selfattn_layer_args = (attention_heads, + output_size, + attention_dropout_rate) + else: + encoder_selfattn_layer = RelPositionMultiHeadedAttention2 + encoder_selfattn_layer_args = (attention_heads, + encoder_dim, + attention_dropout_rate, + do_rel_shift, + adaptive_scale, + init_weights) + + # feed-forward module definition + positionwise_layer = PositionwiseFeedForward2 + positionwise_layer_args = (encoder_dim, + encoder_dim * feed_forward_expansion_factor, + feed_forward_dropout_rate, + activation, + adaptive_scale, + init_weights) + + # convolution module definition + convolution_layer = ConvolutionModule2 + convolution_layer_args = (encoder_dim, cnn_module_kernel, activation, + cnn_norm_type, causal, True, adaptive_scale, init_weights) + + self.embed = DepthwiseConv2DSubsampling4(1, encoder_dim, + RelPositionalEncoding(encoder_dim, dropout_rate=0.1), + dw_stride, + input_size, + input_dropout_rate, + init_weights) + + self.preln = LayerNorm(encoder_dim) + self.encoders = paddle.nn.LayerList([SqueezeformerEncoderLayer( + encoder_dim, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + convolution_layer(*convolution_layer_args), + positionwise_layer(*positionwise_layer_args), + normalize_before, + dropout, + concat_after) for _ in range(num_blocks) + ]) + if time_reduction_layer_type == 'conv1d': + time_reduction_layer = TimeReductionLayer1D + time_reduction_layer_args = { + 'channel': encoder_dim, + 'out_dim': encoder_dim, + } + elif time_reduction_layer_type == 'stream': + time_reduction_layer = TimeReductionLayerStream + time_reduction_layer_args = { + 'channel': encoder_dim, + 'out_dim': encoder_dim, + } + else: + time_reduction_layer = TimeReductionLayer2D + time_reduction_layer_args = {'encoder_dim': encoder_dim} + + self.time_reduction_layer = time_reduction_layer(**time_reduction_layer_args) + self.time_recover_layer = Linear(encoder_dim, encoder_dim) + self.final_proj = None + if output_size != encoder_dim: + self.final_proj = Linear(encoder_dim, output_size) + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs: paddle.Tensor, + xs_lens: paddle.Tensor, + decoding_chunk_size: int = 0, + num_decoding_left_chunks: int = -1, + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Embed positions in tensor. + Args: + xs: padded input tensor (B, L, D) + xs_lens: input length (B) + decoding_chunk_size: decoding chunk size for dynamic chunk + 0: default for training, use random dynamic chunk. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + num_decoding_left_chunks: number of left chunks, this is for decoding, + the chunk size is decoding_chunk_size. + >=0: use num_decoding_left_chunks + <0: use all left chunks + Returns: + encoder output tensor, lens and mask + """ + masks = make_non_pad_mask(xs_lens).unsqueeze(1) # (B, 1, L) + + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + xs, pos_emb, masks = self.embed(xs, masks) + mask_pad = ~masks + chunk_masks = add_optional_chunk_mask(xs, masks, + self.use_dynamic_chunk, + self.use_dynamic_left_chunk, + decoding_chunk_size, + self.static_chunk_size, + num_decoding_left_chunks) + xs_lens = chunk_masks.squeeze(1).sum(1) + xs = self.preln(xs) + recover_activations: \ + List[Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]] = [] + index = 0 + for i, layer in enumerate(self.encoders): + if self.reduce_idx is not None: + if self.time_reduce is not None and i in self.reduce_idx: + recover_activations.append((xs, chunk_masks, pos_emb, mask_pad)) + xs, xs_lens, chunk_masks, mask_pad = self.time_reduction_layer(xs, xs_lens, chunk_masks, mask_pad) + pos_emb = pos_emb[:, ::2, :] + index += 1 + + if self.recover_idx is not None: + if self.time_reduce == 'recover' and i in self.recover_idx: + index -= 1 + recover_tensor, recover_chunk_masks, recover_pos_emb, recover_mask_pad = recover_activations[index] + # recover output length for ctc decode + xs = paddle.repeat_interleave(xs, repeats=2, axis=1) + xs = self.time_recover_layer(xs) + recoverd_t = recover_tensor.shape[1] + xs = recover_tensor + xs[:, :recoverd_t, :] + chunk_masks = recover_chunk_masks + pos_emb = recover_pos_emb + mask_pad = recover_mask_pad + + xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + + if self.final_proj is not None: + xs = self.final_proj(xs) + return xs, masks + + def check_ascending_list(self): + if self.reduce_idx is not None: + assert self.reduce_idx == sorted(self.reduce_idx), \ + "reduce_idx should be int or ascending list" + if self.recover_idx is not None: + assert self.recover_idx == sorted(self.recover_idx), \ + "recover_idx should be int or ascending list" + + def calculate_downsampling_factor(self, i: int) -> int: + if self.reduce_idx is None: + return 1 + else: + reduce_exp, recover_exp = 0, 0 + for exp, rd_idx in enumerate(self.reduce_idx): + if i >= rd_idx: + reduce_exp = exp + 1 + if self.recover_idx is not None: + for exp, rc_idx in enumerate(self.recover_idx): + if i >= rc_idx: + recover_exp = exp + 1 + return int(2 ** (reduce_exp - recover_exp)) + + def forward_chunk( + self, + xs: paddle.Tensor, + offset: int, + required_cache_size: int, + att_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]), + cnn_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]), + att_mask: paddle.Tensor = paddle.ones([0, 0, 0], dtype=paddle.bool), + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """ Forward just one chunk + + Args: + xs (paddle.Tensor): chunk input, with shape (b=1, time, mel-dim), + where `time == (chunk_size - 1) * subsample_rate + \ + subsample.right_context + 1` + offset (int): current offset in encoder output time stamp + required_cache_size (int): cache size required for next chunk + compuation + >=0: actual cache size + <0: means all history cache is required + att_cache (paddle.Tensor): cache tensor for KEY & VALUE in + transformer/conformer attention, with shape + (elayers, head, cache_t1, d_k * 2), where + `head * d_k == hidden-dim` and + `cache_t1 == chunk_size * num_decoding_left_chunks`. + cnn_cache (paddle.Tensor): cache tensor for cnn_module in conformer, + (elayers, b=1, hidden-dim, cache_t2), where + `cache_t2 == cnn.lorder - 1` + + Returns: + paddle.Tensor: output of current input xs, + with shape (b=1, chunk_size, hidden-dim). + paddle.Tensor: new attention cache required for next chunk, with + dynamic shape (elayers, head, ?, d_k * 2) + depending on required_cache_size. + paddle.Tensor: new conformer cnn cache required for next chunk, with + same shape as the original cnn_cache. + """ + assert xs.shape[0] == 1 # batch size must be one + + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + + # tmp_masks is just for interface compatibility, [B=1, C=1, T] + tmp_masks = paddle.ones([1, 1, xs.shape[1]], dtype=paddle.bool) + # before embed, xs=(B, T, D1), pos_emb=(B=1, T, D) + xs, pos_emb, _ = self.embed(xs, tmp_masks, offset=offset) + + # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim) + elayers, cache_t1 = att_cache.shape[0], att_cache.shape[2] + chunk_size = xs.shape[1] + attention_key_size = cache_t1 + chunk_size + pos_emb = self.embed.position_encoding(offset=offset - cache_t1, size=attention_key_size) + if required_cache_size < 0: + next_cache_start = 0 + elif required_cache_size == 0: + next_cache_start = attention_key_size + else: + next_cache_start = max(attention_key_size - required_cache_size, 0) + + r_att_cache = [] + r_cnn_cache = [] + + mask_pad = paddle.ones([1, xs.shape[1]], dtype=paddle.bool) + mask_pad = mask_pad.unsqueeze(1) + max_att_len: int = 0 + recover_activations: \ + List[Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]] = [] + index = 0 + xs_lens = paddle.to_tensor([xs.shape[1]], dtype=paddle.int32) + xs = self.preln(xs) + for i, layer in enumerate(self.encoders): + # NOTE(xcsong): Before layer.forward + # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2), + # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2) + if self.reduce_idx is not None: + if self.time_reduce is not None and i in self.reduce_idx: + recover_activations.append((xs, att_mask, pos_emb, mask_pad)) + xs, xs_lens, att_mask, mask_pad = self.time_reduction_layer(xs, xs_lens, att_mask, mask_pad) + pos_emb = pos_emb[:, ::2, :] + index += 1 + + if self.recover_idx is not None: + if self.time_reduce == 'recover' and i in self.recover_idx: + index -= 1 + recover_tensor, recover_att_mask, recover_pos_emb, recover_mask_pad = recover_activations[index] + # recover output length for ctc decode + xs = paddle.repeat_interleave(xs, repeats=2, axis=1) + xs = self.time_recover_layer(xs) + recoverd_t = recover_tensor.shape[1] + xs = recover_tensor + xs[:, :recoverd_t, :] + att_mask = recover_att_mask + pos_emb = recover_pos_emb + mask_pad = recover_mask_pad + + factor = self.calculate_downsampling_factor(i) + att_cache1 = att_cache[i:i + 1][:, :, ::factor, :][:, :, :pos_emb.shape[1] - xs.shape[1], :] + cnn_cache1 = cnn_cache[i] if cnn_cache.shape[0] > 0 else cnn_cache + xs, _, new_att_cache, new_cnn_cache = layer( + xs, + att_mask, + pos_emb, + att_cache=att_cache1, + cnn_cache=cnn_cache1) + # NOTE(xcsong): After layer.forward + # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2), + # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2) + cached_att = new_att_cache[:, :, next_cache_start // factor:, :] + cached_cnn = new_cnn_cache.unsqueeze(0) + cached_att = cached_att.repeat_interleave(repeats=factor, axis=2) + if i == 0: + # record length for the first block as max length + max_att_len = cached_att.shape[2] + r_att_cache.append(cached_att[:, :, :max_att_len, :]) + r_cnn_cache.append(cached_cnn) + # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2), + # ? may be larger than cache_t1, it depends on required_cache_size + r_att_cache = paddle.concat(r_att_cache, axis=0) + # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2) + r_cnn_cache = paddle.concat(r_cnn_cache, axis=0) + + if self.final_proj is not None: + xs = self.final_proj(xs) + return xs, r_att_cache, r_cnn_cache diff --git a/paddlespeech/s2t/modules/encoder_layer.py b/paddlespeech/s2t/modules/encoder_layer.py index dac62bce3..971afcd8e 100644 --- a/paddlespeech/s2t/modules/encoder_layer.py +++ b/paddlespeech/s2t/modules/encoder_layer.py @@ -276,3 +276,126 @@ class ConformerEncoderLayer(nn.Layer): x = self.norm_final(x) return x, mask, new_att_cache, new_cnn_cache + + +class SqueezeformerEncoderLayer(nn.Layer): + """Encoder layer module.""" + + def __init__( + self, + size: int, + self_attn: paddle.nn.Layer, + feed_forward1: Optional[nn.Layer] = None, + conv_module: Optional[nn.Layer] = None, + feed_forward2: Optional[nn.Layer] = None, + normalize_before: bool = False, + dropout_rate: float = 0.1, + concat_after: bool = False): + """Construct an EncoderLayer object. + + Args: + size (int): Input dimension. + self_attn (paddle.nn.Layer): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + instance can be used as the argument. + feed_forward1 (paddle.nn.Layer): Feed-forward module instance. + `PositionwiseFeedForward` instance can be used as the argument. + conv_module (paddle.nn.Layer): Convolution module instance. + `ConvlutionLayer` instance can be used as the argument. + feed_forward2 (paddle.nn.Layer): Feed-forward module instance. + `PositionwiseFeedForward` instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): + True: use layer_norm before each sub-block. + False: use layer_norm after each sub-block. + """ + super().__init__() + self.size = size + self.self_attn = self_attn + self.layer_norm1 = LayerNorm(size) + self.ffn1 = feed_forward1 + self.layer_norm2 = LayerNorm(size) + self.conv_module = conv_module + self.layer_norm3 = LayerNorm(size) + self.ffn2 = feed_forward2 + self.layer_norm4 = LayerNorm(size) + self.normalize_before = normalize_before + self.dropout = nn.Dropout(dropout_rate) + self.concat_after = concat_after + if concat_after: + self.concat_linear = Linear(size + size, size) + else: + self.concat_linear = nn.Identity() + + def forward( + self, + x: paddle.Tensor, + mask: paddle.Tensor, + pos_emb: paddle.Tensor, + mask_pad: paddle.Tensor = paddle.ones([0, 0, 0], dtype=paddle.bool), + att_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]), + cnn_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]), + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Compute encoded features. + Args: + x (paddle.Tensor): Input tensor (#batch, time, size). + mask (paddle.Tensor): Mask tensor for the input (#batch, time, time). + (0,0,0) means fake mask. + pos_emb (paddle.Tensor): postional encoding, must not be None + for ConformerEncoderLayer + mask_pad (paddle.Tensor): batch padding mask used for conv module. + (#batch, 1,time), (0, 0, 0) means fake mask. + att_cache (paddle.Tensor): Cache tensor of the KEY & VALUE + (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. + cnn_cache (paddle.Tensor): Convolution cache in conformer layer + (1, #batch=1, size, cache_t2). First dim will not be used, just + for dy2st. + Returns: + paddle.Tensor: Output tensor (#batch, time, size). + paddle.Tensor: Mask tensor (#batch, time, time). + paddle.Tensor: att_cache tensor, + (#batch=1, head, cache_t1 + time, d_k * 2). + paddle.Tensor: cnn_cahce tensor (#batch, size, cache_t2). + """ + # self attention module + residual = x + if self.normalize_before: + x = self.layer_norm1(x) + x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, att_cache) + if self.concat_after: + x_concat = paddle.concat((x, x_att), axis=-1) + x = residual + self.concat_linear(x_concat) + else: + x = residual + self.dropout(x_att) + if not self.normalize_before: + x = self.layer_norm1(x) + + # ffn module + residual = x + if self.normalize_before: + x = self.layer_norm2(x) + x = self.ffn1(x) + x = residual + self.dropout(x) + if not self.normalize_before: + x = self.layer_norm2(x) + + # conv module + residual = x + if self.normalize_before: + x = self.layer_norm3(x) + x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) + x = residual + self.dropout(x) + if not self.normalize_before: + x = self.layer_norm3(x) + + # ffn module + residual = x + if self.normalize_before: + x = self.layer_norm4(x) + x = self.ffn2(x) + # we do not use dropout here since it is inside feed forward function + x = residual + self.dropout(x) + if not self.normalize_before: + x = self.layer_norm4(x) + + return x, mask, new_att_cache, new_cnn_cache diff --git a/paddlespeech/s2t/modules/positionwise_feed_forward.py b/paddlespeech/s2t/modules/positionwise_feed_forward.py index c2725dc5c..336199a58 100644 --- a/paddlespeech/s2t/modules/positionwise_feed_forward.py +++ b/paddlespeech/s2t/modules/positionwise_feed_forward.py @@ -17,6 +17,7 @@ import paddle from paddle import nn +from paddle.nn import initializer as I from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.utils.log import Log @@ -32,7 +33,7 @@ class PositionwiseFeedForward(nn.Layer): idim: int, hidden_units: int, dropout_rate: float, - activation: nn.Layer=nn.ReLU()): + activation: nn.Layer = nn.ReLU()): """Construct a PositionwiseFeedForward object. FeedForward are appied on each position of the sequence. @@ -58,3 +59,61 @@ class PositionwiseFeedForward(nn.Layer): output tensor, (B, Lmax, D) """ return self.w_2(self.dropout(self.activation(self.w_1(xs)))) + + +class PositionwiseFeedForward2(paddle.nn.Layer): + """Positionwise feed forward layer. + + FeedForward are appied on each position of the sequence. + The output dim is same with the input dim. + + Args: + idim (int): Input dimenstion. + hidden_units (int): The number of hidden units. + dropout_rate (float): Dropout rate. + activation (paddle.nn.Layer): Activation function + """ + + def __init__(self, + idim: int, + hidden_units: int, + dropout_rate: float, + activation: paddle.nn.Layer = paddle.nn.ReLU(), + adaptive_scale: bool = False, + init_weights: bool = False): + """Construct a PositionwiseFeedForward object.""" + super(PositionwiseFeedForward2, self).__init__() + self.idim = idim + self.hidden_units = hidden_units + self.w_1 = Linear(idim, hidden_units) + self.activation = activation + self.dropout = paddle.nn.Dropout(dropout_rate) + self.w_2 = Linear(hidden_units, idim) + self.adaptive_scale = adaptive_scale + ada_scale = self.create_parameter([1, 1, idim], default_initializer=I.XavierUniform()) + self.add_parameter('ada_scale', ada_scale) + ada_bias = self.create_parameter([1, 1, idim], default_initializer=I.XavierUniform()) + self.add_parameter('ada_bias', ada_bias) + + if init_weights: + self.init_weights() + + def init_weights(self): + ffn1_max = self.idim ** -0.5 + ffn2_max = self.hidden_units ** -0.5 + self.w_1._param_attr = paddle.nn.initializer.Uniform(low=-ffn1_max, high=ffn1_max) + self.w_1._bias_attr = paddle.nn.initializer.Uniform(low=-ffn1_max, high=ffn1_max) + self.w_2._param_attr = paddle.nn.initializer.Uniform(low=-ffn2_max, high=ffn2_max) + self.w_2._bias_attr = paddle.nn.initializer.Uniform(low=-ffn2_max, high=ffn2_max) + + def forward(self, xs: paddle.Tensor) -> paddle.Tensor: + """Forward function. + + Args: + xs: input tensor (B, L, D) + Returns: + output tensor, (B, L, D) + """ + if self.adaptive_scale: + xs = self.ada_scale * xs + self.ada_bias + return self.w_2(self.dropout(self.activation(self.w_1(xs)))) diff --git a/paddlespeech/s2t/modules/subsampling.py b/paddlespeech/s2t/modules/subsampling.py index 782a437ee..09f92acca 100644 --- a/paddlespeech/s2t/modules/subsampling.py +++ b/paddlespeech/s2t/modules/subsampling.py @@ -17,11 +17,14 @@ from typing import Tuple import paddle +import paddle.nn.functional as F from paddle import nn -from paddlespeech.s2t.modules.align import Conv2D +from paddlespeech.s2t import masked_fill +from paddlespeech.s2t.modules.align import Conv2D, Conv1D from paddlespeech.s2t.modules.align import LayerNorm from paddlespeech.s2t.modules.align import Linear +from paddlespeech.s2t.modules.conv2d import Conv2DValid from paddlespeech.s2t.modules.embedding import PositionalEncoding from paddlespeech.s2t.utils.log import Log @@ -249,3 +252,257 @@ class Conv2dSubsampling8(Conv2dSubsampling): x = self.linear(x.transpose([0, 2, 1, 3]).reshape([b, -1, c * f])) x, pos_emb = self.pos_enc(x, offset) return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2] + + +class DepthwiseConv2DSubsampling4(BaseSubsampling): + """Depthwise Convolutional 2D subsampling (to 1/4 length). + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + pos_enc_class (nn.Layer): position encoding class. + dw_stride (int): Whether do depthwise convolution. + input_size (int): filter bank dimension. + + """ + + def __init__( + self, idim: int, odim: int, + pos_enc_class: nn.Layer, + dw_stride: bool = False, + input_size: int = 80, + input_dropout_rate: float = 0.1, + init_weights: bool = True): + super(DepthwiseConv2DSubsampling4, self).__init__() + self.idim = idim + self.odim = odim + self.pw_conv = Conv2D(in_channels=idim, out_channels=odim, kernel_size=3, stride=2) + self.act1 = nn.ReLU() + self.dw_conv = Conv2D(in_channels=odim, out_channels=odim, kernel_size=3, stride=2, + groups=odim if dw_stride else 1) + self.act2 = nn.ReLU() + self.pos_enc = pos_enc_class + self.input_proj = nn.Sequential( + Linear(odim * (((input_size - 1) // 2 - 1) // 2), odim), + nn.Dropout(p=input_dropout_rate)) + if init_weights: + linear_max = (odim * input_size / 4) ** -0.5 + self.input_proj.state_dict()['0.weight'] = paddle.nn.initializer.Uniform(low=-linear_max, high=linear_max) + self.input_proj.state_dict()['0.bias'] = paddle.nn.initializer.Uniform(low=-linear_max, high=linear_max) + + self.subsampling_rate = 4 + # 6 = (3 - 1) * 1 + (3 - 1) * 2 + self.right_context = 6 + + def forward( + self, + x: paddle.Tensor, + x_mask: paddle.Tensor, + offset: int = 0 + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + x = x.unsqueeze(1) # (b, c=1, t, f) + x = self.pw_conv(x) + x = self.act1(x) + x = self.dw_conv(x) + x = self.act2(x) + b, c, t, f = x.shape + x = x.transpose([0, 2, 1, 3]).reshape([b, -1, c * f]) + x, pos_emb = self.pos_enc(x, offset) + x = self.input_proj(x) + return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2] + + +class TimeReductionLayer1D(nn.Layer): + """ + Modified NeMo, + Squeezeformer Time Reduction procedure. + Downsamples the audio by `stride` in the time dimension. + Args: + channel (int): input dimension of + MultiheadAttentionMechanism and PositionwiseFeedForward + out_dim (int): Output dimension of the module. + kernel_size (int): Conv kernel size for + depthwise convolution in convolution module + stride (int): Downsampling factor in time dimension. + """ + + def __init__(self, channel: int, out_dim: int, kernel_size: int = 5, stride: int = 2): + super(TimeReductionLayer1D, self).__init__() + + self.channel = channel + self.out_dim = out_dim + self.kernel_size = kernel_size + self.stride = stride + self.padding = max(0, self.kernel_size - self.stride) + + self.dw_conv = Conv1D( + in_channels=channel, + out_channels=channel, + kernel_size=kernel_size, + stride=stride, + padding=self.padding, + groups=channel, + ) + + self.pw_conv = Conv1D( + in_channels=channel, out_channels=out_dim, + kernel_size=1, stride=1, padding=0, groups=1, + ) + + self.init_weights() + + def init_weights(self): + dw_max = self.kernel_size ** -0.5 + pw_max = self.channel ** -0.5 + self.dw_conv._param_attr = paddle.nn.initializer.Uniform(low=-dw_max, high=dw_max) + self.dw_conv._bias_attr = paddle.nn.initializer.Uniform(low=-dw_max, high=dw_max) + self.pw_conv._param_attr = paddle.nn.initializer.Uniform(low=-pw_max, high=pw_max) + self.pw_conv._bias_attr = paddle.nn.initializer.Uniform(low=-pw_max, high=pw_max) + + def forward(self, xs, xs_lens: paddle.Tensor, + mask: paddle.Tensor = paddle.ones((0, 0, 0), dtype=paddle.bool), + mask_pad: paddle.Tensor = paddle.ones((0, 0, 0), dtype=paddle.bool), + ): + xs = xs.transpose([0, 2, 1]) # [B, C, T] + xs = masked_fill(xs, mask_pad.equal(0), 0.0) + + xs = self.dw_conv(xs) + xs = self.pw_conv(xs) + + xs = xs.transpose([0, 2, 1]) # [B, T, C] + + B, T, D = xs.shape + mask = mask[:, ::self.stride, ::self.stride] + mask_pad = mask_pad[:, :, ::self.stride] + L = mask_pad.shape[-1] + # For JIT exporting, we remove F.pad operator. + if L - T < 0: + xs = xs[:, :L - T, :] + else: + dummy_pad = paddle.zeros([B, L - T, D], dtype=paddle.float32) + xs = paddle.concat([xs, dummy_pad], axis=1) + + xs_lens = (xs_lens + 1) // 2 + return xs, xs_lens, mask, mask_pad + + +class TimeReductionLayer2D(nn.Layer): + def __init__(self, kernel_size: int = 5, stride: int = 2, encoder_dim: int = 256): + super(TimeReductionLayer2D, self).__init__() + self.encoder_dim = encoder_dim + self.kernel_size = kernel_size + self.dw_conv = Conv2DValid(in_channels=encoder_dim, + out_channels=encoder_dim, + kernel_size=(kernel_size, 1), + stride=stride, + valid_trigy=True) + self.pw_conv = Conv2DValid(in_channels=encoder_dim, + out_channels=encoder_dim, + kernel_size=1, + stride=1, + valid_trigx=False, + valid_trigy=False) + + self.kernel_size = kernel_size + self.stride = stride + self.init_weights() + + def init_weights(self): + dw_max = self.kernel_size ** -0.5 + pw_max = self.encoder_dim ** -0.5 + self.dw_conv._param_attr = paddle.nn.initializer.Uniform(low=-dw_max, high=dw_max) + self.dw_conv._bias_attr = paddle.nn.initializer.Uniform(low=-dw_max, high=dw_max) + self.pw_conv._param_attr = paddle.nn.initializer.Uniform(low=-pw_max, high=pw_max) + self.pw_conv._bias_attr = paddle.nn.initializer.Uniform(low=-pw_max, high=pw_max) + + def forward( + self, xs: paddle.Tensor, xs_lens: paddle.Tensor, + mask: paddle.Tensor = paddle.ones((0, 0, 0), dtype=paddle.bool), + mask_pad: paddle.Tensor = paddle.ones((0, 0, 0), dtype=paddle.bool), + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + xs = masked_fill(xs, mask_pad.transpose([0, 2, 1]).equal(0), 0.0) + xs = xs.unsqueeze(1) + padding1 = self.kernel_size - self.stride + xs = F.pad(xs, (0, 0, 0, 0, 0, padding1, 0, 0), mode='constant', value=0.) + xs = self.dw_conv(xs.transpose([0, 3, 2, 1])) + xs = self.pw_conv(xs).transpose([0, 3, 2, 1]).squeeze(1) + tmp_length = xs.shape[1] + xs_lens = (xs_lens + 1) // 2 + padding2 = max(0, (xs_lens.max() - tmp_length).item()) + batch_size, hidden = xs.shape[0], xs.shape[-1] + dummy_pad = paddle.zeros([batch_size, padding2, hidden], dtype=paddle.float32) + xs = paddle.concat([xs, dummy_pad], axis=1) + mask = mask[:, ::2, ::2] + mask_pad = mask_pad[:, :, ::2] + return xs, xs_lens, mask, mask_pad + + +class TimeReductionLayerStream(nn.Layer): + """ + Squeezeformer Time Reduction procedure. + Downsamples the audio by `stride` in the time dimension. + Args: + channel (int): input dimension of + MultiheadAttentionMechanism and PositionwiseFeedForward + out_dim (int): Output dimension of the module. + kernel_size (int): Conv kernel size for + depthwise convolution in convolution module + stride (int): Downsampling factor in time dimension. + """ + + def __init__(self, channel: int, out_dim: int, + kernel_size: int = 1, stride: int = 2): + super(TimeReductionLayerStream, self).__init__() + + self.channel = channel + self.out_dim = out_dim + self.kernel_size = kernel_size + self.stride = stride + + self.dw_conv = Conv1D(in_channels=channel, + out_channels=channel, + kernel_size=kernel_size, + stride=stride, + padding=0, + groups=channel) + + self.pw_conv = Conv1D(in_channels=channel, + out_channels=out_dim, + kernel_size=1, + stride=1, + padding=0, + groups=1) + self.init_weights() + + def init_weights(self): + dw_max = self.kernel_size ** -0.5 + pw_max = self.channel ** -0.5 + self.dw_conv._param_attr = paddle.nn.initializer.Uniform(low=-dw_max, high=dw_max) + self.dw_conv._bias_attr = paddle.nn.initializer.Uniform(low=-dw_max, high=dw_max) + self.pw_conv._param_attr = paddle.nn.initializer.Uniform(low=-pw_max, high=pw_max) + self.pw_conv._bias_attr = paddle.nn.initializer.Uniform(low=-pw_max, high=pw_max) + + def forward(self, xs, xs_lens: paddle.Tensor, + mask: paddle.Tensor = paddle.ones([0, 0, 0], dtype=paddle.bool), + mask_pad: paddle.Tensor = paddle.ones([0, 0, 0], dtype=paddle.bool)): + xs = xs.transpose([0, 2, 1]) # [B, C, T] + xs = masked_fill(xs, mask_pad.equal(0), 0.0) + + xs = self.dw_conv(xs) + xs = self.pw_conv(xs) + + xs = xs.transpose([0, 2, 1]) # [B, T, C] + + B, T, D = xs.shape + mask = mask[:, ::self.stride, ::self.stride] + mask_pad = mask_pad[:, :, ::self.stride] + L = mask_pad.shape[-1] + # For JIT exporting, we remove F.pad operator. + if L - T < 0: + xs = xs[:, :L - T, :] + else: + dummy_pad = paddle.zeros([B, L - T, D], dtype=paddle.float32) + xs = paddle.concat([xs, dummy_pad], axis=1) + + xs_lens = (xs_lens + 1) // 2 + return xs, xs_lens, mask, mask_pad diff --git a/paddlespeech/s2t/utils/utility.py b/paddlespeech/s2t/utils/utility.py index fdd8c0292..fe2fa9167 100644 --- a/paddlespeech/s2t/utils/utility.py +++ b/paddlespeech/s2t/utils/utility.py @@ -130,11 +130,19 @@ def get_subsample(config): Returns: int: subsample rate. """ - input_layer = config["encoder_conf"]["input_layer"] - assert input_layer in ["conv2d", "conv2d6", "conv2d8"] + if config['encoder'] == 'squeezeformer': + input_layer = config["encoder_conf"]["time_reduction_layer_type"] + assert input_layer in ["conv2d", "conv1d", "stream"] + else: + input_layer = config["encoder_conf"]["input_layer"] + assert input_layer in ["conv2d", "conv2d6", "conv2d8"] if input_layer == "conv2d": return 4 elif input_layer == "conv2d6": return 6 elif input_layer == "conv2d8": return 8 + elif input_layer == "conv1d": + return 6 + elif input_layer == "stream": + return 8