|
|
|
@ -101,10 +101,7 @@ class MultiHeadedAttention(nn.Layer):
|
|
|
|
|
if self.train_stage:
|
|
|
|
|
self.build_kv()
|
|
|
|
|
self.train_stage = False
|
|
|
|
|
weight = paddle.concat(
|
|
|
|
|
[self.linear_k.weight, self.linear_v.weight], axis=-1)
|
|
|
|
|
bias = paddle.concat([self.linear_k.bias, self.linear_v.bias])
|
|
|
|
|
k, v = F.linear(key, weight, bias).view(n_batch, -1, 2 * self.h,
|
|
|
|
|
k, v = F.linear(key, self.weight, self.bias).view(n_batch, -1, 2 * self.h,
|
|
|
|
|
self.d_k).split(
|
|
|
|
|
2, axis=2)
|
|
|
|
|
|
|
|
|
|