diff --git a/examples/aishell/asr1/RESULTS.md b/examples/aishell/asr1/RESULTS.md index 643d0e22..be771ba5 100644 --- a/examples/aishell/asr1/RESULTS.md +++ b/examples/aishell/asr1/RESULTS.md @@ -1,14 +1,31 @@ # Aishell -## Conformer -paddle version: 2.2.2 -paddlespeech version: 1.0.1 -| Model | Params | Config | Augmentation| Test set | Decode method | Loss | CER | -| --- | --- | --- | --- | --- | --- | --- | --- | -| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | attention | - | 0.0522 | -| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | ctc_greedy_search | - | 0.0481 | -| conformer | 47.07M | conf/conformer.yaml | spec_aug| test | ctc_prefix_beam_search | - | 0.0480 | -| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | attention_rescoring | - | 0.0460 | +## RoFormer Streaming +paddle version: 2.5.0 +paddlespeech version: 1.5.0 + +Tesla V100-SXM2-32GB: 1 node, 4 card +Global BachSize: 32 * 4 +Training Done: 1 day, 12:56:39.639646 +### `decoding.decoding_chunk_size=16` + +> chunk_size=16, ((16 - 1) * 4 + 7) * 10ms = (16 * 4 + 3) * 10ms = 670ms + +| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size & Left Chunks | Loss | CER | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | +| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | attention | 16, -1 | - | 5.63 | +| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | ctc_greedy_search | 16, -1 | - | 6.13 | +| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | ctc_prefix_beam_search | 16, -1 | - | 6.13 | +| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | attention_rescoring | 16, -1 | - | 5.44 | + +### `decoding.decoding_chunk_size=-1` + +| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size & Left Chunks | Loss | CER | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | +| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | attention | -1, -1 | - | 5.39 | +| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | ctc_greedy_search | -1, -1 | - | 5.51 | +| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | ctc_prefix_beam_search | -1, -1 | - | 5.51 | +| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | attention_rescoring | -1, -1 | - | 4.99 | ## Conformer Streaming @@ -24,6 +41,17 @@ Need set `decoding.decoding_chunk_size=16` when decoding. | conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | attention_rescoring | 16, -1 | - | 0.051968 | +## Conformer +paddle version: 2.2.2 +paddlespeech version: 1.0.1 +| Model | Params | Config | Augmentation| Test set | Decode method | Loss | CER | +| --- | --- | --- | --- | --- | --- | --- | --- | +| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | attention | - | 0.0522 | +| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | ctc_greedy_search | - | 0.0481 | +| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | ctc_prefix_beam_search | - | 0.0480 | +| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | attention_rescoring | - | 0.0460 | + + ## Transformer | Model | Params | Config | Augmentation| Test set | Decode method | Loss | CER | diff --git a/examples/aishell/asr1/conf/chunk_roformer.yaml b/examples/aishell/asr1/conf/chunk_roformer.yaml new file mode 100644 index 00000000..a4051a02 --- /dev/null +++ b/examples/aishell/asr1/conf/chunk_roformer.yaml @@ -0,0 +1,98 @@ +############################################ +# Network Architecture # +############################################ +cmvn_file: +cmvn_file_type: "json" +# encoder related +encoder: conformer +encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 # sublayer output dropout + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: True + cnn_module_kernel: 15 + use_cnn_module: True + activation_type: 'swish' + pos_enc_layer_type: 'rope_pos' # abs_pos, rel_pos, rope_pos + selfattention_layer_type: 'rel_selfattn' # unused + causal: true + use_dynamic_chunk: true + cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster + use_dynamic_left_chunk: false +# decoder related +decoder: transformer # transformer, bitransformer +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + r_num_blocks: 0 # only for bitransformer + dropout_rate: 0.1 # sublayer output dropout + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 +# hybrid CTC/attention +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + reverse_weight: 0.0 # only for bitransformer + length_normalized_loss: false + init_type: 'kaiming_uniform' # !Warning: need to convergence + +########################################### +# Data # +########################################### + +train_manifest: data/manifest.train +dev_manifest: data/manifest.dev +test_manifest: data/manifest.test + + +########################################### +# Dataloader # +########################################### + +vocab_filepath: data/lang_char/vocab.txt +spm_model_prefix: '' +unit_type: 'char' +preprocess_config: conf/preprocess.yaml +feat_dim: 80 +stride_ms: 10.0 +window_ms: 25.0 +sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs +batch_size: 32 +maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced +maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced +minibatches: 0 # for debug +batch_count: auto +batch_bins: 0 +batch_frames_in: 0 +batch_frames_out: 0 +batch_frames_inout: 0 +num_workers: 2 +subsampling_factor: 1 +num_encs: 1 + +########################################### +# Training # +########################################### +n_epoch: 240 +accum_grad: 1 +global_grad_clip: 5.0 +dist_sampler: True +optim: adam +optim_conf: + lr: 0.001 + weight_decay: 1.0e-6 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 25000 + lr_decay: 1.0 +log_interval: 100 +checkpoint: + kbest_n: 50 + latest_n: 5 diff --git a/examples/aishell/asr1/conf/chunk_roformer_bidecoder.yaml b/examples/aishell/asr1/conf/chunk_roformer_bidecoder.yaml new file mode 100644 index 00000000..aa3a0aca --- /dev/null +++ b/examples/aishell/asr1/conf/chunk_roformer_bidecoder.yaml @@ -0,0 +1,98 @@ +############################################ +# Network Architecture # +############################################ +cmvn_file: +cmvn_file_type: "json" +# encoder related +encoder: conformer +encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 # sublayer output dropout + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: True + cnn_module_kernel: 15 + use_cnn_module: True + activation_type: 'swish' + pos_enc_layer_type: 'rope_pos' # abs_pos, rel_pos, rope_pos + selfattention_layer_type: 'rel_selfattn' # unused + causal: true + use_dynamic_chunk: true + cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster + use_dynamic_left_chunk: false +# decoder related +decoder: bitransformer # transformer, bitransformer +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 3 + r_num_blocks: 3 # only for bitransformer + dropout_rate: 0.1 # sublayer output dropout + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 +# hybrid CTC/attention +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + reverse_weight: 0.3 # only for bitransformer + length_normalized_loss: false + init_type: 'kaiming_uniform' # !Warning: need to convergence + +########################################### +# Data # +########################################### + +train_manifest: data/manifest.train +dev_manifest: data/manifest.dev +test_manifest: data/manifest.test + + +########################################### +# Dataloader # +########################################### + +vocab_filepath: data/lang_char/vocab.txt +spm_model_prefix: '' +unit_type: 'char' +preprocess_config: conf/preprocess.yaml +feat_dim: 80 +stride_ms: 10.0 +window_ms: 25.0 +sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs +batch_size: 32 +maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced +maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced +minibatches: 0 # for debug +batch_count: auto +batch_bins: 0 +batch_frames_in: 0 +batch_frames_out: 0 +batch_frames_inout: 0 +num_workers: 2 +subsampling_factor: 1 +num_encs: 1 + +########################################### +# Training # +########################################### +n_epoch: 240 +accum_grad: 1 +global_grad_clip: 5.0 +dist_sampler: True +optim: adam +optim_conf: + lr: 0.001 + weight_decay: 1.0e-6 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 25000 + lr_decay: 1.0 +log_interval: 100 +checkpoint: + kbest_n: 50 + latest_n: 5 diff --git a/paddlespeech/dataset/s2t/avg_model.py b/paddlespeech/dataset/s2t/avg_model.py index c5753b72..5bd5cb1f 100755 --- a/paddlespeech/dataset/s2t/avg_model.py +++ b/paddlespeech/dataset/s2t/avg_model.py @@ -20,30 +20,6 @@ import numpy as np import paddle -def define_argparse(): - parser = argparse.ArgumentParser(description='average model') - parser.add_argument('--dst_model', required=True, help='averaged model') - parser.add_argument( - '--ckpt_dir', required=True, help='ckpt model dir for average') - parser.add_argument( - '--val_best', action="store_true", help='averaged model') - parser.add_argument( - '--num', default=5, type=int, help='nums for averaged model') - parser.add_argument( - '--min_epoch', - default=0, - type=int, - help='min epoch used for averaging model') - parser.add_argument( - '--max_epoch', - default=65536, # Big enough - type=int, - help='max epoch used for averaging model') - - args = parser.parse_args() - return args - - def average_checkpoints(dst_model="", ckpt_dir="", val_best=True, @@ -85,7 +61,7 @@ def average_checkpoints(dst_model="", print(path_list) avg = None - num = args.num + num = num assert num == len(path_list) for path in path_list: print(f'Processing {path}') @@ -100,14 +76,14 @@ def average_checkpoints(dst_model="", if avg[k] is not None: avg[k] /= num - paddle.save(avg, args.dst_model) - print(f'Saving to {args.dst_model}') + paddle.save(avg, dst_model) + print(f'Saving to {dst_model}') - meta_path = os.path.splitext(args.dst_model)[0] + '.avg.json' + meta_path = os.path.splitext(dst_model)[0] + '.avg.json' with open(meta_path, 'w') as f: data = json.dumps({ - "mode": 'val_best' if args.val_best else 'latest', - "avg_ckpt": args.dst_model, + "mode": 'val_best' if val_best else 'latest', + "avg_ckpt": dst_model, "val_loss_mean": avg_val_score, "ckpts": path_list, "epochs": selected_epochs.tolist(), @@ -116,9 +92,40 @@ def average_checkpoints(dst_model="", f.write(data + "\n") +def define_argparse(): + parser = argparse.ArgumentParser(description='average model') + parser.add_argument('--dst_model', required=True, help='averaged model') + parser.add_argument( + '--ckpt_dir', required=True, help='ckpt model dir for average') + parser.add_argument( + '--val_best', action="store_true", help='averaged model') + parser.add_argument( + '--num', default=5, type=int, help='nums for averaged model') + parser.add_argument( + '--min_epoch', + default=0, + type=int, + help='min epoch used for averaging model') + parser.add_argument( + '--max_epoch', + default=65536, # Big enough + type=int, + help='max epoch used for averaging model') + + args = parser.parse_args() + print(args) + return args + + def main(): args = define_argparse() - average_checkpoints(args) + average_checkpoints( + dst_model=args.dst_model, + ckpt_dir=args.ckpt_dir, + val_best=args.val_best, + num=args.num, + min_epoch=args.min_epoch, + max_epoch=args.max_epoch) if __name__ == '__main__': diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index f716fa3b..2e1c14ac 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -145,7 +145,6 @@ class U2BaseModel(ASRInterface, nn.Layer): text_lengths) ctc_time = time.time() - start #logger.debug(f"ctc time: {ctc_time}") - if loss_ctc is None: loss = loss_att elif loss_att is None: @@ -916,6 +915,8 @@ class U2Model(U2DecodeModel): decoder_type = configs.get('decoder', 'transformer') logger.debug(f"U2 Decoder type: {decoder_type}") if decoder_type == 'transformer': + configs['model_conf'].pop('reverse_weight', None) + configs['decoder_conf'].pop('r_num_blocks', None) decoder = TransformerDecoder(vocab_size, encoder.output_size(), **configs['decoder_conf']) diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py index 14336c03..10ab3eae 100644 --- a/paddlespeech/s2t/modules/attention.py +++ b/paddlespeech/s2t/modules/attention.py @@ -15,6 +15,7 @@ # Modified from wenet(https://github.com/wenet-e2e/wenet) """Multi-Head Attention layer definition.""" import math +from typing import List from typing import Tuple import paddle @@ -26,7 +27,10 @@ from paddlespeech.s2t.utils.log import Log logger = Log(__name__).getlog() -__all__ = ["MultiHeadedAttention", "RelPositionMultiHeadedAttention"] +__all__ = [ + "MultiHeadedAttention", "RelPositionMultiHeadedAttention", + "RoPERelPositionMultiHeadedAttention" +] # Relative Positional Encodings # https://www.jianshu.com/p/c0608efcc26f @@ -165,6 +169,7 @@ class MultiHeadedAttention(nn.Layer): and `head * d_k == size` """ + # (B,T,D) -> (B,T,H,D/H) q, k, v = self.forward_qkv(query, key, value) # when export onnx model, for 1st chunk, we feed @@ -373,3 +378,139 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): self.d_k) # (batch, head, time1, time2) return self.forward_attention(v, scores, mask), new_cache + + +class RoPERelPositionMultiHeadedAttention(MultiHeadedAttention): + """Multi-Head Attention layer with RoPE relative position encoding.""" + + def __init__(self, + n_head, + n_feat, + dropout_rate, + adaptive_scale=False, + init_weights=False): + """Construct an RelPositionMultiHeadedAttention object. + Paper: https://arxiv.org/abs/1901.02860 + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + """ + super().__init__(n_head, n_feat, dropout_rate) + + def align(self, tensor: paddle.Tensor, axes: List[int], ndim=None): + """重新对齐tensor(批量版expand_dims) + axes:原来的第i维对齐新tensor的第axes[i]维; + ndim:新tensor的维度。 + """ + assert len(axes) == tensor.dim() + assert ndim or min(axes) >= 0 + + ndim = ndim or max(axes) + 1 + + # a[0, None, 1] = a[0, np.newaxis, 1] + indices = [None] * ndim + for i in axes: + # slice nothing, a[0, slice(None), 1] = a[0, :, 1] + indices[i] = slice(None) + + return tensor[indices] + + def apply_rotary_position_embeddings(self, sinusoidal, *tensors): + """应用RoPE到tensors中 + 其中,sinusoidal.shape=[B, T, D],tensors为tensor的列表,而 + tensor.shape=[B, T, ..., D], or (B,H,T,D/H) + """ + assert len(tensors) > 0, 'at least one input tensor' + assert all( + [tensor.shape == tensors[0].shape + for tensor in tensors[1:]]), 'all tensors must have the same shape' + + # (B,H,T,D) + ndim = tensors[0].dim() + _, H, T, D = tensors[0].shape + + # sinusoidal shape same with tensors[0] + # [B,T,D] -> [B,T,H,D/H] -> (B,H,T,D/H) + # sinusoidal = self.align(sinusoidal, [0, 1, -1], ndim) + sinusoidal = sinusoidal.reshape((1, T, H, D)).transpose([0, 2, 1, 3]) + + # http://man.hubwiz.com/docset/TensorFlow.docset/Contents/Resources/Documents/api_docs/python/tf/keras/backend/repeat_elements.html + # like np.repeat, x (s1, s2, s3), axis 1, (s1, s2*rep, s3) + # [b,T, ..., d/2] -> [b,T, ..., d] + cos_pos = paddle.repeat_interleave(sinusoidal[..., 1::2], 2, axis=-1) + sin_pos = paddle.repeat_interleave(sinusoidal[..., 0::2], 2, axis=-1) + outputs = [] + for tensor in tensors: + # x2 = [-x2, x1, -x4, x3, ..., -x_d, x_{d-1}] + tensor2 = paddle.stack([-tensor[..., 1::2], tensor[..., ::2]], ndim) + tensor2 = paddle.reshape(tensor2, paddle.shape(tensor)) + + # 公式 34, out = x * cos_pos + x2 * sin_pos + outputs.append(tensor * cos_pos + tensor2 * sin_pos) + return outputs[0] if len(outputs) == 1 else outputs + + def forward(self, + query: paddle.Tensor, + key: paddle.Tensor, + value: paddle.Tensor, + mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), + pos_emb: paddle.Tensor=paddle.empty([0]), + cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]) + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + Ref: https://github.com/facebookresearch/llama/blob/main/llama/model.py + Args: + query (paddle.Tensor): Query tensor (#batch, time1, size). + key (paddle.Tensor): Key tensor (#batch, time2, size). + value (paddle.Tensor): Value tensor (#batch, time2, size). + mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2), (0, 0, 0) means fake mask. + pos_emb (paddle.Tensor): Positional embedding tensor + (#batch, time2, size). + cache (paddle.Tensor): Cache tensor (1, head, cache_t, d_k * 2), + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + Returns: + paddle.Tensor: Output tensor (#batch, time1, d_model). + paddle.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + """ + q, k, v = self.forward_qkv(query, key, value) + # q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k) + + # f{q,k}(x_m, m) = R^d_{\theta, m} W_{q,k} x_m, m is position index + # q_t always is chunk_size + q_t = q.shape[2] + q = self.apply_rotary_position_embeddings(pos_emb[:, -q_t:, :], q) + # k will increase when in streaming decoding. + k = self.apply_rotary_position_embeddings(pos_emb[:, -q_t:, :], k) + + # when export onnx model, for 1st chunk, we feed + # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) + # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). + # In all modes, `if cache.size(0) > 0` will alwayse be `True` + # and we will always do splitting and + # concatnation(this will simplify onnx export). Note that + # it's OK to concat & split zero-shaped tensors(see code below). + # when export jit model, for 1st chunk, we always feed + # cache(0, 0, 0, 0) since jit supports dynamic if-branch. + # >>> a = torch.ones((1, 2, 0, 4)) + # >>> b = torch.ones((1, 2, 3, 4)) + # >>> c = torch.cat((a, b), dim=2) + # >>> torch.equal(b, c) # True + # >>> d = torch.split(a, 2, dim=-1) + # >>> torch.equal(d[0], d[1]) # True + if cache.shape[0] > 0: + # last dim `d_k * 2` for (key, val) + key_cache, value_cache = paddle.split(cache, 2, axis=-1) + k = paddle.concat([key_cache, k], axis=2) + v = paddle.concat([value_cache, v], axis=2) + # We do cache slicing in encoder.forward_chunk, since it's + # non-trivial to calculate `next_cache_start` here. + new_cache = paddle.concat((k, v), axis=-1) + + # dot(q, k) + scores = paddle.matmul(q, k, transpose_y=True) / math.sqrt(self.d_k) + return self.forward_attention(v, scores, mask), new_cache diff --git a/paddlespeech/s2t/modules/embedding.py b/paddlespeech/s2t/modules/embedding.py index f41a7b5d..1e9f0101 100644 --- a/paddlespeech/s2t/modules/embedding.py +++ b/paddlespeech/s2t/modules/embedding.py @@ -85,18 +85,21 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface): reverse (bool, optional): Not used. Defaults to False. """ nn.Layer.__init__(self) - self.d_model = d_model + self.d_model = paddle.to_tensor(d_model) self.max_len = max_len self.xscale = paddle.to_tensor(math.sqrt(self.d_model)) self.dropout = nn.Dropout(p=dropout_rate) + self.base = paddle.to_tensor(10000.0) self.pe = paddle.zeros([1, self.max_len, self.d_model]) #[B=1,T,D] position = paddle.arange( 0, self.max_len, dtype=paddle.float32).unsqueeze(1) #[T, 1] + # base^{-2(i-1)/d)}, i \in (1,2...,d/2) div_term = paddle.exp( - paddle.arange(0, self.d_model, 2, dtype=paddle.float32) * - -(math.log(10000.0) / self.d_model)) + -paddle.arange(0, self.d_model, 2, dtype=paddle.float32) * + (paddle.log(self.base) / self.d_model)) + # [B,T,D] self.pe[:, :, 0::2] = paddle.sin(position * div_term) self.pe[:, :, 1::2] = paddle.cos(position * div_term) @@ -161,6 +164,98 @@ class RelPositionalEncoding(PositionalEncoding): assert offset + x.shape[ 1] < self.max_len, "offset: {} + x.shape[1]: {} is larger than the max_len: {}".format( offset, x.shape[1], self.max_len) + x = x * self.xscale pos_emb = self.pe[:, offset:offset + x.shape[1]] return self.dropout(x), self.dropout(pos_emb) + + +# RotaryRelPositionalEncoding is same to RelPositionalEncoding +class ScaledRotaryRelPositionalEncoding(RelPositionalEncoding): + """Scaled Rotary Relative positional encoding module. + POSITION INTERPOLATION: : https://arxiv.org/pdf/2306.15595v2.pdf + """ + + def __init__(self, + d_model: int, + dropout_rate: float, + max_len: int=5000, + scale=1): + """ + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int, optional): [Maximum input length.]. Defaults to 5000. + scale (int): Interpolation max input length to `scale * max_len` positions. + """ + super().__init__(d_model, dropout_rate, max_len, reverse=True) + self.pscale = paddle.to_tensor(scale) + self.max_len = max_len * scale + + def sinusoidal_embeddings(self, + pos: paddle.Tensor, + dim: paddle.Tensor, + base=10000) -> paddle.Tensor: + """计算pos位置的dim维sinusoidal编码""" + assert dim % 2 == 0 + # (d/2,) + indices = paddle.arange(0, dim // 2, dtype=pos.dtype) + indices = paddle.pow(paddle.cast(base, pos.dtype), -2 * indices / dim) + # pos (1, T), indices (d/2,) -> (1, T, d/2) + embeddings = paddle.einsum('...,d->...d', pos, indices) + # (1, T, d/2, 2) + embeddings = paddle.stack( + [paddle.sin(embeddings), paddle.cos(embeddings)], axis=-1) + # (1, T, d) + embeddings = paddle.flatten(embeddings, start_axis=-2, stop_axis=-1) + return embeddings + + def forward(self, x: paddle.Tensor, + offset: int=0) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Compute positional encoding. + Args: + x (paddle.Tensor): Input tensor (batch, time, `*`). + Returns: + paddle.Tensor: Encoded tensor (batch, time, `*`). + paddle.Tensor: Positional embedding tensor (1, time, `*`). + """ + x = x * self.xscale + + B, T, D = x.shape + assert D == self.d_model + + # postion interploation + start = 0 + end = T * self.pscale + assert end <= self.max_len + position = paddle.arange(start, end, dtype=x.dtype).unsqueeze(0) + position *= 1.0 / self.pscale + pe = self.sinusoidal_embeddings(position, self.d_model, base=self.base) + + pos_emb = pe[:, offset:offset + x.shape[1]] + return self.dropout(x), self.dropout(pos_emb) + + def position_encoding(self, offset: int, size: int) -> paddle.Tensor: + """ For getting encoding in a streaming fashion + Attention!!!!! + we apply dropout only once at the whole utterance level in a none + streaming way, but will call this function several times with + increasing input size in a streaming scenario, so the dropout will + be applied several times. + Args: + offset (int): start offset + size (int): requried size of position encoding + Returns: + paddle.Tensor: Corresponding position encoding, #[1, T, D]. + """ + # postion interploation + start = offset + end = (offset + size) * self.pscale + assert end <= self.max_len + position = paddle.arange( + start, end, dtype=paddle.get_default_dtype()).unsqueeze(0) + position *= 1.0 / self.pscale + + pe = self.sinusoidal_embeddings(position, self.d_model, base=self.base) + + return self.dropout(pe) diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index d90d69d7..27d7ffbd 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -28,6 +28,7 @@ from paddlespeech.s2t.modules.align import LayerNorm from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.modules.attention import MultiHeadedAttention from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention +from paddlespeech.s2t.modules.attention import RoPERelPositionMultiHeadedAttention from paddlespeech.s2t.modules.conformer_convolution import ConvolutionModule from paddlespeech.s2t.modules.embedding import NoPositionalEncoding from paddlespeech.s2t.modules.embedding import PositionalEncoding @@ -115,6 +116,8 @@ class BaseEncoder(nn.Layer): pos_enc_class = PositionalEncoding elif pos_enc_layer_type == "rel_pos": pos_enc_class = RelPositionalEncoding + elif pos_enc_layer_type == "rope_pos": + pos_enc_class = RelPositionalEncoding elif pos_enc_layer_type == "no_pos": pos_enc_class = NoPositionalEncoding else: @@ -230,14 +233,14 @@ class BaseEncoder(nn.Layer): xs = self.global_cmvn(xs) # before embed, xs=(B, T, D1), pos_emb=(B=1, T, D) - xs, pos_emb, _ = self.embed(xs, tmp_masks, offset=offset) + xs, _, _ = self.embed(xs, tmp_masks, offset=offset) # after embed, xs=(B=1, chunk_size, hidden-dim) elayers, _, cache_t1, _ = att_cache.shape chunk_size = xs.shape[1] attention_key_size = cache_t1 + chunk_size - # only used when using `RelPositionMultiHeadedAttention` + # only used when using `RelPositionMultiHeadedAttention` and `RoPERelPositionMultiHeadedAttention` pos_emb = self.embed.position_encoding( offset=offset - cache_t1, size=attention_key_size) @@ -474,21 +477,35 @@ class ConformerEncoder(BaseEncoder): activation = get_activation(activation_type) # self-attention module definition - encoder_selfattn_layer = RelPositionMultiHeadedAttention - encoder_selfattn_layer_args = (attention_heads, output_size, - attention_dropout_rate) + encoder_dim = output_size + if pos_enc_layer_type == "abs_pos": + encoder_selfattn_layer = MultiHeadedAttention + encoder_selfattn_layer_args = (attention_heads, encoder_dim, + attention_dropout_rate) + elif pos_enc_layer_type == "rel_pos": + encoder_selfattn_layer = RelPositionMultiHeadedAttention + encoder_selfattn_layer_args = (attention_heads, encoder_dim, + attention_dropout_rate) + elif pos_enc_layer_type == "rope_pos": + encoder_selfattn_layer = RoPERelPositionMultiHeadedAttention + encoder_selfattn_layer_args = (attention_heads, encoder_dim, + attention_dropout_rate) + else: + raise ValueError( + f"pos_enc_layer_type {pos_enc_layer_type} not supported.") + # feed-forward module definition positionwise_layer = PositionwiseFeedForward - positionwise_layer_args = (output_size, linear_units, dropout_rate, + positionwise_layer_args = (encoder_dim, linear_units, dropout_rate, activation) # convolution module definition convolution_layer = ConvolutionModule - convolution_layer_args = (output_size, cnn_module_kernel, activation, + convolution_layer_args = (encoder_dim, cnn_module_kernel, activation, cnn_module_norm, causal) self.encoders = nn.LayerList([ ConformerEncoderLayer( - size=output_size, + size=encoder_dim, self_attn=encoder_selfattn_layer(*encoder_selfattn_layer_args), feed_forward=positionwise_layer(*positionwise_layer_args), feed_forward_macaron=positionwise_layer( @@ -580,15 +597,23 @@ class SqueezeformerEncoder(nn.Layer): activation = get_activation(activation_type) # self-attention module definition - if pos_enc_layer_type != "rel_pos": + if pos_enc_layer_type == "abs_pos": encoder_selfattn_layer = MultiHeadedAttention encoder_selfattn_layer_args = (attention_heads, output_size, attention_dropout_rate) - else: + elif pos_enc_layer_type == "rel_pos": encoder_selfattn_layer = RelPositionMultiHeadedAttention encoder_selfattn_layer_args = (attention_heads, encoder_dim, attention_dropout_rate, adaptive_scale, init_weights) + elif pos_enc_layer_type == "rope_pos": + encoder_selfattn_layer = RoPERelPositionMultiHeadedAttention + encoder_selfattn_layer_args = (attention_heads, encoder_dim, + attention_dropout_rate, + adaptive_scale, init_weights) + else: + raise ValueError( + f"pos_enc_layer_type {pos_enc_layer_type} not supported.") # feed-forward module definition positionwise_layer = PositionwiseFeedForward diff --git a/paddlespeech/s2t/modules/encoder_layer.py b/paddlespeech/s2t/modules/encoder_layer.py index ecba95e8..0499e742 100644 --- a/paddlespeech/s2t/modules/encoder_layer.py +++ b/paddlespeech/s2t/modules/encoder_layer.py @@ -48,7 +48,7 @@ class TransformerEncoderLayer(nn.Layer): Args: size (int): Input dimension. self_attn (nn.Layer): Self-attention module instance. - `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + `MultiHeadedAttention`, `RelPositionMultiHeadedAttention` or `RoPERelPositionMultiHeadedAttention` instance can be used as the argument. feed_forward (nn.Layer): Feed-forward module instance. `PositionwiseFeedForward`, instance can be used as the argument. @@ -147,7 +147,7 @@ class ConformerEncoderLayer(nn.Layer): Args: size (int): Input dimension. self_attn (nn.Layer): Self-attention module instance. - `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + `MultiHeadedAttention`, `RelPositionMultiHeadedAttention` or `RoPERelPositionMultiHeadedAttention` instance can be used as the argument. feed_forward (nn.Layer): Feed-forward module instance. `PositionwiseFeedForward` instance can be used as the argument. @@ -298,7 +298,7 @@ class SqueezeformerEncoderLayer(nn.Layer): Args: size (int): Input dimension. self_attn (paddle.nn.Layer): Self-attention module instance. - `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + `MultiHeadedAttention`, `RelPositionMultiHeadedAttention` or `RoPERelPositionMultiHeadedAttention` instance can be used as the argument. feed_forward1 (paddle.nn.Layer): Feed-forward module instance. `PositionwiseFeedForward` instance can be used as the argument.