From f9e3eaa024218a5310c24bd504d4468826867bbd Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Mon, 26 Sep 2022 11:55:26 +0000 Subject: [PATCH] transpose in matmul --- paddlespeech/s2t/modules/attention.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py index 2d236743a..c02de15e8 100644 --- a/paddlespeech/s2t/modules/attention.py +++ b/paddlespeech/s2t/modules/attention.py @@ -188,8 +188,9 @@ class MultiHeadedAttention(nn.Layer): # non-trivial to calculate `next_cache_start` here. new_cache = paddle.concat((k, v), axis=-1) - scores = paddle.matmul(q, - k.transpose([0, 1, 3, 2])) / math.sqrt(self.d_k) + # scores = paddle.matmul(q, + # k.transpose([0, 1, 3, 2])) / math.sqrt(self.d_k) + scores = paddle.matmul(q, k, transpose_y=True) / math.sqrt(self.d_k) return self.forward_attention(v, scores, mask), new_cache @@ -309,11 +310,13 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): # first compute matrix a and matrix c # as described in https://arxiv.org/abs/1901.02860 Section 3.3 # (batch, head, time1, time2) - matrix_ac = paddle.matmul(q_with_bias_u, k.transpose([0, 1, 3, 2])) + # matrix_ac = paddle.matmul(q_with_bias_u, k.transpose([0, 1, 3, 2])) + matrix_ac = paddle.matmul(q_with_bias_u, k, transpose_y=True) # compute matrix b and matrix d # (batch, head, time1, time2) - matrix_bd = paddle.matmul(q_with_bias_v, p.transpose([0, 1, 3, 2])) + # matrix_bd = paddle.matmul(q_with_bias_v, p.transpose([0, 1, 3, 2])) + matrix_bd = paddle.matmul(q_with_bias_v, p, transpose_y=True) # Remove rel_shift since it is useless in speech recognition, # and it requires special attention for streaming. # matrix_bd = self.rel_shift(matrix_bd)