elimiate attn transpose

pull/2425/head
Hui Zhang 3 years ago
parent f9e3eaa024
commit 46088c0a16

@ -271,7 +271,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
and `head * d_k == size` and `head * d_k == size`
""" """
q, k, v = self.forward_qkv(query, key, value) q, k, v = self.forward_qkv(query, key, value)
q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k) # q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
# when export onnx model, for 1st chunk, we feed # when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
@ -302,9 +302,11 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
# (batch, head, time1, d_k) # (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose([0, 2, 1, 3]) # q_with_bias_u = (q + self.pos_bias_u).transpose([0, 2, 1, 3])
q_with_bias_u = q + self.pos_bias_u.unsqueeze(1)
# (batch, head, time1, d_k) # (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose([0, 2, 1, 3]) # q_with_bias_v = (q + self.pos_bias_v).transpose([0, 2, 1, 3])
q_with_bias_v = q + self.pos_bias_v.unsqueeze(1)
# compute attention score # compute attention score
# first compute matrix a and matrix c # first compute matrix a and matrix c

Loading…
Cancel
Save