Fixed the transpose usages ignored before

pull/3242/head
jiamingkong 2 years ago
parent 0e2068e2cf
commit ba874db5dc

@ -555,19 +555,19 @@ class MultiheadAttention(nn.Layer):
q = ( q = (
q.contiguous() q.contiguous()
.view(tgt_len, bsz * self.num_heads, self.q_head_dim) .view(tgt_len, bsz * self.num_heads, self.q_head_dim)
.transpose(0, 1) .transpose([1, 0, 2])
) )
if k is not None: if k is not None:
k = ( k = (
k.contiguous() k.contiguous()
.view(-1, bsz * self.num_heads, self.k_head_dim) .view(-1, bsz * self.num_heads, self.k_head_dim)
.transpose(0, 1) .transpose([1, 0, 2])
) )
if v is not None: if v is not None:
v = ( v = (
v.contiguous() v.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim) .view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1) .transpose([1, 0, 2])
) )
if saved_state is not None: if saved_state is not None:
@ -643,7 +643,8 @@ class MultiheadAttention(nn.Layer):
) )
attn_weights = paddle.bmm(q, k.transpose(1, 2)) attn_weights = paddle.matmul(q, k.transpose([0, 2, 1]))
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
assert list(attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len] assert list(attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len]
@ -687,13 +688,13 @@ class MultiheadAttention(nn.Layer):
assert v is not None assert v is not None
attn = paddle.bmm(attn_probs, v) attn = paddle.bmm(attn_probs, v)
assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim] assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn = attn.transpose([1, 0, 2]).reshape([tgt_len, bsz, embed_dim])
attn = self.out_proj(attn) attn = self.out_proj(attn)
attn_weights: Optional[Tensor] = None attn_weights: Optional[Tensor] = None
if need_weights: if need_weights:
attn_weights = attn_weights_float.view( attn_weights = attn_weights_float.view(
bsz, self.num_heads, tgt_len, src_len bsz, self.num_heads, tgt_len, src_len
).transpose(1, 0) ).transpose([1, 0, 2, 3])
if not need_head_weights: if not need_head_weights:
# average attention weights over heads # average attention weights over heads
attn_weights = attn_weights.mean(dim=0) attn_weights = attn_weights.mean(dim=0)

Loading…
Cancel
Save