|
|
@ -19,7 +19,6 @@ from typing import Tuple
|
|
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
import paddle
|
|
|
|
from paddle import nn
|
|
|
|
from paddle import nn
|
|
|
|
from paddle.nn import functional as F
|
|
|
|
|
|
|
|
from paddle.nn import initializer as I
|
|
|
|
from paddle.nn import initializer as I
|
|
|
|
|
|
|
|
|
|
|
|
from paddlespeech.s2t.modules.align import Linear
|
|
|
|
from paddlespeech.s2t.modules.align import Linear
|
|
|
@ -56,16 +55,6 @@ class MultiHeadedAttention(nn.Layer):
|
|
|
|
self.linear_out = Linear(n_feat, n_feat)
|
|
|
|
self.linear_out = Linear(n_feat, n_feat)
|
|
|
|
self.dropout = nn.Dropout(p=dropout_rate)
|
|
|
|
self.dropout = nn.Dropout(p=dropout_rate)
|
|
|
|
|
|
|
|
|
|
|
|
def _build_once(self, *args, **kwargs):
|
|
|
|
|
|
|
|
super()._build_once(*args, **kwargs)
|
|
|
|
|
|
|
|
# if self.self_att:
|
|
|
|
|
|
|
|
# self.linear_kv = Linear(self.n_feat, self.n_feat*2)
|
|
|
|
|
|
|
|
if not self.training:
|
|
|
|
|
|
|
|
self.weight = paddle.concat(
|
|
|
|
|
|
|
|
[self.linear_k.weight, self.linear_v.weight], axis=-1)
|
|
|
|
|
|
|
|
self.bias = paddle.concat([self.linear_k.bias, self.linear_v.bias])
|
|
|
|
|
|
|
|
self._built = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward_qkv(self,
|
|
|
|
def forward_qkv(self,
|
|
|
|
query: paddle.Tensor,
|
|
|
|
query: paddle.Tensor,
|
|
|
|
key: paddle.Tensor,
|
|
|
|
key: paddle.Tensor,
|
|
|
@ -87,13 +76,8 @@ class MultiHeadedAttention(nn.Layer):
|
|
|
|
n_batch = query.shape[0]
|
|
|
|
n_batch = query.shape[0]
|
|
|
|
|
|
|
|
|
|
|
|
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
|
|
|
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
|
|
|
if self.training:
|
|
|
|
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
|
|
|
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
|
|
|
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
|
|
|
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
k, v = F.linear(key, self.weight, self.bias).view(
|
|
|
|
|
|
|
|
n_batch, -1, 2 * self.h, self.d_k).split(
|
|
|
|
|
|
|
|
2, axis=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
q = q.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
|
|
|
|
q = q.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
|
|
|
|
k = k.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k)
|
|
|
|
k = k.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k)
|
|
|
|