|
|
|
@ -271,7 +271,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|
|
|
|
and `head * d_k == size`
|
|
|
|
|
"""
|
|
|
|
|
q, k, v = self.forward_qkv(query, key, value)
|
|
|
|
|
q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
|
|
|
|
|
# q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
|
|
|
|
|
|
|
|
|
|
# when export onnx model, for 1st chunk, we feed
|
|
|
|
|
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
|
|
|
@ -302,9 +302,11 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|
|
|
|
p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
|
|
|
|
|
|
|
|
|
|
# (batch, head, time1, d_k)
|
|
|
|
|
q_with_bias_u = (q + self.pos_bias_u).transpose([0, 2, 1, 3])
|
|
|
|
|
# q_with_bias_u = (q + self.pos_bias_u).transpose([0, 2, 1, 3])
|
|
|
|
|
q_with_bias_u = q + self.pos_bias_u.unsqueeze(1)
|
|
|
|
|
# (batch, head, time1, d_k)
|
|
|
|
|
q_with_bias_v = (q + self.pos_bias_v).transpose([0, 2, 1, 3])
|
|
|
|
|
# q_with_bias_v = (q + self.pos_bias_v).transpose([0, 2, 1, 3])
|
|
|
|
|
q_with_bias_v = q + self.pos_bias_v.unsqueeze(1)
|
|
|
|
|
|
|
|
|
|
# compute attention score
|
|
|
|
|
# first compute matrix a and matrix c
|
|
|
|
|