fix attention.py validation bug

pull/2543/head
tianhao zhang 3 years ago
parent 62c1b6c6ee
commit 125932d4f3

@ -3,10 +3,10 @@
. path.sh || exit 1;
set -e
gpus=0,1,2,3,4,5,6,7
stage=0
gpus=1
stage=3
stop_stage=100
conf_path=conf/conformer.yaml
conf_path=conf/bitransformer_decoder_conformer.yaml
ips= #xxx.xxx.xxx.xxx,xxx.xxx.xxx.xxx
decode_conf_path=conf/tuning/decode.yaml
average_checkpoint=true

@ -101,10 +101,7 @@ class MultiHeadedAttention(nn.Layer):
if self.train_stage:
self.build_kv()
self.train_stage = False
weight = paddle.concat(
[self.linear_k.weight, self.linear_v.weight], axis=-1)
bias = paddle.concat([self.linear_k.bias, self.linear_v.bias])
k, v = F.linear(key, weight, bias).view(n_batch, -1, 2 * self.h,
k, v = F.linear(key, self.weight, self.bias).view(n_batch, -1, 2 * self.h,
self.d_k).split(
2, axis=2)

Loading…
Cancel
Save