refactor attention cache

pull/2124/head
Hui Zhang 2 years ago
parent e153495519
commit fb40602d94

@ -605,29 +605,42 @@ class U2BaseModel(ASRInterface, nn.Layer):
xs: paddle.Tensor,
offset: int,
required_cache_size: int,
subsampling_cache: Optional[paddle.Tensor]=None,
elayers_output_cache: Optional[List[paddle.Tensor]]=None,
conformer_cnn_cache: Optional[List[paddle.Tensor]]=None,
) -> Tuple[paddle.Tensor, paddle.Tensor, List[paddle.Tensor], List[
paddle.Tensor]]:
att_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]),
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" Export interface for c++ call, give input chunk xs, and return
output from time 0 to current chunk.
Args:
xs (paddle.Tensor): chunk input
subsampling_cache (Optional[paddle.Tensor]): subsampling cache
elayers_output_cache (Optional[List[paddle.Tensor]]):
transformer/conformer encoder layers output cache
conformer_cnn_cache (Optional[List[paddle.Tensor]]): conformer
cnn cache
xs (paddle.Tensor): chunk input, with shape (b=1, time, mel-dim),
where `time == (chunk_size - 1) * subsample_rate + \
subsample.right_context + 1`
offset (int): current offset in encoder output time stamp
required_cache_size (int): cache size required for next chunk
compuation
>=0: actual cache size
<0: means all history cache is required
att_cache (paddle.Tensor): cache tensor for KEY & VALUE in
transformer/conformer attention, with shape
(elayers, head, cache_t1, d_k * 2), where
`head * d_k == hidden-dim` and
`cache_t1 == chunk_size * num_decoding_left_chunks`.
`d_k * 2` for att key & value.
cnn_cache (paddle.Tensor): cache tensor for cnn_module in conformer,
(elayers, b=1, hidden-dim, cache_t2), where
`cache_t2 == cnn.lorder - 1`
Returns:
paddle.Tensor: output, it ranges from time 0 to current chunk.
paddle.Tensor: subsampling cache
List[paddle.Tensor]: attention cache
List[paddle.Tensor]: conformer cnn cache
paddle.Tensor: output of current input xs,
with shape (b=1, chunk_size, hidden-dim).
paddle.Tensor: new attention cache required for next chunk, with
dynamic shape (elayers, head, T(?), d_k * 2)
depending on required_cache_size.
paddle.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
"""
return self.encoder.forward_chunk(
xs, offset, required_cache_size, subsampling_cache,
elayers_output_cache, conformer_cnn_cache)
xs, offset, required_cache_size, att_cache, cnn_cache)
# @jit.to_static
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:

@ -84,9 +84,10 @@ class MultiHeadedAttention(nn.Layer):
return q, k, v
def forward_attention(self,
value: paddle.Tensor,
scores: paddle.Tensor,
mask: Optional[paddle.Tensor]) -> paddle.Tensor:
value: paddle.Tensor,
scores: paddle.Tensor,
mask: paddle.Tensor = paddle.ones([0, 0, 0], dtype=paddle.bool),
) -> paddle.Tensor:
"""Compute attention context vector.
Args:
value (paddle.Tensor): Transformed value, size
@ -94,14 +95,23 @@ class MultiHeadedAttention(nn.Layer):
scores (paddle.Tensor): Attention score, size
(#batch, n_head, time1, time2).
mask (paddle.Tensor): Mask, size (#batch, 1, time2) or
(#batch, time1, time2).
(#batch, time1, time2), (0, 0, 0) means fake mask.
Returns:
paddle.Tensor: Transformed value weighted
by the attention score, (#batch, time1, d_model).
paddle.Tensor: Transformed value (#batch, time1, d_model)
weighted by the attention score (#batch, time1, time2).
"""
n_batch = value.shape[0]
if mask is not None:
# When `if mask.size(2) > 0` be True:
# 1. training.
# 2. oonx(16/4, chunk_size/history_size), feed real cache and real mask for the 1st chunk.
# When will `if mask.size(2) > 0` be False?
# 1. onnx(16/-1, -1/-1, 16/0)
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
if paddle.shape(mask)[2] > 0: # time2 > 0
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
# for last chunk, time2 might be larger than scores.size(-1)
mask = mask[:, :, :, :paddle.shape(scores)[-1]]
scores = scores.masked_fill(mask, -float('inf'))
attn = paddle.softmax(
scores, axis=-1).masked_fill(mask,
@ -121,21 +131,67 @@ class MultiHeadedAttention(nn.Layer):
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
mask: Optional[paddle.Tensor]) -> paddle.Tensor:
mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool),
pos_emb: paddle.Tensor = paddle.empty([0]),
cache: paddle.Tensor = paddle.zeros([0,0,0,0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute scaled dot product attention.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
Args:
query (paddle.Tensor): Query tensor (#batch, time1, size).
key (paddle.Tensor): Key tensor (#batch, time2, size).
value (paddle.Tensor): Value tensor (#batch, time2, size).
mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
1.When applying cross attention between decoder and encoder,
the batch padding mask for input is in (#batch, 1, T) shape.
2.When applying self attention of encoder,
the mask is in (#batch, T, T) shape.
3.When applying self attention of decoder,
the mask is in (#batch, L, L) shape.
4.If the different position in decoder see different block
of the encoder, such as Mocha, the passed in mask could be
in (#batch, L, T) shape. But there is no such case in current
Wenet.
cache (paddle.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
paddle.Tensor: Output tensor (#batch, time1, d_model).
paddle.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q, k, v = self.forward_qkv(query, key, value)
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if paddle.shape(cache)[0] > 0:
# last dim `d_k * 2` for (key, val)
key_cache, value_cache = paddle.split(
cache, paddle.shape(cache)[-1] // 2, axis=-1)
k = paddle.concat([key_cache, k], axis=2)
v = paddle.concat([value_cache, v], axis=2)
# We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = paddle.concat((k, v), axis=-1)
scores = paddle.matmul(q,
k.transpose([0, 1, 3, 2])) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask)
return self.forward_attention(v, scores, mask), new_cache
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
@ -192,23 +248,55 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
pos_emb: paddle.Tensor,
mask: Optional[paddle.Tensor]):
mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool),
pos_emb: paddle.Tensor = paddle.empty([0]),
cache: paddle.Tensor = paddle.zeros([0,0,0,0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (paddle.Tensor): Query tensor (#batch, time1, size).
key (paddle.Tensor): Key tensor (#batch, time2, size).
value (paddle.Tensor): Value tensor (#batch, time2, size).
pos_emb (paddle.Tensor): Positional embedding tensor
(#batch, time1, size).
mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
(#batch, time1, time2), (0, 0, 0) means fake mask.
pos_emb (paddle.Tensor): Positional embedding tensor
(#batch, time2, size).
cache (paddle.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
paddle.Tensor: Output tensor (#batch, time1, d_model).
paddle.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if paddle.shape(cache)[0] > 0:
key_cache, value_cache = paddle.split(
cache, paddle.shape(cache)[-1] // 2, axis=-1)
k = paddle.concat([key_cache, k], axis=2)
v = paddle.concat([value_cache, v], axis=2)
# We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = paddle.concat((k, v), axis=-1)
n_batch_pos = pos_emb.shape[0]
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
@ -234,4 +322,4 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k) # (batch, head, time1, time2)
return self.forward_attention(v, scores, mask)
return self.forward_attention(v, scores, mask), new_cache

@ -108,15 +108,17 @@ class ConvolutionModule(nn.Layer):
def forward(self,
x: paddle.Tensor,
mask_pad: Optional[paddle.Tensor]=None,
cache: Optional[paddle.Tensor]=None
mask_pad: paddle.Tensor= paddle.ones([0,0,0], dtype=paddle.bool),
cache: paddle.Tensor= paddle.zeros([0,0,0]),
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute convolution module.
Args:
x (paddle.Tensor): Input tensor (#batch, time, channels).
mask_pad (paddle.Tensor): used for batch padding, (#batch, channels, time).
mask_pad (paddle.Tensor): used for batch padding (#batch, 1, time),
(0, 0, 0) means fake mask.
cache (paddle.Tensor): left context cache, it is only
used in causal convolution. (#batch, channels, time')
used in causal convolution (#batch, channels, cache_t),
(0, 0, 0) meas fake cache.
Returns:
paddle.Tensor: Output tensor (#batch, time, channels).
paddle.Tensor: Output cache tensor (#batch, channels, time')
@ -125,11 +127,11 @@ class ConvolutionModule(nn.Layer):
x = x.transpose([0, 2, 1]) # [B, C, T]
# mask batch padding
if mask_pad is not None:
if paddle.shape(mask_pad)[2] > 0: # time > 0
x = x.masked_fill(mask_pad, 0.0)
if self.lorder > 0:
if cache is None:
if paddle.shape(cache)[2] == 0: # cache_t == 0
x = nn.functional.pad(
x, [self.lorder, 0], 'constant', 0.0, data_format='NCL')
else:
@ -143,7 +145,7 @@ class ConvolutionModule(nn.Layer):
# It's better we just return None if no cache is requried,
# However, for JIT export, here we just fake one tensor instead of
# None.
new_cache = paddle.zeros([1], dtype=x.dtype)
new_cache = paddle.zeros([0, 0, 0], dtype=x.dtype)
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
@ -159,7 +161,7 @@ class ConvolutionModule(nn.Layer):
x = self.pointwise_conv2(x)
# mask batch padding
if mask_pad is not None:
if paddle.shape(mask_pad)[2] > 0: # time > 0
x = x.masked_fill(mask_pad, 0.0)
x = x.transpose([0, 2, 1]) # [B, T, C]

@ -121,11 +121,11 @@ class DecoderLayer(nn.Layer):
if self.concat_after:
tgt_concat = paddle.cat(
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1)
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]), dim=-1)
x = residual + self.concat_linear1(tgt_concat)
else:
x = residual + self.dropout(
self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
if not self.normalize_before:
x = self.norm1(x)
@ -134,11 +134,11 @@ class DecoderLayer(nn.Layer):
x = self.norm2(x)
if self.concat_after:
x_concat = paddle.cat(
(x, self.src_attn(x, memory, memory, memory_mask)), dim=-1)
(x, self.src_attn(x, memory, memory, memory_mask)[0]), dim=-1)
x = residual + self.concat_linear2(x_concat)
else:
x = residual + self.dropout(
self.src_attn(x, memory, memory, memory_mask))
self.src_attn(x, memory, memory, memory_mask)[0])
if not self.normalize_before:
x = self.norm2(x)

@ -131,7 +131,7 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface):
offset (int): start offset
size (int): requried size of position encoding
Returns:
paddle.Tensor: Corresponding position encoding
paddle.Tensor: Corresponding position encoding, #[1, T, D].
"""
assert offset + size < self.max_len
return self.dropout(self.pe[:, offset:offset + size])

@ -177,7 +177,7 @@ class BaseEncoder(nn.Layer):
decoding_chunk_size, self.static_chunk_size,
num_decoding_left_chunks)
for layer in self.encoders:
xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
if self.normalize_before:
xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just
@ -190,30 +190,31 @@ class BaseEncoder(nn.Layer):
xs: paddle.Tensor,
offset: int,
required_cache_size: int,
subsampling_cache: Optional[paddle.Tensor]=None,
elayers_output_cache: Optional[List[paddle.Tensor]]=None,
conformer_cnn_cache: Optional[List[paddle.Tensor]]=None,
) -> Tuple[paddle.Tensor, paddle.Tensor, List[paddle.Tensor], List[
paddle.Tensor]]:
att_cache: paddle.Tensor = paddle.zeros([0,0,0,0]),
cnn_cache: paddle.Tensor = paddle.zeros([0,0,0,0]),
att_mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool),
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" Forward just one chunk
Args:
xs (paddle.Tensor): chunk input, [B=1, T, D]
xs (paddle.Tensor): chunk audio feat input, [B=1, T, D], where
`T==(chunk_size-1)*subsampling_rate + subsample.right_context + 1`
offset (int): current offset in encoder output time stamp
required_cache_size (int): cache size required for next chunk
compuation
>=0: actual cache size
<0: means all history cache is required
subsampling_cache (Optional[paddle.Tensor]): subsampling cache
elayers_output_cache (Optional[List[paddle.Tensor]]):
transformer/conformer encoder layers output cache
conformer_cnn_cache (Optional[List[paddle.Tensor]]): conformer
cnn cache
att_cache(paddle.Tensor): cache tensor for key & val in
transformer/conformer attention. Shape is
(elayers, head, cache_t1, d_k * 2), where`head * d_k == hidden-dim`
and `cache_t1 == chunk_size * num_decoding_left_chunks`.
cnn_cache (paddle.Tensor): cache tensor for cnn_module in conformer,
(elayers, B=1, hidden-dim, cache_t2), where `cache_t2 == cnn.lorder - 1`
Returns:
paddle.Tensor: output of current input xs
paddle.Tensor: subsampling cache required for next chunk computation
List[paddle.Tensor]: encoder layers output cache required for next
chunk computation
List[paddle.Tensor]: conformer cnn cache
paddle.Tensor: output of current input xs, (B=1, chunk_size, hidden-dim)
paddle.Tensor: new attention cache required for next chunk, dyanmic shape
(elayers, head, T, d_k*2) depending on required_cache_size
paddle.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache
"""
assert xs.shape[0] == 1 # batch size must be one
# tmp_masks is just for interface compatibility
@ -225,50 +226,49 @@ class BaseEncoder(nn.Layer):
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
xs, pos_emb, _ = self.embed(
xs, tmp_masks, offset=offset) #xs=(B, T, D), pos_emb=(B=1, T, D)
# before embed, xs=(B, T, D1), pos_emb=(B=1, T, D)
xs, pos_emb, _ = self.embed(xs, tmp_masks, offset=offset)
# after embed, xs=(B=1, chunk_size, hidden-dim)
if subsampling_cache is not None:
cache_size = subsampling_cache.shape[1] #T
xs = paddle.cat((subsampling_cache, xs), dim=1)
else:
cache_size = 0
elayers, cache_t1 = paddle.shape(att_cache)[0], paddle.shape(att_cache)[2]
chunk_size = paddle.shape(xs)[1]
attention_key_size = cache_t1 + chunk_size
# only used when using `RelPositionMultiHeadedAttention`
pos_emb = self.embed.position_encoding(
offset=offset - cache_size, size=xs.shape[1])
offset=offset - cache_t1, size=attention_key_size)
if required_cache_size < 0:
next_cache_start = 0
elif required_cache_size == 0:
next_cache_start = xs.shape[1]
next_cache_start = attention_key_size
else:
next_cache_start = xs.shape[1] - required_cache_size
r_subsampling_cache = xs[:, next_cache_start:, :]
# Real mask for transformer/conformer layers
masks = paddle.ones([1, xs.shape[1]], dtype=paddle.bool)
masks = masks.unsqueeze(1) #[B=1, L'=1, T]
r_elayers_output_cache = []
r_conformer_cnn_cache = []
next_cache_start = max(attention_key_size - required_cache_size, 0)
r_att_cache = []
r_cnn_cache = []
for i, layer in enumerate(self.encoders):
attn_cache = None if elayers_output_cache is None else elayers_output_cache[
i]
cnn_cache = None if conformer_cnn_cache is None else conformer_cnn_cache[
i]
xs, _, new_cnn_cache = layer(
xs,
masks,
pos_emb,
output_cache=attn_cache,
cnn_cache=cnn_cache)
r_elayers_output_cache.append(xs[:, next_cache_start:, :])
r_conformer_cnn_cache.append(new_cnn_cache)
# att_cache[i:i+1] = (1, head, cache_t1, d_k*2)
# cnn_cache[i] = (B=1, hidden-dim, cache_t2)
xs, _, new_att_cache, new_cnn_cache = layer(
xs, att_mask, pos_emb,
att_cache=att_cache[i:i+1] if elayers > 0 else att_cache,
cnn_cache=cnn_cache[i] if paddle.shape(cnn_cache)[0] > 0 else cnn_cache,
)
# new_att_cache = (1, head, attention_key_size, d_k*2)
# new_cnn_cache = (B=1, hidden-dim, cache_t2)
r_att_cache.append(new_att_cache[:,:, next_cache_start:, :])
r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) # add elayer dim
if self.normalize_before:
xs = self.after_norm(xs)
return (xs[:, cache_size:, :], r_subsampling_cache,
r_elayers_output_cache, r_conformer_cnn_cache)
# r_att_cache (elayers, head, T, d_k*2)
# r_cnn_cache elayers, B=1, hidden-dim, cache_t2)
r_att_cache = paddle.concat(r_att_cache, axis=0)
r_cnn_cache = paddle.concat(r_cnn_cache, axis=0)
return xs, r_att_cache, r_cnn_cache
def forward_chunk_by_chunk(
self,
@ -313,25 +313,24 @@ class BaseEncoder(nn.Layer):
num_frames = xs.shape[1]
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
subsampling_cache: Optional[paddle.Tensor] = None
elayers_output_cache: Optional[List[paddle.Tensor]] = None
conformer_cnn_cache: Optional[List[paddle.Tensor]] = None
att_cache: paddle.Tensor = paddle.zeros([0,0,0,0])
cnn_cache: paddle.Tensor = paddle.zeros([0,0,0,0])
outputs = []
offset = 0
# Feed forward overlap input step by step
for cur in range(0, num_frames - context + 1, stride):
end = min(cur + decoding_window, num_frames)
chunk_xs = xs[:, cur:end, :]
(y, subsampling_cache, elayers_output_cache,
conformer_cnn_cache) = self.forward_chunk(
chunk_xs, offset, required_cache_size, subsampling_cache,
elayers_output_cache, conformer_cnn_cache)
(y, att_cache, cnn_cache) = self.forward_chunk(
chunk_xs, offset, required_cache_size, att_cache, cnn_cache)
outputs.append(y)
offset += y.shape[1]
ys = paddle.cat(outputs, 1)
# fake mask, just for jit script and compatibility with `forward` api
masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
masks = masks.unsqueeze(1)
masks = paddle.ones([1, 1, ys.shape[1]], dtype=paddle.bool)
return ys, masks

@ -75,49 +75,45 @@ class TransformerEncoderLayer(nn.Layer):
self,
x: paddle.Tensor,
mask: paddle.Tensor,
pos_emb: Optional[paddle.Tensor]=None,
mask_pad: Optional[paddle.Tensor]=None,
output_cache: Optional[paddle.Tensor]=None,
cnn_cache: Optional[paddle.Tensor]=None,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
pos_emb: paddle.Tensor,
mask_pad: paddle.Tensor= paddle.ones([0,0,0], dtype=paddle.bool),
att_cache: paddle.Tensor=paddle.zeros([0,0,0,0]),
cnn_cache: paddle.Tensor=paddle.zeros([0,0,0,0]),
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute encoded features.
Args:
x (paddle.Tensor): Input tensor (#batch, time, size).
mask (paddle.Tensor): Mask tensor for the input (#batch, time).
x (paddle.Tensor): (#batch, time, size)
mask (paddle.Tensor): Mask tensor for the input (#batch, timetime),
(0, 0, 0) means fake mask.
pos_emb (paddle.Tensor): just for interface compatibility
to ConformerEncoderLayer
mask_pad (paddle.Tensor): not used here, it's for interface
compatibility to ConformerEncoderLayer
output_cache (paddle.Tensor): Cache tensor of the output
(#batch, time2, size), time2 < time in x.
cnn_cache (paddle.Tensor): not used here, it's for interface
compatibility to ConformerEncoderLayer
mask_pad (paddle.Tensor): does not used in transformer layer,
just for unified api with conformer.
att_cache (paddle.Tensor): Cache tensor of the KEY & VALUE
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
cnn_cache (paddle.Tensor): Convolution cache in conformer layer
(#batch=1, size, cache_t2), not used here, it's for interface
compatibility to ConformerEncoderLayer.
Returns:
paddle.Tensor: Output tensor (#batch, time, size).
paddle.Tensor: Mask tensor (#batch, time).
paddle.Tensor: Fake cnn cache tensor for api compatibility with Conformer (#batch, channels, time').
paddle.Tensor: Mask tensor (#batch, time, time).
paddle.Tensor: att_cache tensor,
(#batch=1, head, cache_t1 + time, d_k * 2).
paddle.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
"""
residual = x
if self.normalize_before:
x = self.norm1(x)
if output_cache is None:
x_q = x
else:
assert output_cache.shape[0] == x.shape[0]
assert output_cache.shape[1] < x.shape[1]
assert output_cache.shape[2] == self.size
chunk = x.shape[1] - output_cache.shape[1]
x_q = x[:, -chunk:, :]
residual = residual[:, -chunk:, :]
mask = mask[:, -chunk:, :]
x_att, new_att_cache = self.self_attn(
x, x, x, mask, cache=att_cache
)
if self.concat_after:
x_concat = paddle.concat(
(x, self.self_attn(x_q, x, x, mask)), axis=-1)
x_concat = paddle.concat((x, x_att), axis=-1)
x = residual + self.concat_linear(x_concat)
else:
x = residual + self.dropout(self.self_attn(x_q, x, x, mask))
x = residual + self.dropout(x_att)
if not self.normalize_before:
x = self.norm1(x)
@ -128,11 +124,8 @@ class TransformerEncoderLayer(nn.Layer):
if not self.normalize_before:
x = self.norm2(x)
if output_cache is not None:
x = paddle.concat([output_cache, x], axis=1)
fake_cnn_cache = paddle.zeros([1], dtype=x.dtype)
return x, mask, fake_cnn_cache
fake_cnn_cache = paddle.zeros([0,0,0], dtype=x.dtype)
return x, mask, new_att_cache, fake_cnn_cache
class ConformerEncoderLayer(nn.Layer):
@ -192,32 +185,41 @@ class ConformerEncoderLayer(nn.Layer):
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
self.concat_linear = Linear(size + size, size)
if self.concat_after:
self.concat_linear = Linear(size + size, size)
else:
self.concat_linear = nn.Identity()
def forward(
self,
x: paddle.Tensor,
mask: paddle.Tensor,
pos_emb: paddle.Tensor,
mask_pad: Optional[paddle.Tensor]=None,
output_cache: Optional[paddle.Tensor]=None,
cnn_cache: Optional[paddle.Tensor]=None,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
mask_pad: paddle.Tensor= paddle.ones([0,0,0], dtype=paddle.bool),
att_cache: paddle.Tensor=paddle.zeros([0,0,0,0]),
cnn_cache: paddle.Tensor=paddle.zeros([0,0,0,0]),
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute encoded features.
Args:
x (paddle.Tensor): (#batch, time, size)
mask (paddle.Tensor): Mask tensor for the input (#batch, timetime).
pos_emb (paddle.Tensor): positional encoding, must not be None
for ConformerEncoderLayer.
mask_pad (paddle.Tensor): batch padding mask used for conv module, (B, 1, T).
output_cache (paddle.Tensor): Cache tensor of the encoder output
(#batch, time2, size), time2 < time in x.
x (paddle.Tensor): Input tensor (#batch, time, size).
mask (paddle.Tensor): Mask tensor for the input (#batch, time, time).
(0,0,0) means fake mask.
pos_emb (paddle.Tensor): postional encoding, must not be None
for ConformerEncoderLayer
mask_pad (paddle.Tensor): batch padding mask used for conv module.
(#batch, 1time), (0, 0, 0) means fake mask.
att_cache (paddle.Tensor): Cache tensor of the KEY & VALUE
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
cnn_cache (paddle.Tensor): Convolution cache in conformer layer
(#batch=1, size, cache_t2)
Returns:
paddle.Tensor: Output tensor (#batch, time, size).
paddle.Tensor: Mask tensor (#batch, time).
paddle.Tensor: New cnn cache tensor (#batch, channels, time').
paddle.Tensor: Output tensor (#batch, time, size).
paddle.Tensor: Mask tensor (#batch, time, time).
paddle.Tensor: att_cache tensor,
(#batch=1, head, cache_t1 + time, d_k * 2).
paddle.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
"""
# whether to use macaron style FFN
if self.feed_forward_macaron is not None:
residual = x
@ -233,18 +235,8 @@ class ConformerEncoderLayer(nn.Layer):
if self.normalize_before:
x = self.norm_mha(x)
if output_cache is None:
x_q = x
else:
assert output_cache.shape[0] == x.shape[0]
assert output_cache.shape[1] < x.shape[1]
assert output_cache.shape[2] == self.size
chunk = x.shape[1] - output_cache.shape[1]
x_q = x[:, -chunk:, :]
residual = residual[:, -chunk:, :]
mask = mask[:, -chunk:, :]
x_att = self.self_attn(x_q, x, x, pos_emb, mask)
x_att, new_att_cache = self.self_attn(
x, x, x, mask, pos_emb, cache=att_cache)
if self.concat_after:
x_concat = paddle.concat((x, x_att), axis=-1)
@ -257,7 +249,7 @@ class ConformerEncoderLayer(nn.Layer):
# convolution module
# Fake new cnn cache here, and then change it in conv_module
new_cnn_cache = paddle.zeros([1], dtype=x.dtype)
new_cnn_cache = paddle.zeros([0,0,0], dtype=x.dtype)
if self.conv_module is not None:
residual = x
if self.normalize_before:
@ -282,7 +274,4 @@ class ConformerEncoderLayer(nn.Layer):
if self.conv_module is not None:
x = self.norm_final(x)
if output_cache is not None:
x = paddle.concat([output_cache, x], axis=1)
return x, mask, new_cnn_cache
return x, mask, new_att_cache, new_cnn_cache

@ -71,7 +71,8 @@ base = [
"colorlog",
"pathos == 0.2.8",
"braceexpand",
"pyyaml"
"pyyaml",
"pybind11",
]
server = [
@ -90,7 +91,6 @@ requirements = {
"gpustat",
"paddlespeech_ctcdecoders",
"phkit",
"pybind11",
"pypi-kenlm",
"snakeviz",
"sox",

Loading…
Cancel
Save