From 62c1b6c6eee68670ebba5668b4941871cf134e9d Mon Sep 17 00:00:00 2001 From: tianhao zhang <15600919271@163.com> Date: Tue, 18 Oct 2022 07:17:45 +0000 Subject: [PATCH] fix attention.py validation bug --- paddlespeech/s2t/modules/attention.py | 37 ++++++++++++++++++--------- 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py index 128f87c07..12ac01bdb 100644 --- a/paddlespeech/s2t/modules/attention.py +++ b/paddlespeech/s2t/modules/attention.py @@ -56,15 +56,17 @@ class MultiHeadedAttention(nn.Layer): self.linear_out = Linear(n_feat, n_feat) self.dropout = nn.Dropout(p=dropout_rate) - def _build_once(self, *args, **kwargs): - super()._build_once(*args, **kwargs) - # if self.self_att: - # self.linear_kv = Linear(self.n_feat, self.n_feat*2) - if not self.training: - self.weight = paddle.concat( - [self.linear_k.weight, self.linear_v.weight], axis=-1) - self.bias = paddle.concat([self.linear_k.bias, self.linear_v.bias]) - self._built = True + if self.training: + self.train_stage = True + else: + self.build_kv() + self.train_stage = False + + # self._built = True + def build_kv(self): + self.weight = paddle.concat( + [self.linear_k.weight, self.linear_v.weight], axis=-1) + self.bias = paddle.concat([self.linear_k.bias, self.linear_v.bias]) def forward_qkv(self, query: paddle.Tensor, @@ -88,12 +90,23 @@ class MultiHeadedAttention(nn.Layer): q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) if self.training: + if not self.train_stage: + del self.weight + del self.bias + self.train_stage = True k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + else: - k, v = F.linear(key, self.weight, self.bias).view( - n_batch, -1, 2 * self.h, self.d_k).split( - 2, axis=2) + if self.train_stage: + self.build_kv() + self.train_stage = False + weight = paddle.concat( + [self.linear_k.weight, self.linear_v.weight], axis=-1) + bias = paddle.concat([self.linear_k.bias, self.linear_v.bias]) + k, v = F.linear(key, weight, bias).view(n_batch, -1, 2 * self.h, + self.d_k).split( + 2, axis=2) q = q.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) k = k.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k)