|
|
|
@ -20,6 +20,7 @@ from typing import Tuple
|
|
|
|
|
import paddle
|
|
|
|
|
from paddle import nn
|
|
|
|
|
from paddle.nn import initializer as I
|
|
|
|
|
from paddle.nn import functional as F
|
|
|
|
|
|
|
|
|
|
from paddlespeech.s2t.modules.align import Linear
|
|
|
|
|
from paddlespeech.s2t.utils.log import Log
|
|
|
|
@ -45,6 +46,7 @@ class MultiHeadedAttention(nn.Layer):
|
|
|
|
|
"""
|
|
|
|
|
super().__init__()
|
|
|
|
|
assert n_feat % n_head == 0
|
|
|
|
|
self.n_feat = n_feat
|
|
|
|
|
# We assume d_v always equals d_k
|
|
|
|
|
self.d_k = n_feat // n_head
|
|
|
|
|
self.h = n_head
|
|
|
|
@ -54,6 +56,15 @@ class MultiHeadedAttention(nn.Layer):
|
|
|
|
|
self.linear_out = Linear(n_feat, n_feat)
|
|
|
|
|
self.dropout = nn.Dropout(p=dropout_rate)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _build_once(self, *args, **kwargs):
|
|
|
|
|
super()._build_once(*args, **kwargs)
|
|
|
|
|
# if self.self_att:
|
|
|
|
|
# self.linear_kv = Linear(self.n_feat, self.n_feat*2)
|
|
|
|
|
self.weight = paddle.concat([self.linear_k.weight, self.linear_v.weight], axis=-1)
|
|
|
|
|
self.bias = paddle.concat([self.linear_k.bias, self.linear_v.bias])
|
|
|
|
|
self._built = True
|
|
|
|
|
|
|
|
|
|
def forward_qkv(self,
|
|
|
|
|
query: paddle.Tensor,
|
|
|
|
|
key: paddle.Tensor,
|
|
|
|
@ -73,9 +84,12 @@ class MultiHeadedAttention(nn.Layer):
|
|
|
|
|
(#batch, n_head, time2, d_k).
|
|
|
|
|
"""
|
|
|
|
|
n_batch = query.shape[0]
|
|
|
|
|
|
|
|
|
|
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
|
|
|
|
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
|
|
|
|
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
|
|
|
|
# k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
|
|
|
|
# v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
|
|
|
|
k, v = F.linear(key, self.weight, self.bias).view(n_batch, -1, 2 * self.h, self.d_k).split(2, axis=2)
|
|
|
|
|
|
|
|
|
|
q = q.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
|
|
|
|
|
k = k.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k)
|
|
|
|
|
v = v.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k)
|
|
|
|
@ -108,10 +122,10 @@ class MultiHeadedAttention(nn.Layer):
|
|
|
|
|
# When will `if mask.size(2) > 0` be False?
|
|
|
|
|
# 1. onnx(16/-1, -1/-1, 16/0)
|
|
|
|
|
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
|
|
|
|
|
if paddle.shape(mask)[2] > 0: # time2 > 0
|
|
|
|
|
if mask.shape[2] > 0: # time2 > 0
|
|
|
|
|
mask = mask.unsqueeze(1).equal(0) # (batch, 1, *, time2)
|
|
|
|
|
# for last chunk, time2 might be larger than scores.size(-1)
|
|
|
|
|
mask = mask[:, :, :, :paddle.shape(scores)[-1]]
|
|
|
|
|
mask = mask[:, :, :, :scores.shape[-1]]
|
|
|
|
|
scores = scores.masked_fill(mask, -float('inf'))
|
|
|
|
|
attn = paddle.softmax(
|
|
|
|
|
scores, axis=-1).masked_fill(mask,
|
|
|
|
@ -179,7 +193,7 @@ class MultiHeadedAttention(nn.Layer):
|
|
|
|
|
# >>> torch.equal(b, c) # True
|
|
|
|
|
# >>> d = torch.split(a, 2, dim=-1)
|
|
|
|
|
# >>> torch.equal(d[0], d[1]) # True
|
|
|
|
|
if paddle.shape(cache)[0] > 0:
|
|
|
|
|
if cache.shape[0] > 0:
|
|
|
|
|
# last dim `d_k * 2` for (key, val)
|
|
|
|
|
key_cache, value_cache = paddle.split(cache, 2, axis=-1)
|
|
|
|
|
k = paddle.concat([key_cache, k], axis=2)
|
|
|
|
@ -288,7 +302,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|
|
|
|
# >>> torch.equal(b, c) # True
|
|
|
|
|
# >>> d = torch.split(a, 2, dim=-1)
|
|
|
|
|
# >>> torch.equal(d[0], d[1]) # True
|
|
|
|
|
if paddle.shape(cache)[0] > 0:
|
|
|
|
|
if cache.shape[0] > 0:
|
|
|
|
|
# last dim `d_k * 2` for (key, val)
|
|
|
|
|
key_cache, value_cache = paddle.split(cache, 2, axis=-1)
|
|
|
|
|
k = paddle.concat([key_cache, k], axis=2)
|
|
|
|
|