From 1ea828c30eb9bd35becb3efe5cf55940144bb7be Mon Sep 17 00:00:00 2001 From: tianhao zhang <15600919271@163.com> Date: Tue, 18 Oct 2022 09:06:34 +0000 Subject: [PATCH] fix attention val bug --- paddlespeech/s2t/modules/attention.py | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py index 128f87c0..d9568dcc 100644 --- a/paddlespeech/s2t/modules/attention.py +++ b/paddlespeech/s2t/modules/attention.py @@ -19,7 +19,6 @@ from typing import Tuple import paddle from paddle import nn -from paddle.nn import functional as F from paddle.nn import initializer as I from paddlespeech.s2t.modules.align import Linear @@ -56,16 +55,6 @@ class MultiHeadedAttention(nn.Layer): self.linear_out = Linear(n_feat, n_feat) self.dropout = nn.Dropout(p=dropout_rate) - def _build_once(self, *args, **kwargs): - super()._build_once(*args, **kwargs) - # if self.self_att: - # self.linear_kv = Linear(self.n_feat, self.n_feat*2) - if not self.training: - self.weight = paddle.concat( - [self.linear_k.weight, self.linear_v.weight], axis=-1) - self.bias = paddle.concat([self.linear_k.bias, self.linear_v.bias]) - self._built = True - def forward_qkv(self, query: paddle.Tensor, key: paddle.Tensor, @@ -87,13 +76,8 @@ class MultiHeadedAttention(nn.Layer): n_batch = query.shape[0] q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) - if self.training: - k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) - v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) - else: - k, v = F.linear(key, self.weight, self.bias).view( - n_batch, -1, 2 * self.h, self.d_k).split( - 2, axis=2) + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) q = q.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) k = k.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k)