|
|
|
@ -86,7 +86,8 @@ class MultiHeadedAttention(nn.Layer):
|
|
|
|
|
self,
|
|
|
|
|
value: paddle.Tensor,
|
|
|
|
|
scores: paddle.Tensor,
|
|
|
|
|
mask: paddle.Tensor, ) -> paddle.Tensor:
|
|
|
|
|
mask: paddle.Tensor, # paddle.ones([0, 0, 0], dtype=paddle.bool)
|
|
|
|
|
) -> paddle.Tensor:
|
|
|
|
|
"""Compute attention context vector.
|
|
|
|
|
Args:
|
|
|
|
|
value (paddle.Tensor): Transformed value, size
|
|
|
|
@ -126,13 +127,15 @@ class MultiHeadedAttention(nn.Layer):
|
|
|
|
|
|
|
|
|
|
return self.linear_out(x) # (batch, time1, d_model)
|
|
|
|
|
|
|
|
|
|
def forward(self,
|
|
|
|
|
query: paddle.Tensor,
|
|
|
|
|
key: paddle.Tensor,
|
|
|
|
|
value: paddle.Tensor,
|
|
|
|
|
mask: paddle.Tensor,
|
|
|
|
|
pos_emb: paddle.Tensor,
|
|
|
|
|
cache: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
query: paddle.Tensor,
|
|
|
|
|
key: paddle.Tensor,
|
|
|
|
|
value: paddle.Tensor,
|
|
|
|
|
mask: paddle.Tensor, # paddle.ones([0,0,0], dtype=paddle.bool)
|
|
|
|
|
pos_emb: paddle.Tensor, # paddle.empty([0])
|
|
|
|
|
cache: paddle.Tensor # paddle.zeros([0,0,0,0])
|
|
|
|
|
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
|
|
|
|
"""Compute scaled dot product attention.
|
|
|
|
|
Args:
|
|
|
|
|
query (paddle.Tensor): Query tensor (#batch, time1, size).
|
|
|
|
@ -241,13 +244,15 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def forward(self,
|
|
|
|
|
query: paddle.Tensor,
|
|
|
|
|
key: paddle.Tensor,
|
|
|
|
|
value: paddle.Tensor,
|
|
|
|
|
mask: paddle.Tensor,
|
|
|
|
|
pos_emb: paddle.Tensor,
|
|
|
|
|
cache: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
query: paddle.Tensor,
|
|
|
|
|
key: paddle.Tensor,
|
|
|
|
|
value: paddle.Tensor,
|
|
|
|
|
mask: paddle.Tensor, # paddle.ones([0,0,0], dtype=paddle.bool)
|
|
|
|
|
pos_emb: paddle.Tensor, # paddle.empty([0])
|
|
|
|
|
cache: paddle.Tensor # paddle.zeros([0,0,0,0])
|
|
|
|
|
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
|
|
|
|
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
|
|
|
|
Args:
|
|
|
|
|
query (paddle.Tensor): Query tensor (#batch, time1, size).
|
|
|
|
|