diff --git a/examples/aishell/asr1/conf/chunk_roformer.yaml b/examples/aishell/asr1/conf/chunk_roformer.yaml new file mode 100644 index 00000000..1b752f87 --- /dev/null +++ b/examples/aishell/asr1/conf/chunk_roformer.yaml @@ -0,0 +1,98 @@ +############################################ +# Network Architecture # +############################################ +cmvn_file: +cmvn_file_type: "json" +# encoder related +encoder: conformer +encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 # sublayer output dropout + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: True + cnn_module_kernel: 15 + use_cnn_module: True + activation_type: 'swish' + pos_enc_layer_type: 'rpoe_pos' # abs_pos, rel_pos, rope_pos + selfattention_layer_type: 'rel_selfattn' # unused + causal: true + use_dynamic_chunk: true + cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster + use_dynamic_left_chunk: false +# decoder related +decoder: transformer # transformer, bitransformer +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + r_num_blocks: 3 # only for bitransformer + dropout_rate: 0.1 # sublayer output dropout + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 +# hybrid CTC/attention +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + reverse_weight: 0.3 # only for bitransformer + length_normalized_loss: false + init_type: 'kaiming_uniform' # !Warning: need to convergence + +########################################### +# Data # +########################################### + +train_manifest: data/manifest.train +dev_manifest: data/manifest.dev +test_manifest: data/manifest.test + + +########################################### +# Dataloader # +########################################### + +vocab_filepath: data/lang_char/vocab.txt +spm_model_prefix: '' +unit_type: 'char' +preprocess_config: conf/preprocess.yaml +feat_dim: 80 +stride_ms: 10.0 +window_ms: 25.0 +sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs +batch_size: 32 +maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced +maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced +minibatches: 0 # for debug +batch_count: auto +batch_bins: 0 +batch_frames_in: 0 +batch_frames_out: 0 +batch_frames_inout: 0 +num_workers: 2 +subsampling_factor: 1 +num_encs: 1 + +########################################### +# Training # +########################################### +n_epoch: 240 +accum_grad: 1 +global_grad_clip: 5.0 +dist_sampler: True +optim: adam +optim_conf: + lr: 0.001 + weight_decay: 1.0e-6 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 25000 + lr_decay: 1.0 +log_interval: 100 +checkpoint: + kbest_n: 50 + latest_n: 5 diff --git a/examples/aishell/asr1/conf/chunk_roformer_bidecoder.yaml b/examples/aishell/asr1/conf/chunk_roformer_bidecoder.yaml new file mode 100644 index 00000000..8bf81fa0 --- /dev/null +++ b/examples/aishell/asr1/conf/chunk_roformer_bidecoder.yaml @@ -0,0 +1,98 @@ +############################################ +# Network Architecture # +############################################ +cmvn_file: +cmvn_file_type: "json" +# encoder related +encoder: conformer +encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 # sublayer output dropout + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: True + cnn_module_kernel: 15 + use_cnn_module: True + activation_type: 'swish' + pos_enc_layer_type: 'rpoe_pos' # abs_pos, rel_pos, rope_pos + selfattention_layer_type: 'rel_selfattn' # unused + causal: true + use_dynamic_chunk: true + cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster + use_dynamic_left_chunk: false +# decoder related +decoder: bitransformer # transformer, bitransformer +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 3 + r_num_blocks: 3 # only for bitransformer + dropout_rate: 0.1 # sublayer output dropout + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 +# hybrid CTC/attention +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + reverse_weight: 0.3 # only for bitransformer + length_normalized_loss: false + init_type: 'kaiming_uniform' # !Warning: need to convergence + +########################################### +# Data # +########################################### + +train_manifest: data/manifest.train +dev_manifest: data/manifest.dev +test_manifest: data/manifest.test + + +########################################### +# Dataloader # +########################################### + +vocab_filepath: data/lang_char/vocab.txt +spm_model_prefix: '' +unit_type: 'char' +preprocess_config: conf/preprocess.yaml +feat_dim: 80 +stride_ms: 10.0 +window_ms: 25.0 +sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs +batch_size: 32 +maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced +maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced +minibatches: 0 # for debug +batch_count: auto +batch_bins: 0 +batch_frames_in: 0 +batch_frames_out: 0 +batch_frames_inout: 0 +num_workers: 2 +subsampling_factor: 1 +num_encs: 1 + +########################################### +# Training # +########################################### +n_epoch: 240 +accum_grad: 1 +global_grad_clip: 5.0 +dist_sampler: True +optim: adam +optim_conf: + lr: 0.001 + weight_decay: 1.0e-6 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 25000 + lr_decay: 1.0 +log_interval: 100 +checkpoint: + kbest_n: 50 + latest_n: 5 diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py index 14336c03..386977cd 100644 --- a/paddlespeech/s2t/modules/attention.py +++ b/paddlespeech/s2t/modules/attention.py @@ -26,7 +26,10 @@ from paddlespeech.s2t.utils.log import Log logger = Log(__name__).getlog() -__all__ = ["MultiHeadedAttention", "RelPositionMultiHeadedAttention"] +__all__ = [ + "MultiHeadedAttention", "RelPositionMultiHeadedAttention", + "RoPERelPositionMultiHeadedAttention" +] # Relative Positional Encodings # https://www.jianshu.com/p/c0608efcc26f @@ -165,6 +168,7 @@ class MultiHeadedAttention(nn.Layer): and `head * d_k == size` """ + # (B,T,D) -> (B,T,H,D/H) q, k, v = self.forward_qkv(query, key, value) # when export onnx model, for 1st chunk, we feed @@ -373,3 +377,131 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): self.d_k) # (batch, head, time1, time2) return self.forward_attention(v, scores, mask), new_cache + + +class RoPERelPositionMultiHeadedAttention(MultiHeadedAttention): + """Multi-Head Attention layer with RoPE relative position encoding.""" + + def __init__(self, + n_head, + n_feat, + dropout_rate, + adaptive_scale=False, + init_weights=False): + """Construct an RelPositionMultiHeadedAttention object. + Paper: https://arxiv.org/abs/1901.02860 + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + """ + super().__init__(n_head, n_feat, dropout_rate) + + def align(self, tensor: paddle.Tensor, axes: List[int], ndim=None): + """重新对齐tensor(批量版expand_dims) + axes:原来的第i维对齐新tensor的第axes[i]维; + ndim:新tensor的维度。 + """ + assert len(axes) == tensor.dim() + assert ndim or min(axes) >= 0 + + ndim = ndim or max(axes) + 1 + + # a[0, None, 1] = a[0, np.newaxis, 1] + indices = [None] * ndim + for i in axes: + # slice nothing, a[0, slice(None), 1] = a[0, :, 1] + indices[i] = slice(None) + + return tensor[indices] + + def apply_rotary_position_embeddings(self, sinusoidal, *tensors): + """应用RoPE到tensors中 + 其中,sinusoidal.shape=[B, T, D],tensors为tensor的列表,而 + tensor.shape=[B, T, ..., D], or (B,T,H,D/H) + """ + assert len(tensors) > 0, 'at least one input tensor' + assert all( + [tensor.shape == tensors[0].shape + for tensor in tensors[1:]]), 'all tensors must have the same shape' + + ndim = tensors[0].dim() + + # sinusoidal shape same with tensors[0] + # [B,T,D] -> [B,T,1,D] + sinusoidal = self.align(sinusoidal, [0, 1, -1], ndim) + + # http://man.hubwiz.com/docset/TensorFlow.docset/Contents/Resources/Documents/api_docs/python/tf/keras/backend/repeat_elements.html + # like np.repeat, x (s1, s2, s3), axis 1, (s1, s2*rep, s3) + # [b,T, ..., d/2] -> [b,T, ..., d] + cos_pos = paddle.repeat_interleave(sinusoidal[..., 1::2], 2, axis=-1) + sin_pos = paddle.repeat_interleave(sinusoidal[..., 0::2], 2, axis=-1) + + outputs = [] + for tensor in tensors: + # x2 = [-x2, x1, -x4, x3, ..., -x_d, x_{d-1}] + tensor2 = paddle.stack([-tensor[..., 1::2], tensor[..., ::2]], ndim) + tensor2 = paddle.reshape(tensor2, paddle.shape(tensor)) + + # 公式 34, out = x * cos_pos + x2 * sin_pos + outputs.append(tensor * cos_pos + tensor2 * sin_pos) + return outputs[0] if len(outputs) == 1 else outputs + + def forward(self, + query: paddle.Tensor, + key: paddle.Tensor, + value: paddle.Tensor, + mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), + pos_emb: paddle.Tensor=paddle.empty([0]), + cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]) + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + Args: + query (paddle.Tensor): Query tensor (#batch, time1, size). + key (paddle.Tensor): Key tensor (#batch, time2, size). + value (paddle.Tensor): Value tensor (#batch, time2, size). + mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2), (0, 0, 0) means fake mask. + pos_emb (paddle.Tensor): Positional embedding tensor + (#batch, time2, size). + cache (paddle.Tensor): Cache tensor (1, head, cache_t, d_k * 2), + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + Returns: + paddle.Tensor: Output tensor (#batch, time1, d_model). + paddle.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + """ + q, k, v = self.forward_qkv(query, key, value) + # q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k) + + # when export onnx model, for 1st chunk, we feed + # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) + # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). + # In all modes, `if cache.size(0) > 0` will alwayse be `True` + # and we will always do splitting and + # concatnation(this will simplify onnx export). Note that + # it's OK to concat & split zero-shaped tensors(see code below). + # when export jit model, for 1st chunk, we always feed + # cache(0, 0, 0, 0) since jit supports dynamic if-branch. + # >>> a = torch.ones((1, 2, 0, 4)) + # >>> b = torch.ones((1, 2, 3, 4)) + # >>> c = torch.cat((a, b), dim=2) + # >>> torch.equal(b, c) # True + # >>> d = torch.split(a, 2, dim=-1) + # >>> torch.equal(d[0], d[1]) # True + if cache.shape[0] > 0: + # last dim `d_k * 2` for (key, val) + key_cache, value_cache = paddle.split(cache, 2, axis=-1) + k = paddle.concat([key_cache, k], axis=2) + v = paddle.concat([value_cache, v], axis=2) + # We do cache slicing in encoder.forward_chunk, since it's + # non-trivial to calculate `next_cache_start` here. + new_cache = paddle.concat((k, v), axis=-1) + + # f{q,k}(x_m, m) = R^d_{\theta, m} W_{q,k} x_m, m is position index + q, k = self.apply_rotary_position_embeddings(pos_emb, [q, k]) + # dot(q, k) + scores = paddle.matmul(q, k, transpose_y=True) / math.sqrt(self.d_k) + return self.forward_attention(v, scores, mask), new_cache diff --git a/paddlespeech/s2t/modules/embedding.py b/paddlespeech/s2t/modules/embedding.py index f41a7b5d..8ff2e663 100644 --- a/paddlespeech/s2t/modules/embedding.py +++ b/paddlespeech/s2t/modules/embedding.py @@ -89,14 +89,17 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface): self.max_len = max_len self.xscale = paddle.to_tensor(math.sqrt(self.d_model)) self.dropout = nn.Dropout(p=dropout_rate) + self.base = 10000.0 self.pe = paddle.zeros([1, self.max_len, self.d_model]) #[B=1,T,D] position = paddle.arange( 0, self.max_len, dtype=paddle.float32).unsqueeze(1) #[T, 1] + # base^{-2(i-1)/d)}, i \in (1,2...,d/2) div_term = paddle.exp( - paddle.arange(0, self.d_model, 2, dtype=paddle.float32) * - -(math.log(10000.0) / self.d_model)) + -paddle.arange(0, self.d_model, 2, dtype=paddle.float32) * + (math.log(self.base) / self.d_model)) + # [B,T,D] self.pe[:, :, 0::2] = paddle.sin(position * div_term) self.pe[:, :, 1::2] = paddle.cos(position * div_term) diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index d90d69d7..2c3b8c39 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -28,6 +28,7 @@ from paddlespeech.s2t.modules.align import LayerNorm from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.modules.attention import MultiHeadedAttention from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention +from paddlespeech.s2t.modules.attention import RoPERelPositionMultiHeadedAttention from paddlespeech.s2t.modules.conformer_convolution import ConvolutionModule from paddlespeech.s2t.modules.embedding import NoPositionalEncoding from paddlespeech.s2t.modules.embedding import PositionalEncoding @@ -115,6 +116,8 @@ class BaseEncoder(nn.Layer): pos_enc_class = PositionalEncoding elif pos_enc_layer_type == "rel_pos": pos_enc_class = RelPositionalEncoding + elif pos_enc_layer_type == "rope_pos": + pos_enc_class = RelPositionalEncoding elif pos_enc_layer_type == "no_pos": pos_enc_class = NoPositionalEncoding else: @@ -237,7 +240,7 @@ class BaseEncoder(nn.Layer): chunk_size = xs.shape[1] attention_key_size = cache_t1 + chunk_size - # only used when using `RelPositionMultiHeadedAttention` + # only used when using `RelPositionMultiHeadedAttention` and `RoPERelPositionMultiHeadedAttention` pos_emb = self.embed.position_encoding( offset=offset - cache_t1, size=attention_key_size) @@ -474,9 +477,22 @@ class ConformerEncoder(BaseEncoder): activation = get_activation(activation_type) # self-attention module definition - encoder_selfattn_layer = RelPositionMultiHeadedAttention - encoder_selfattn_layer_args = (attention_heads, output_size, - attention_dropout_rate) + if pos_enc_layer_type == "abs_pos": + encoder_selfattn_layer = MultiHeadedAttention + encoder_selfattn_layer_args = (attention_heads, output_size, + attention_dropout_rate) + elif pos_enc_layer_type == "rel_pos": + encoder_selfattn_layer = RelPositionMultiHeadedAttention + encoder_selfattn_layer_args = (attention_heads, encoder_dim, + attention_dropout_rate) + elif pos_enc_layer_type == "rope_pos": + encoder_selfattn_layer = RoPERelPositionMultiHeadedAttention + encoder_selfattn_layer_args = (attention_heads, encoder_dim, + attention_dropout_rate) + else: + raise ValueError( + f"pos_enc_layer_type {pos_enc_layer_type} not supported.") + # feed-forward module definition positionwise_layer = PositionwiseFeedForward positionwise_layer_args = (output_size, linear_units, dropout_rate, @@ -580,15 +596,23 @@ class SqueezeformerEncoder(nn.Layer): activation = get_activation(activation_type) # self-attention module definition - if pos_enc_layer_type != "rel_pos": + if pos_enc_layer_type == "abs_pos": encoder_selfattn_layer = MultiHeadedAttention encoder_selfattn_layer_args = (attention_heads, output_size, attention_dropout_rate) - else: + elif pos_enc_layer_type == "rel_pos": encoder_selfattn_layer = RelPositionMultiHeadedAttention encoder_selfattn_layer_args = (attention_heads, encoder_dim, attention_dropout_rate, adaptive_scale, init_weights) + elif pos_enc_layer_type == "rope_pos": + encoder_selfattn_layer = RoPERelPositionMultiHeadedAttention + encoder_selfattn_layer_args = (attention_heads, encoder_dim, + attention_dropout_rate, + adaptive_scale, init_weights) + else: + raise ValueError( + f"pos_enc_layer_type {pos_enc_layer_type} not supported.") # feed-forward module definition positionwise_layer = PositionwiseFeedForward diff --git a/paddlespeech/s2t/modules/encoder_layer.py b/paddlespeech/s2t/modules/encoder_layer.py index ecba95e8..0499e742 100644 --- a/paddlespeech/s2t/modules/encoder_layer.py +++ b/paddlespeech/s2t/modules/encoder_layer.py @@ -48,7 +48,7 @@ class TransformerEncoderLayer(nn.Layer): Args: size (int): Input dimension. self_attn (nn.Layer): Self-attention module instance. - `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + `MultiHeadedAttention`, `RelPositionMultiHeadedAttention` or `RoPERelPositionMultiHeadedAttention` instance can be used as the argument. feed_forward (nn.Layer): Feed-forward module instance. `PositionwiseFeedForward`, instance can be used as the argument. @@ -147,7 +147,7 @@ class ConformerEncoderLayer(nn.Layer): Args: size (int): Input dimension. self_attn (nn.Layer): Self-attention module instance. - `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + `MultiHeadedAttention`, `RelPositionMultiHeadedAttention` or `RoPERelPositionMultiHeadedAttention` instance can be used as the argument. feed_forward (nn.Layer): Feed-forward module instance. `PositionwiseFeedForward` instance can be used as the argument. @@ -298,7 +298,7 @@ class SqueezeformerEncoderLayer(nn.Layer): Args: size (int): Input dimension. self_attn (paddle.nn.Layer): Self-attention module instance. - `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + `MultiHeadedAttention`, `RelPositionMultiHeadedAttention` or `RoPERelPositionMultiHeadedAttention` instance can be used as the argument. feed_forward1 (paddle.nn.Layer): Feed-forward module instance. `PositionwiseFeedForward` instance can be used as the argument.