|
|
@ -15,8 +15,8 @@
|
|
|
|
# Modified from wenet(https://github.com/wenet-e2e/wenet)
|
|
|
|
# Modified from wenet(https://github.com/wenet-e2e/wenet)
|
|
|
|
"""Multi-Head Attention layer definition."""
|
|
|
|
"""Multi-Head Attention layer definition."""
|
|
|
|
import math
|
|
|
|
import math
|
|
|
|
from typing import Tuple
|
|
|
|
|
|
|
|
from typing import List
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
from typing import Tuple
|
|
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
import paddle
|
|
|
|
from paddle import nn
|
|
|
|
from paddle import nn
|
|
|
@ -428,7 +428,7 @@ class RoPERelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|
|
|
|
|
|
|
|
|
|
|
# (B,H,T,D)
|
|
|
|
# (B,H,T,D)
|
|
|
|
ndim = tensors[0].dim()
|
|
|
|
ndim = tensors[0].dim()
|
|
|
|
_,H,T,D = tensors[0].shape
|
|
|
|
_, H, T, D = tensors[0].shape
|
|
|
|
|
|
|
|
|
|
|
|
# sinusoidal shape same with tensors[0]
|
|
|
|
# sinusoidal shape same with tensors[0]
|
|
|
|
# [B,T,D] -> [B,T,H,D/H] -> (B,H,T,D/H)
|
|
|
|
# [B,T,D] -> [B,T,H,D/H] -> (B,H,T,D/H)
|
|
|
@ -476,6 +476,7 @@ class RoPERelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|
|
|
where `cache_t == chunk_size * num_decoding_left_chunks`
|
|
|
|
where `cache_t == chunk_size * num_decoding_left_chunks`
|
|
|
|
and `head * d_k == size`
|
|
|
|
and `head * d_k == size`
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
q, k, v = self.forward_qkv(query, key, value)
|
|
|
|
q, k, v = self.forward_qkv(query, key, value)
|
|
|
|
# q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
|
|
|
|
# q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
|
|
|
|
|
|
|
|
|
|
|
@ -504,7 +505,12 @@ class RoPERelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|
|
|
new_cache = paddle.concat((k, v), axis=-1)
|
|
|
|
new_cache = paddle.concat((k, v), axis=-1)
|
|
|
|
|
|
|
|
|
|
|
|
# f{q,k}(x_m, m) = R^d_{\theta, m} W_{q,k} x_m, m is position index
|
|
|
|
# f{q,k}(x_m, m) = R^d_{\theta, m} W_{q,k} x_m, m is position index
|
|
|
|
q, k = self.apply_rotary_position_embeddings(pos_emb, q, k)
|
|
|
|
# q_t always is chunk_size
|
|
|
|
|
|
|
|
q_t = q.shape[2]
|
|
|
|
|
|
|
|
q = self.apply_rotary_position_embeddings(pos_emb[:, -q_t:, :], q)
|
|
|
|
|
|
|
|
# k will increase when in streaming decoding.
|
|
|
|
|
|
|
|
k = self.apply_rotary_position_embeddings(pos_emb, k)
|
|
|
|
|
|
|
|
|
|
|
|
# dot(q, k)
|
|
|
|
# dot(q, k)
|
|
|
|
scores = paddle.matmul(q, k, transpose_y=True) / math.sqrt(self.d_k)
|
|
|
|
scores = paddle.matmul(q, k, transpose_y=True) / math.sqrt(self.d_k)
|
|
|
|
return self.forward_attention(v, scores, mask), new_cache
|
|
|
|
return self.forward_attention(v, scores, mask), new_cache
|
|
|
|