transpose in matmul

pull/2425/head
Hui Zhang 2 years ago
parent 3d7ca93861
commit f9e3eaa024

@ -188,8 +188,9 @@ class MultiHeadedAttention(nn.Layer):
# non-trivial to calculate `next_cache_start` here.
new_cache = paddle.concat((k, v), axis=-1)
scores = paddle.matmul(q,
k.transpose([0, 1, 3, 2])) / math.sqrt(self.d_k)
# scores = paddle.matmul(q,
# k.transpose([0, 1, 3, 2])) / math.sqrt(self.d_k)
scores = paddle.matmul(q, k, transpose_y=True) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache
@ -309,11 +310,13 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = paddle.matmul(q_with_bias_u, k.transpose([0, 1, 3, 2]))
# matrix_ac = paddle.matmul(q_with_bias_u, k.transpose([0, 1, 3, 2]))
matrix_ac = paddle.matmul(q_with_bias_u, k, transpose_y=True)
# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd = paddle.matmul(q_with_bias_v, p.transpose([0, 1, 3, 2]))
# matrix_bd = paddle.matmul(q_with_bias_v, p.transpose([0, 1, 3, 2]))
matrix_bd = paddle.matmul(q_with_bias_v, p, transpose_y=True)
# Remove rel_shift since it is useless in speech recognition,
# and it requires special attention for streaming.
# matrix_bd = self.rel_shift(matrix_bd)

Loading…
Cancel
Save