diff --git a/paddlespeech/s2t/modules/embedding.py b/paddlespeech/s2t/modules/embedding.py index 8ff2e6636..db5647d6b 100644 --- a/paddlespeech/s2t/modules/embedding.py +++ b/paddlespeech/s2t/modules/embedding.py @@ -167,3 +167,40 @@ class RelPositionalEncoding(PositionalEncoding): x = x * self.xscale pos_emb = self.pe[:, offset:offset + x.shape[1]] return self.dropout(x), self.dropout(pos_emb) + + +# RotaryRelPositionalEncoding is same to RelPositionalEncoding +class ScaledRotaryRelPositionalEncoding(RelPositionalEncoding): + """Scaled Rotary Relative positional encoding module. + POSITION INTERPOLATION: : https://arxiv.org/pdf/2306.15595v2.pdf + """ + + def __init__(self, + d_model: int, + dropout_rate: float, + max_len: int=5000, + scale=1): + """ + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int, optional): [Maximum input length.]. Defaults to 5000. + scale (int): Interpolation max input length to `scale * max_len` positions. + """ + super().__init__(d_model, dropout_rate, max_len, reverse=True) + self.scale = scale + self.max_len = max_len * scale + + position = paddle.arange( + 0, self.max_len, dtype=paddle.float32).unsqueeze(1) #[T, 1] + # position interpoloation + position *= 1.0 / self.scale + + # base^{-2(i-1)/d)}, i \in (1,2...,d/2) + div_term = paddle.exp( + -paddle.arange(0, self.d_model, 2, dtype=paddle.float32) * + (math.log(self.base) / self.d_model)) + + # [B,T,D] + self.pe[:, :, 0::2] = paddle.sin(position * div_term) + self.pe[:, :, 1::2] = paddle.cos(position * div_term)