Fixed the transpose usages ignored before

pull/3242/head
jiamingkong 1 year ago
parent 0e2068e2cf
commit ba874db5dc

@ -555,19 +555,19 @@ class MultiheadAttention(nn.Layer):
q = (
q.contiguous()
.view(tgt_len, bsz * self.num_heads, self.q_head_dim)
.transpose(0, 1)
.transpose([1, 0, 2])
)
if k is not None:
k = (
k.contiguous()
.view(-1, bsz * self.num_heads, self.k_head_dim)
.transpose(0, 1)
.transpose([1, 0, 2])
)
if v is not None:
v = (
v.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
.transpose([1, 0, 2])
)
if saved_state is not None:
@ -643,7 +643,8 @@ class MultiheadAttention(nn.Layer):
)
attn_weights = paddle.bmm(q, k.transpose(1, 2))
attn_weights = paddle.matmul(q, k.transpose([0, 2, 1]))
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
assert list(attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len]
@ -687,13 +688,13 @@ class MultiheadAttention(nn.Layer):
assert v is not None
attn = paddle.bmm(attn_probs, v)
assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = attn.transpose([1, 0, 2]).reshape([tgt_len, bsz, embed_dim])
attn = self.out_proj(attn)
attn_weights: Optional[Tensor] = None
if need_weights:
attn_weights = attn_weights_float.view(
bsz, self.num_heads, tgt_len, src_len
).transpose(1, 0)
).transpose([1, 0, 2, 3])
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dim=0)

Loading…
Cancel
Save