fix develop bug function:view to reshape (#3633)

pull/3711/head
luyao-cv 7 months ago committed by GitHub
parent a1f9339181
commit f2416ff365
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -79,9 +79,9 @@ class MultiHeadedAttention(nn.Layer):
"""
n_batch = query.shape[0]
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
q = self.linear_q(query).reshape([n_batch, -1, self.h, self.d_k])
k = self.linear_k(key).reshape([n_batch, -1, self.h, self.d_k])
v = self.linear_v(value).reshape([n_batch, -1, self.h, self.d_k])
q = q.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
k = k.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k)
@ -129,8 +129,8 @@ class MultiHeadedAttention(nn.Layer):
p_attn = self.dropout(attn)
x = paddle.matmul(p_attn, value) # (batch, head, time1, d_k)
x = x.transpose([0, 2, 1, 3]).view(n_batch, -1, self.h *
self.d_k) # (batch, time1, d_model)
x = x.transpose([0, 2, 1, 3]).reshape([n_batch, -1, self.h *
self.d_k]) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model)
@ -349,7 +349,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
new_cache = paddle.concat((k, v), axis=-1)
n_batch_pos = pos_emb.shape[0]
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = self.linear_pos(pos_emb).reshape([n_batch_pos, -1, self.h, self.d_k])
p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
# (batch, head, time1, d_k)

Loading…
Cancel
Save