fix attn can not train

pull/2502/head
Hui Zhang 3 years ago
parent 1f4f98b171
commit 9277fcb8a8

@ -60,6 +60,7 @@ class MultiHeadedAttention(nn.Layer):
super()._build_once(*args, **kwargs) super()._build_once(*args, **kwargs)
# if self.self_att: # if self.self_att:
# self.linear_kv = Linear(self.n_feat, self.n_feat*2) # self.linear_kv = Linear(self.n_feat, self.n_feat*2)
if not self.training:
self.weight = paddle.concat( self.weight = paddle.concat(
[self.linear_k.weight, self.linear_v.weight], axis=-1) [self.linear_k.weight, self.linear_v.weight], axis=-1)
self.bias = paddle.concat([self.linear_k.bias, self.linear_v.bias]) self.bias = paddle.concat([self.linear_k.bias, self.linear_v.bias])
@ -86,8 +87,10 @@ class MultiHeadedAttention(nn.Layer):
n_batch = query.shape[0] n_batch = query.shape[0]
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
# k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) if self.training:
# v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
else:
k, v = F.linear(key, self.weight, self.bias).view( k, v = F.linear(key, self.weight, self.bias).view(
n_batch, -1, 2 * self.h, self.d_k).split( n_batch, -1, 2 * self.h, self.d_k).split(
2, axis=2) 2, axis=2)

Loading…
Cancel
Save