|
|
|
@ -555,19 +555,19 @@ class MultiheadAttention(nn.Layer):
|
|
|
|
|
q = (
|
|
|
|
|
q.contiguous()
|
|
|
|
|
.view(tgt_len, bsz * self.num_heads, self.q_head_dim)
|
|
|
|
|
.transpose(0, 1)
|
|
|
|
|
.transpose([1, 0, 2])
|
|
|
|
|
)
|
|
|
|
|
if k is not None:
|
|
|
|
|
k = (
|
|
|
|
|
k.contiguous()
|
|
|
|
|
.view(-1, bsz * self.num_heads, self.k_head_dim)
|
|
|
|
|
.transpose(0, 1)
|
|
|
|
|
.transpose([1, 0, 2])
|
|
|
|
|
)
|
|
|
|
|
if v is not None:
|
|
|
|
|
v = (
|
|
|
|
|
v.contiguous()
|
|
|
|
|
.view(-1, bsz * self.num_heads, self.head_dim)
|
|
|
|
|
.transpose(0, 1)
|
|
|
|
|
.transpose([1, 0, 2])
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if saved_state is not None:
|
|
|
|
@ -643,7 +643,8 @@ class MultiheadAttention(nn.Layer):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_weights = paddle.bmm(q, k.transpose(1, 2))
|
|
|
|
|
attn_weights = paddle.matmul(q, k.transpose([0, 2, 1]))
|
|
|
|
|
|
|
|
|
|
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
|
|
|
|
|
|
|
|
|
assert list(attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len]
|
|
|
|
@ -687,13 +688,13 @@ class MultiheadAttention(nn.Layer):
|
|
|
|
|
assert v is not None
|
|
|
|
|
attn = paddle.bmm(attn_probs, v)
|
|
|
|
|
assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
|
|
|
|
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
|
|
|
|
attn = attn.transpose([1, 0, 2]).reshape([tgt_len, bsz, embed_dim])
|
|
|
|
|
attn = self.out_proj(attn)
|
|
|
|
|
attn_weights: Optional[Tensor] = None
|
|
|
|
|
if need_weights:
|
|
|
|
|
attn_weights = attn_weights_float.view(
|
|
|
|
|
bsz, self.num_heads, tgt_len, src_len
|
|
|
|
|
).transpose(1, 0)
|
|
|
|
|
).transpose([1, 0, 2, 3])
|
|
|
|
|
if not need_head_weights:
|
|
|
|
|
# average attention weights over heads
|
|
|
|
|
attn_weights = attn_weights.mean(dim=0)
|
|
|
|
|