From ba874db5dc94b92be2ca63dc6163e5e813a518a4 Mon Sep 17 00:00:00 2001 From: jiamingkong Date: Tue, 30 May 2023 17:52:02 +0800 Subject: [PATCH] Fixed the transpose usages ignored before --- paddlespeech/s2t/models/wavlm/modules/modules.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/paddlespeech/s2t/models/wavlm/modules/modules.py b/paddlespeech/s2t/models/wavlm/modules/modules.py index 5ef42e60..f14e4016 100644 --- a/paddlespeech/s2t/models/wavlm/modules/modules.py +++ b/paddlespeech/s2t/models/wavlm/modules/modules.py @@ -555,19 +555,19 @@ class MultiheadAttention(nn.Layer): q = ( q.contiguous() .view(tgt_len, bsz * self.num_heads, self.q_head_dim) - .transpose(0, 1) + .transpose([1, 0, 2]) ) if k is not None: k = ( k.contiguous() .view(-1, bsz * self.num_heads, self.k_head_dim) - .transpose(0, 1) + .transpose([1, 0, 2]) ) if v is not None: v = ( v.contiguous() .view(-1, bsz * self.num_heads, self.head_dim) - .transpose(0, 1) + .transpose([1, 0, 2]) ) if saved_state is not None: @@ -643,7 +643,8 @@ class MultiheadAttention(nn.Layer): ) - attn_weights = paddle.bmm(q, k.transpose(1, 2)) + attn_weights = paddle.matmul(q, k.transpose([0, 2, 1])) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) assert list(attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len] @@ -687,13 +688,13 @@ class MultiheadAttention(nn.Layer): assert v is not None attn = paddle.bmm(attn_probs, v) assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim] - attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = attn.transpose([1, 0, 2]).reshape([tgt_len, bsz, embed_dim]) attn = self.out_proj(attn) attn_weights: Optional[Tensor] = None if need_weights: attn_weights = attn_weights_float.view( bsz, self.num_heads, tgt_len, src_len - ).transpose(1, 0) + ).transpose([1, 0, 2, 3]) if not need_head_weights: # average attention weights over heads attn_weights = attn_weights.mean(dim=0)