From 9277fcb8a85d7a064f90eebdc7f9ba547abec13e Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Sat, 8 Oct 2022 08:15:51 +0000 Subject: [PATCH] fix attn can not train --- paddlespeech/s2t/modules/attention.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py index d9ee763f1..128f87c07 100644 --- a/paddlespeech/s2t/modules/attention.py +++ b/paddlespeech/s2t/modules/attention.py @@ -60,9 +60,10 @@ class MultiHeadedAttention(nn.Layer): super()._build_once(*args, **kwargs) # if self.self_att: # self.linear_kv = Linear(self.n_feat, self.n_feat*2) - self.weight = paddle.concat( - [self.linear_k.weight, self.linear_v.weight], axis=-1) - self.bias = paddle.concat([self.linear_k.bias, self.linear_v.bias]) + if not self.training: + self.weight = paddle.concat( + [self.linear_k.weight, self.linear_v.weight], axis=-1) + self.bias = paddle.concat([self.linear_k.bias, self.linear_v.bias]) self._built = True def forward_qkv(self, @@ -86,11 +87,13 @@ class MultiHeadedAttention(nn.Layer): n_batch = query.shape[0] q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) - # k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) - # v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) - k, v = F.linear(key, self.weight, self.bias).view( - n_batch, -1, 2 * self.h, self.d_k).split( - 2, axis=2) + if self.training: + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + else: + k, v = F.linear(key, self.weight, self.bias).view( + n_batch, -1, 2 * self.h, self.d_k).split( + 2, axis=2) q = q.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) k = k.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k)