elimiate attn transpose

pull/2425/head
Hui Zhang 3 years ago
parent f9e3eaa024
commit 46088c0a16

@ -271,7 +271,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
and `head * d_k == size`
"""
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
# q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
@ -302,9 +302,11 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose([0, 2, 1, 3])
# q_with_bias_u = (q + self.pos_bias_u).transpose([0, 2, 1, 3])
q_with_bias_u = q + self.pos_bias_u.unsqueeze(1)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose([0, 2, 1, 3])
# q_with_bias_v = (q + self.pos_bias_v).transpose([0, 2, 1, 3])
q_with_bias_v = q + self.pos_bias_v.unsqueeze(1)
# compute attention score
# first compute matrix a and matrix c

Loading…
Cancel
Save