|
|
|
@ -70,10 +70,11 @@ class MultiHeadedAttention(nn.Layer):
|
|
|
|
|
paddle.Tensor: Transformed value tensor, size
|
|
|
|
|
(#batch, n_head, time2, d_k).
|
|
|
|
|
"""
|
|
|
|
|
n_batch = query.size(0)
|
|
|
|
|
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
|
|
|
|
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
|
|
|
|
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
|
|
|
|
# n_batch = query.size(0)
|
|
|
|
|
n_batch = query.shape[0]
|
|
|
|
|
q = self.linear_q(query).reshape([n_batch, -1, self.h, self.d_k])
|
|
|
|
|
k = self.linear_k(key).reshape([n_batch, -1, self.h, self.d_k])
|
|
|
|
|
v = self.linear_v(value).reshape([n_batch, -1, self.h, self.d_k])
|
|
|
|
|
q = q.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
|
|
|
|
|
k = k.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k)
|
|
|
|
|
v = v.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k)
|
|
|
|
@ -96,7 +97,8 @@ class MultiHeadedAttention(nn.Layer):
|
|
|
|
|
paddle.Tensor: Transformed value weighted
|
|
|
|
|
by the attention score, (#batch, time1, d_model).
|
|
|
|
|
"""
|
|
|
|
|
n_batch = value.size(0)
|
|
|
|
|
# n_batch = value.size(0)
|
|
|
|
|
n_batch = value.shape[0]
|
|
|
|
|
if mask is not None:
|
|
|
|
|
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
|
|
|
|
scores = scores.masked_fill(mask, -float('inf'))
|
|
|
|
@ -205,8 +207,10 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|
|
|
|
q, k, v = self.forward_qkv(query, key, value)
|
|
|
|
|
q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
|
|
|
|
|
|
|
|
|
|
n_batch_pos = pos_emb.size(0)
|
|
|
|
|
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
|
|
|
|
#n_batch_pos = pos_emb.size(0)
|
|
|
|
|
n_batch_pos = pos_emb.shape[0]
|
|
|
|
|
# p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
|
|
|
|
p = self.linear_pos(pos_emb).reshape([n_batch_pos, -1, self.h, self.d_k])
|
|
|
|
|
p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
|
|
|
|
|
|
|
|
|
|
# (batch, head, time1, d_k)
|
|
|
|
|