|
|
|
@ -84,9 +84,10 @@ class MultiHeadedAttention(nn.Layer):
|
|
|
|
|
return q, k, v
|
|
|
|
|
|
|
|
|
|
def forward_attention(self,
|
|
|
|
|
value: paddle.Tensor,
|
|
|
|
|
scores: paddle.Tensor,
|
|
|
|
|
mask: Optional[paddle.Tensor]) -> paddle.Tensor:
|
|
|
|
|
value: paddle.Tensor,
|
|
|
|
|
scores: paddle.Tensor,
|
|
|
|
|
mask: paddle.Tensor = paddle.ones([0, 0, 0], dtype=paddle.bool),
|
|
|
|
|
) -> paddle.Tensor:
|
|
|
|
|
"""Compute attention context vector.
|
|
|
|
|
Args:
|
|
|
|
|
value (paddle.Tensor): Transformed value, size
|
|
|
|
@ -94,14 +95,23 @@ class MultiHeadedAttention(nn.Layer):
|
|
|
|
|
scores (paddle.Tensor): Attention score, size
|
|
|
|
|
(#batch, n_head, time1, time2).
|
|
|
|
|
mask (paddle.Tensor): Mask, size (#batch, 1, time2) or
|
|
|
|
|
(#batch, time1, time2).
|
|
|
|
|
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
|
|
|
|
Returns:
|
|
|
|
|
paddle.Tensor: Transformed value weighted
|
|
|
|
|
by the attention score, (#batch, time1, d_model).
|
|
|
|
|
paddle.Tensor: Transformed value (#batch, time1, d_model)
|
|
|
|
|
weighted by the attention score (#batch, time1, time2).
|
|
|
|
|
"""
|
|
|
|
|
n_batch = value.shape[0]
|
|
|
|
|
if mask is not None:
|
|
|
|
|
|
|
|
|
|
# When `if mask.size(2) > 0` be True:
|
|
|
|
|
# 1. training.
|
|
|
|
|
# 2. oonx(16/4, chunk_size/history_size), feed real cache and real mask for the 1st chunk.
|
|
|
|
|
# When will `if mask.size(2) > 0` be False?
|
|
|
|
|
# 1. onnx(16/-1, -1/-1, 16/0)
|
|
|
|
|
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
|
|
|
|
|
if paddle.shape(mask)[2] > 0: # time2 > 0
|
|
|
|
|
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
|
|
|
|
# for last chunk, time2 might be larger than scores.size(-1)
|
|
|
|
|
mask = mask[:, :, :, :paddle.shape(scores)[-1]]
|
|
|
|
|
scores = scores.masked_fill(mask, -float('inf'))
|
|
|
|
|
attn = paddle.softmax(
|
|
|
|
|
scores, axis=-1).masked_fill(mask,
|
|
|
|
@ -121,21 +131,67 @@ class MultiHeadedAttention(nn.Layer):
|
|
|
|
|
query: paddle.Tensor,
|
|
|
|
|
key: paddle.Tensor,
|
|
|
|
|
value: paddle.Tensor,
|
|
|
|
|
mask: Optional[paddle.Tensor]) -> paddle.Tensor:
|
|
|
|
|
mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool),
|
|
|
|
|
pos_emb: paddle.Tensor = paddle.empty([0]),
|
|
|
|
|
cache: paddle.Tensor = paddle.zeros([0,0,0,0])
|
|
|
|
|
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
|
|
|
|
"""Compute scaled dot product attention.
|
|
|
|
|
Args:
|
|
|
|
|
query (torch.Tensor): Query tensor (#batch, time1, size).
|
|
|
|
|
key (torch.Tensor): Key tensor (#batch, time2, size).
|
|
|
|
|
value (torch.Tensor): Value tensor (#batch, time2, size).
|
|
|
|
|
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
|
|
|
|
Args:
|
|
|
|
|
query (paddle.Tensor): Query tensor (#batch, time1, size).
|
|
|
|
|
key (paddle.Tensor): Key tensor (#batch, time2, size).
|
|
|
|
|
value (paddle.Tensor): Value tensor (#batch, time2, size).
|
|
|
|
|
mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or
|
|
|
|
|
(#batch, time1, time2).
|
|
|
|
|
1.When applying cross attention between decoder and encoder,
|
|
|
|
|
the batch padding mask for input is in (#batch, 1, T) shape.
|
|
|
|
|
2.When applying self attention of encoder,
|
|
|
|
|
the mask is in (#batch, T, T) shape.
|
|
|
|
|
3.When applying self attention of decoder,
|
|
|
|
|
the mask is in (#batch, L, L) shape.
|
|
|
|
|
4.If the different position in decoder see different block
|
|
|
|
|
of the encoder, such as Mocha, the passed in mask could be
|
|
|
|
|
in (#batch, L, T) shape. But there is no such case in current
|
|
|
|
|
Wenet.
|
|
|
|
|
cache (paddle.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
|
|
|
|
where `cache_t == chunk_size * num_decoding_left_chunks`
|
|
|
|
|
and `head * d_k == size`
|
|
|
|
|
Returns:
|
|
|
|
|
torch.Tensor: Output tensor (#batch, time1, d_model).
|
|
|
|
|
paddle.Tensor: Output tensor (#batch, time1, d_model).
|
|
|
|
|
paddle.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
|
|
|
|
where `cache_t == chunk_size * num_decoding_left_chunks`
|
|
|
|
|
and `head * d_k == size`
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
q, k, v = self.forward_qkv(query, key, value)
|
|
|
|
|
|
|
|
|
|
# when export onnx model, for 1st chunk, we feed
|
|
|
|
|
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
|
|
|
|
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
|
|
|
|
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
|
|
|
|
# and we will always do splitting and
|
|
|
|
|
# concatnation(this will simplify onnx export). Note that
|
|
|
|
|
# it's OK to concat & split zero-shaped tensors(see code below).
|
|
|
|
|
# when export jit model, for 1st chunk, we always feed
|
|
|
|
|
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
|
|
|
|
# >>> a = torch.ones((1, 2, 0, 4))
|
|
|
|
|
# >>> b = torch.ones((1, 2, 3, 4))
|
|
|
|
|
# >>> c = torch.cat((a, b), dim=2)
|
|
|
|
|
# >>> torch.equal(b, c) # True
|
|
|
|
|
# >>> d = torch.split(a, 2, dim=-1)
|
|
|
|
|
# >>> torch.equal(d[0], d[1]) # True
|
|
|
|
|
if paddle.shape(cache)[0] > 0:
|
|
|
|
|
# last dim `d_k * 2` for (key, val)
|
|
|
|
|
key_cache, value_cache = paddle.split(
|
|
|
|
|
cache, paddle.shape(cache)[-1] // 2, axis=-1)
|
|
|
|
|
k = paddle.concat([key_cache, k], axis=2)
|
|
|
|
|
v = paddle.concat([value_cache, v], axis=2)
|
|
|
|
|
# We do cache slicing in encoder.forward_chunk, since it's
|
|
|
|
|
# non-trivial to calculate `next_cache_start` here.
|
|
|
|
|
new_cache = paddle.concat((k, v), axis=-1)
|
|
|
|
|
|
|
|
|
|
scores = paddle.matmul(q,
|
|
|
|
|
k.transpose([0, 1, 3, 2])) / math.sqrt(self.d_k)
|
|
|
|
|
return self.forward_attention(v, scores, mask)
|
|
|
|
|
return self.forward_attention(v, scores, mask), new_cache
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|
|
|
@ -192,23 +248,55 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|
|
|
|
query: paddle.Tensor,
|
|
|
|
|
key: paddle.Tensor,
|
|
|
|
|
value: paddle.Tensor,
|
|
|
|
|
pos_emb: paddle.Tensor,
|
|
|
|
|
mask: Optional[paddle.Tensor]):
|
|
|
|
|
mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool),
|
|
|
|
|
pos_emb: paddle.Tensor = paddle.empty([0]),
|
|
|
|
|
cache: paddle.Tensor = paddle.zeros([0,0,0,0])
|
|
|
|
|
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
|
|
|
|
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
|
|
|
|
Args:
|
|
|
|
|
query (paddle.Tensor): Query tensor (#batch, time1, size).
|
|
|
|
|
key (paddle.Tensor): Key tensor (#batch, time2, size).
|
|
|
|
|
value (paddle.Tensor): Value tensor (#batch, time2, size).
|
|
|
|
|
pos_emb (paddle.Tensor): Positional embedding tensor
|
|
|
|
|
(#batch, time1, size).
|
|
|
|
|
mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or
|
|
|
|
|
(#batch, time1, time2).
|
|
|
|
|
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
|
|
|
|
pos_emb (paddle.Tensor): Positional embedding tensor
|
|
|
|
|
(#batch, time2, size).
|
|
|
|
|
cache (paddle.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
|
|
|
|
where `cache_t == chunk_size * num_decoding_left_chunks`
|
|
|
|
|
and `head * d_k == size`
|
|
|
|
|
Returns:
|
|
|
|
|
paddle.Tensor: Output tensor (#batch, time1, d_model).
|
|
|
|
|
paddle.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
|
|
|
|
where `cache_t == chunk_size * num_decoding_left_chunks`
|
|
|
|
|
and `head * d_k == size`
|
|
|
|
|
"""
|
|
|
|
|
q, k, v = self.forward_qkv(query, key, value)
|
|
|
|
|
q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
|
|
|
|
|
|
|
|
|
|
# when export onnx model, for 1st chunk, we feed
|
|
|
|
|
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
|
|
|
|
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
|
|
|
|
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
|
|
|
|
# and we will always do splitting and
|
|
|
|
|
# concatnation(this will simplify onnx export). Note that
|
|
|
|
|
# it's OK to concat & split zero-shaped tensors(see code below).
|
|
|
|
|
# when export jit model, for 1st chunk, we always feed
|
|
|
|
|
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
|
|
|
|
# >>> a = torch.ones((1, 2, 0, 4))
|
|
|
|
|
# >>> b = torch.ones((1, 2, 3, 4))
|
|
|
|
|
# >>> c = torch.cat((a, b), dim=2)
|
|
|
|
|
# >>> torch.equal(b, c) # True
|
|
|
|
|
# >>> d = torch.split(a, 2, dim=-1)
|
|
|
|
|
# >>> torch.equal(d[0], d[1]) # True
|
|
|
|
|
if paddle.shape(cache)[0] > 0:
|
|
|
|
|
key_cache, value_cache = paddle.split(
|
|
|
|
|
cache, paddle.shape(cache)[-1] // 2, axis=-1)
|
|
|
|
|
k = paddle.concat([key_cache, k], axis=2)
|
|
|
|
|
v = paddle.concat([value_cache, v], axis=2)
|
|
|
|
|
# We do cache slicing in encoder.forward_chunk, since it's
|
|
|
|
|
# non-trivial to calculate `next_cache_start` here.
|
|
|
|
|
new_cache = paddle.concat((k, v), axis=-1)
|
|
|
|
|
|
|
|
|
|
n_batch_pos = pos_emb.shape[0]
|
|
|
|
|
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
|
|
|
|
p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
|
|
|
|
@ -234,4 +322,4 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|
|
|
|
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
|
|
|
|
self.d_k) # (batch, head, time1, time2)
|
|
|
|
|
|
|
|
|
|
return self.forward_attention(v, scores, mask)
|
|
|
|
|
return self.forward_attention(v, scores, mask), new_cache
|