diff --git a/examples/wenetspeech/asr1/run.sh b/examples/wenetspeech/asr1/run.sh index ddce0a9c8..772842b69 100644 --- a/examples/wenetspeech/asr1/run.sh +++ b/examples/wenetspeech/asr1/run.sh @@ -3,10 +3,10 @@ . path.sh || exit 1; set -e -gpus=0,1,2,3,4,5,6,7 -stage=0 +gpus=1 +stage=3 stop_stage=100 -conf_path=conf/conformer.yaml +conf_path=conf/bitransformer_decoder_conformer.yaml ips= #xxx.xxx.xxx.xxx,xxx.xxx.xxx.xxx decode_conf_path=conf/tuning/decode.yaml average_checkpoint=true diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py index 12ac01bdb..dd7019fb6 100644 --- a/paddlespeech/s2t/modules/attention.py +++ b/paddlespeech/s2t/modules/attention.py @@ -101,10 +101,7 @@ class MultiHeadedAttention(nn.Layer): if self.train_stage: self.build_kv() self.train_stage = False - weight = paddle.concat( - [self.linear_k.weight, self.linear_v.weight], axis=-1) - bias = paddle.concat([self.linear_k.bias, self.linear_v.bias]) - k, v = F.linear(key, weight, bias).view(n_batch, -1, 2 * self.h, + k, v = F.linear(key, self.weight, self.bias).view(n_batch, -1, 2 * self.h, self.d_k).split( 2, axis=2)