|
|
|
@ -16,6 +16,7 @@
|
|
|
|
|
"""Multi-Head Attention layer definition."""
|
|
|
|
|
import math
|
|
|
|
|
from typing import Tuple
|
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
|
from paddle import nn
|
|
|
|
@ -418,25 +419,27 @@ class RoPERelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|
|
|
|
def apply_rotary_position_embeddings(self, sinusoidal, *tensors):
|
|
|
|
|
"""应用RoPE到tensors中
|
|
|
|
|
其中,sinusoidal.shape=[B, T, D],tensors为tensor的列表,而
|
|
|
|
|
tensor.shape=[B, T, ..., D], or (B,T,H,D/H)
|
|
|
|
|
tensor.shape=[B, T, ..., D], or (B,H,T,D/H)
|
|
|
|
|
"""
|
|
|
|
|
assert len(tensors) > 0, 'at least one input tensor'
|
|
|
|
|
assert all(
|
|
|
|
|
[tensor.shape == tensors[0].shape
|
|
|
|
|
for tensor in tensors[1:]]), 'all tensors must have the same shape'
|
|
|
|
|
|
|
|
|
|
# (B,H,T,D)
|
|
|
|
|
ndim = tensors[0].dim()
|
|
|
|
|
_,H,T,D = tensors[0].shape
|
|
|
|
|
|
|
|
|
|
# sinusoidal shape same with tensors[0]
|
|
|
|
|
# [B,T,D] -> [B,T,1,D]
|
|
|
|
|
sinusoidal = self.align(sinusoidal, [0, 1, -1], ndim)
|
|
|
|
|
# [B,T,D] -> [B,T,H,D/H] -> (B,H,T,D/H)
|
|
|
|
|
# sinusoidal = self.align(sinusoidal, [0, 1, -1], ndim)
|
|
|
|
|
sinusoidal = sinusoidal.reshape((1, T, H, D)).transpose([0, 2, 1, 3])
|
|
|
|
|
|
|
|
|
|
# http://man.hubwiz.com/docset/TensorFlow.docset/Contents/Resources/Documents/api_docs/python/tf/keras/backend/repeat_elements.html
|
|
|
|
|
# like np.repeat, x (s1, s2, s3), axis 1, (s1, s2*rep, s3)
|
|
|
|
|
# [b,T, ..., d/2] -> [b,T, ..., d]
|
|
|
|
|
cos_pos = paddle.repeat_interleave(sinusoidal[..., 1::2], 2, axis=-1)
|
|
|
|
|
sin_pos = paddle.repeat_interleave(sinusoidal[..., 0::2], 2, axis=-1)
|
|
|
|
|
|
|
|
|
|
outputs = []
|
|
|
|
|
for tensor in tensors:
|
|
|
|
|
# x2 = [-x2, x1, -x4, x3, ..., -x_d, x_{d-1}]
|
|
|
|
@ -501,7 +504,7 @@ class RoPERelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|
|
|
|
new_cache = paddle.concat((k, v), axis=-1)
|
|
|
|
|
|
|
|
|
|
# f{q,k}(x_m, m) = R^d_{\theta, m} W_{q,k} x_m, m is position index
|
|
|
|
|
q, k = self.apply_rotary_position_embeddings(pos_emb, [q, k])
|
|
|
|
|
q, k = self.apply_rotary_position_embeddings(pos_emb, q, k)
|
|
|
|
|
# dot(q, k)
|
|
|
|
|
scores = paddle.matmul(q, k, transpose_y=True) / math.sqrt(self.d_k)
|
|
|
|
|
return self.forward_attention(v, scores, mask), new_cache
|
|
|
|
|