diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py index 548564a25..10ab3eaea 100644 --- a/paddlespeech/s2t/modules/attention.py +++ b/paddlespeech/s2t/modules/attention.py @@ -459,6 +459,7 @@ class RoPERelPositionMultiHeadedAttention(MultiHeadedAttention): cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]) ) -> Tuple[paddle.Tensor, paddle.Tensor]: """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + Ref: https://github.com/facebookresearch/llama/blob/main/llama/model.py Args: query (paddle.Tensor): Query tensor (#batch, time1, size). key (paddle.Tensor): Key tensor (#batch, time2, size). @@ -476,10 +477,16 @@ class RoPERelPositionMultiHeadedAttention(MultiHeadedAttention): where `cache_t == chunk_size * num_decoding_left_chunks` and `head * d_k == size` """ - q, k, v = self.forward_qkv(query, key, value) # q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k) + # f{q,k}(x_m, m) = R^d_{\theta, m} W_{q,k} x_m, m is position index + # q_t always is chunk_size + q_t = q.shape[2] + q = self.apply_rotary_position_embeddings(pos_emb[:, -q_t:, :], q) + # k will increase when in streaming decoding. + k = self.apply_rotary_position_embeddings(pos_emb[:, -q_t:, :], k) + # when export onnx model, for 1st chunk, we feed # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). @@ -504,13 +511,6 @@ class RoPERelPositionMultiHeadedAttention(MultiHeadedAttention): # non-trivial to calculate `next_cache_start` here. new_cache = paddle.concat((k, v), axis=-1) - # f{q,k}(x_m, m) = R^d_{\theta, m} W_{q,k} x_m, m is position index - # q_t always is chunk_size - q_t = q.shape[2] - q = self.apply_rotary_position_embeddings(pos_emb[:, -q_t:, :], q) - # k will increase when in streaming decoding. - k = self.apply_rotary_position_embeddings(pos_emb, k) - # dot(q, k) scores = paddle.matmul(q, k, transpose_y=True) / math.sqrt(self.d_k) return self.forward_attention(v, scores, mask), new_cache