refactor attention cache

pull/2124/head
Hui Zhang 2 years ago
parent e153495519
commit fb40602d94

@ -605,29 +605,42 @@ class U2BaseModel(ASRInterface, nn.Layer):
xs: paddle.Tensor, xs: paddle.Tensor,
offset: int, offset: int,
required_cache_size: int, required_cache_size: int,
subsampling_cache: Optional[paddle.Tensor]=None, att_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]),
elayers_output_cache: Optional[List[paddle.Tensor]]=None, cnn_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]),
conformer_cnn_cache: Optional[List[paddle.Tensor]]=None, ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
) -> Tuple[paddle.Tensor, paddle.Tensor, List[paddle.Tensor], List[
paddle.Tensor]]:
""" Export interface for c++ call, give input chunk xs, and return """ Export interface for c++ call, give input chunk xs, and return
output from time 0 to current chunk. output from time 0 to current chunk.
Args: Args:
xs (paddle.Tensor): chunk input xs (paddle.Tensor): chunk input, with shape (b=1, time, mel-dim),
subsampling_cache (Optional[paddle.Tensor]): subsampling cache where `time == (chunk_size - 1) * subsample_rate + \
elayers_output_cache (Optional[List[paddle.Tensor]]): subsample.right_context + 1`
transformer/conformer encoder layers output cache offset (int): current offset in encoder output time stamp
conformer_cnn_cache (Optional[List[paddle.Tensor]]): conformer required_cache_size (int): cache size required for next chunk
cnn cache compuation
>=0: actual cache size
<0: means all history cache is required
att_cache (paddle.Tensor): cache tensor for KEY & VALUE in
transformer/conformer attention, with shape
(elayers, head, cache_t1, d_k * 2), where
`head * d_k == hidden-dim` and
`cache_t1 == chunk_size * num_decoding_left_chunks`.
`d_k * 2` for att key & value.
cnn_cache (paddle.Tensor): cache tensor for cnn_module in conformer,
(elayers, b=1, hidden-dim, cache_t2), where
`cache_t2 == cnn.lorder - 1`
Returns: Returns:
paddle.Tensor: output, it ranges from time 0 to current chunk. paddle.Tensor: output of current input xs,
paddle.Tensor: subsampling cache with shape (b=1, chunk_size, hidden-dim).
List[paddle.Tensor]: attention cache paddle.Tensor: new attention cache required for next chunk, with
List[paddle.Tensor]: conformer cnn cache dynamic shape (elayers, head, T(?), d_k * 2)
depending on required_cache_size.
paddle.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
""" """
return self.encoder.forward_chunk( return self.encoder.forward_chunk(
xs, offset, required_cache_size, subsampling_cache, xs, offset, required_cache_size, att_cache, cnn_cache)
elayers_output_cache, conformer_cnn_cache)
# @jit.to_static # @jit.to_static
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:

@ -84,9 +84,10 @@ class MultiHeadedAttention(nn.Layer):
return q, k, v return q, k, v
def forward_attention(self, def forward_attention(self,
value: paddle.Tensor, value: paddle.Tensor,
scores: paddle.Tensor, scores: paddle.Tensor,
mask: Optional[paddle.Tensor]) -> paddle.Tensor: mask: paddle.Tensor = paddle.ones([0, 0, 0], dtype=paddle.bool),
) -> paddle.Tensor:
"""Compute attention context vector. """Compute attention context vector.
Args: Args:
value (paddle.Tensor): Transformed value, size value (paddle.Tensor): Transformed value, size
@ -94,14 +95,23 @@ class MultiHeadedAttention(nn.Layer):
scores (paddle.Tensor): Attention score, size scores (paddle.Tensor): Attention score, size
(#batch, n_head, time1, time2). (#batch, n_head, time1, time2).
mask (paddle.Tensor): Mask, size (#batch, 1, time2) or mask (paddle.Tensor): Mask, size (#batch, 1, time2) or
(#batch, time1, time2). (#batch, time1, time2), (0, 0, 0) means fake mask.
Returns: Returns:
paddle.Tensor: Transformed value weighted paddle.Tensor: Transformed value (#batch, time1, d_model)
by the attention score, (#batch, time1, d_model). weighted by the attention score (#batch, time1, time2).
""" """
n_batch = value.shape[0] n_batch = value.shape[0]
if mask is not None:
# When `if mask.size(2) > 0` be True:
# 1. training.
# 2. oonx(16/4, chunk_size/history_size), feed real cache and real mask for the 1st chunk.
# When will `if mask.size(2) > 0` be False?
# 1. onnx(16/-1, -1/-1, 16/0)
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
if paddle.shape(mask)[2] > 0: # time2 > 0
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
# for last chunk, time2 might be larger than scores.size(-1)
mask = mask[:, :, :, :paddle.shape(scores)[-1]]
scores = scores.masked_fill(mask, -float('inf')) scores = scores.masked_fill(mask, -float('inf'))
attn = paddle.softmax( attn = paddle.softmax(
scores, axis=-1).masked_fill(mask, scores, axis=-1).masked_fill(mask,
@ -121,21 +131,67 @@ class MultiHeadedAttention(nn.Layer):
query: paddle.Tensor, query: paddle.Tensor,
key: paddle.Tensor, key: paddle.Tensor,
value: paddle.Tensor, value: paddle.Tensor,
mask: Optional[paddle.Tensor]) -> paddle.Tensor: mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool),
pos_emb: paddle.Tensor = paddle.empty([0]),
cache: paddle.Tensor = paddle.zeros([0,0,0,0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute scaled dot product attention. """Compute scaled dot product attention.
Args: Args:
query (torch.Tensor): Query tensor (#batch, time1, size). query (paddle.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size). key (paddle.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size). value (paddle.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2). (#batch, time1, time2).
1.When applying cross attention between decoder and encoder,
the batch padding mask for input is in (#batch, 1, T) shape.
2.When applying self attention of encoder,
the mask is in (#batch, T, T) shape.
3.When applying self attention of decoder,
the mask is in (#batch, L, L) shape.
4.If the different position in decoder see different block
of the encoder, such as Mocha, the passed in mask could be
in (#batch, L, T) shape. But there is no such case in current
Wenet.
cache (paddle.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns: Returns:
torch.Tensor: Output tensor (#batch, time1, d_model). paddle.Tensor: Output tensor (#batch, time1, d_model).
paddle.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
""" """
q, k, v = self.forward_qkv(query, key, value) q, k, v = self.forward_qkv(query, key, value)
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if paddle.shape(cache)[0] > 0:
# last dim `d_k * 2` for (key, val)
key_cache, value_cache = paddle.split(
cache, paddle.shape(cache)[-1] // 2, axis=-1)
k = paddle.concat([key_cache, k], axis=2)
v = paddle.concat([value_cache, v], axis=2)
# We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = paddle.concat((k, v), axis=-1)
scores = paddle.matmul(q, scores = paddle.matmul(q,
k.transpose([0, 1, 3, 2])) / math.sqrt(self.d_k) k.transpose([0, 1, 3, 2])) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask) return self.forward_attention(v, scores, mask), new_cache
class RelPositionMultiHeadedAttention(MultiHeadedAttention): class RelPositionMultiHeadedAttention(MultiHeadedAttention):
@ -192,23 +248,55 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
query: paddle.Tensor, query: paddle.Tensor,
key: paddle.Tensor, key: paddle.Tensor,
value: paddle.Tensor, value: paddle.Tensor,
pos_emb: paddle.Tensor, mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool),
mask: Optional[paddle.Tensor]): pos_emb: paddle.Tensor = paddle.empty([0]),
cache: paddle.Tensor = paddle.zeros([0,0,0,0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding. """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args: Args:
query (paddle.Tensor): Query tensor (#batch, time1, size). query (paddle.Tensor): Query tensor (#batch, time1, size).
key (paddle.Tensor): Key tensor (#batch, time2, size). key (paddle.Tensor): Key tensor (#batch, time2, size).
value (paddle.Tensor): Value tensor (#batch, time2, size). value (paddle.Tensor): Value tensor (#batch, time2, size).
pos_emb (paddle.Tensor): Positional embedding tensor
(#batch, time1, size).
mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2). (#batch, time1, time2), (0, 0, 0) means fake mask.
pos_emb (paddle.Tensor): Positional embedding tensor
(#batch, time2, size).
cache (paddle.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns: Returns:
paddle.Tensor: Output tensor (#batch, time1, d_model). paddle.Tensor: Output tensor (#batch, time1, d_model).
paddle.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
""" """
q, k, v = self.forward_qkv(query, key, value) q, k, v = self.forward_qkv(query, key, value)
q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k) q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if paddle.shape(cache)[0] > 0:
key_cache, value_cache = paddle.split(
cache, paddle.shape(cache)[-1] // 2, axis=-1)
k = paddle.concat([key_cache, k], axis=2)
v = paddle.concat([value_cache, v], axis=2)
# We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = paddle.concat((k, v), axis=-1)
n_batch_pos = pos_emb.shape[0] n_batch_pos = pos_emb.shape[0]
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
@ -234,4 +322,4 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
scores = (matrix_ac + matrix_bd) / math.sqrt( scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k) # (batch, head, time1, time2) self.d_k) # (batch, head, time1, time2)
return self.forward_attention(v, scores, mask) return self.forward_attention(v, scores, mask), new_cache

@ -108,15 +108,17 @@ class ConvolutionModule(nn.Layer):
def forward(self, def forward(self,
x: paddle.Tensor, x: paddle.Tensor,
mask_pad: Optional[paddle.Tensor]=None, mask_pad: paddle.Tensor= paddle.ones([0,0,0], dtype=paddle.bool),
cache: Optional[paddle.Tensor]=None cache: paddle.Tensor= paddle.zeros([0,0,0]),
) -> Tuple[paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute convolution module. """Compute convolution module.
Args: Args:
x (paddle.Tensor): Input tensor (#batch, time, channels). x (paddle.Tensor): Input tensor (#batch, time, channels).
mask_pad (paddle.Tensor): used for batch padding, (#batch, channels, time). mask_pad (paddle.Tensor): used for batch padding (#batch, 1, time),
(0, 0, 0) means fake mask.
cache (paddle.Tensor): left context cache, it is only cache (paddle.Tensor): left context cache, it is only
used in causal convolution. (#batch, channels, time') used in causal convolution (#batch, channels, cache_t),
(0, 0, 0) meas fake cache.
Returns: Returns:
paddle.Tensor: Output tensor (#batch, time, channels). paddle.Tensor: Output tensor (#batch, time, channels).
paddle.Tensor: Output cache tensor (#batch, channels, time') paddle.Tensor: Output cache tensor (#batch, channels, time')
@ -125,11 +127,11 @@ class ConvolutionModule(nn.Layer):
x = x.transpose([0, 2, 1]) # [B, C, T] x = x.transpose([0, 2, 1]) # [B, C, T]
# mask batch padding # mask batch padding
if mask_pad is not None: if paddle.shape(mask_pad)[2] > 0: # time > 0
x = x.masked_fill(mask_pad, 0.0) x = x.masked_fill(mask_pad, 0.0)
if self.lorder > 0: if self.lorder > 0:
if cache is None: if paddle.shape(cache)[2] == 0: # cache_t == 0
x = nn.functional.pad( x = nn.functional.pad(
x, [self.lorder, 0], 'constant', 0.0, data_format='NCL') x, [self.lorder, 0], 'constant', 0.0, data_format='NCL')
else: else:
@ -143,7 +145,7 @@ class ConvolutionModule(nn.Layer):
# It's better we just return None if no cache is requried, # It's better we just return None if no cache is requried,
# However, for JIT export, here we just fake one tensor instead of # However, for JIT export, here we just fake one tensor instead of
# None. # None.
new_cache = paddle.zeros([1], dtype=x.dtype) new_cache = paddle.zeros([0, 0, 0], dtype=x.dtype)
# GLU mechanism # GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, dim) x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
@ -159,7 +161,7 @@ class ConvolutionModule(nn.Layer):
x = self.pointwise_conv2(x) x = self.pointwise_conv2(x)
# mask batch padding # mask batch padding
if mask_pad is not None: if paddle.shape(mask_pad)[2] > 0: # time > 0
x = x.masked_fill(mask_pad, 0.0) x = x.masked_fill(mask_pad, 0.0)
x = x.transpose([0, 2, 1]) # [B, T, C] x = x.transpose([0, 2, 1]) # [B, T, C]

@ -121,11 +121,11 @@ class DecoderLayer(nn.Layer):
if self.concat_after: if self.concat_after:
tgt_concat = paddle.cat( tgt_concat = paddle.cat(
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1) (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]), dim=-1)
x = residual + self.concat_linear1(tgt_concat) x = residual + self.concat_linear1(tgt_concat)
else: else:
x = residual + self.dropout( x = residual + self.dropout(
self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)) self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
if not self.normalize_before: if not self.normalize_before:
x = self.norm1(x) x = self.norm1(x)
@ -134,11 +134,11 @@ class DecoderLayer(nn.Layer):
x = self.norm2(x) x = self.norm2(x)
if self.concat_after: if self.concat_after:
x_concat = paddle.cat( x_concat = paddle.cat(
(x, self.src_attn(x, memory, memory, memory_mask)), dim=-1) (x, self.src_attn(x, memory, memory, memory_mask)[0]), dim=-1)
x = residual + self.concat_linear2(x_concat) x = residual + self.concat_linear2(x_concat)
else: else:
x = residual + self.dropout( x = residual + self.dropout(
self.src_attn(x, memory, memory, memory_mask)) self.src_attn(x, memory, memory, memory_mask)[0])
if not self.normalize_before: if not self.normalize_before:
x = self.norm2(x) x = self.norm2(x)

@ -131,7 +131,7 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface):
offset (int): start offset offset (int): start offset
size (int): requried size of position encoding size (int): requried size of position encoding
Returns: Returns:
paddle.Tensor: Corresponding position encoding paddle.Tensor: Corresponding position encoding, #[1, T, D].
""" """
assert offset + size < self.max_len assert offset + size < self.max_len
return self.dropout(self.pe[:, offset:offset + size]) return self.dropout(self.pe[:, offset:offset + size])

@ -177,7 +177,7 @@ class BaseEncoder(nn.Layer):
decoding_chunk_size, self.static_chunk_size, decoding_chunk_size, self.static_chunk_size,
num_decoding_left_chunks) num_decoding_left_chunks)
for layer in self.encoders: for layer in self.encoders:
xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad) xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
if self.normalize_before: if self.normalize_before:
xs = self.after_norm(xs) xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just # Here we assume the mask is not changed in encoder layers, so just
@ -190,30 +190,31 @@ class BaseEncoder(nn.Layer):
xs: paddle.Tensor, xs: paddle.Tensor,
offset: int, offset: int,
required_cache_size: int, required_cache_size: int,
subsampling_cache: Optional[paddle.Tensor]=None, att_cache: paddle.Tensor = paddle.zeros([0,0,0,0]),
elayers_output_cache: Optional[List[paddle.Tensor]]=None, cnn_cache: paddle.Tensor = paddle.zeros([0,0,0,0]),
conformer_cnn_cache: Optional[List[paddle.Tensor]]=None, att_mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool),
) -> Tuple[paddle.Tensor, paddle.Tensor, List[paddle.Tensor], List[ ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
paddle.Tensor]]:
""" Forward just one chunk """ Forward just one chunk
Args: Args:
xs (paddle.Tensor): chunk input, [B=1, T, D] xs (paddle.Tensor): chunk audio feat input, [B=1, T, D], where
`T==(chunk_size-1)*subsampling_rate + subsample.right_context + 1`
offset (int): current offset in encoder output time stamp offset (int): current offset in encoder output time stamp
required_cache_size (int): cache size required for next chunk required_cache_size (int): cache size required for next chunk
compuation compuation
>=0: actual cache size >=0: actual cache size
<0: means all history cache is required <0: means all history cache is required
subsampling_cache (Optional[paddle.Tensor]): subsampling cache att_cache(paddle.Tensor): cache tensor for key & val in
elayers_output_cache (Optional[List[paddle.Tensor]]): transformer/conformer attention. Shape is
transformer/conformer encoder layers output cache (elayers, head, cache_t1, d_k * 2), where`head * d_k == hidden-dim`
conformer_cnn_cache (Optional[List[paddle.Tensor]]): conformer and `cache_t1 == chunk_size * num_decoding_left_chunks`.
cnn cache cnn_cache (paddle.Tensor): cache tensor for cnn_module in conformer,
(elayers, B=1, hidden-dim, cache_t2), where `cache_t2 == cnn.lorder - 1`
Returns: Returns:
paddle.Tensor: output of current input xs paddle.Tensor: output of current input xs, (B=1, chunk_size, hidden-dim)
paddle.Tensor: subsampling cache required for next chunk computation paddle.Tensor: new attention cache required for next chunk, dyanmic shape
List[paddle.Tensor]: encoder layers output cache required for next (elayers, head, T, d_k*2) depending on required_cache_size
chunk computation paddle.Tensor: new conformer cnn cache required for next chunk, with
List[paddle.Tensor]: conformer cnn cache same shape as the original cnn_cache
""" """
assert xs.shape[0] == 1 # batch size must be one assert xs.shape[0] == 1 # batch size must be one
# tmp_masks is just for interface compatibility # tmp_masks is just for interface compatibility
@ -225,50 +226,49 @@ class BaseEncoder(nn.Layer):
if self.global_cmvn is not None: if self.global_cmvn is not None:
xs = self.global_cmvn(xs) xs = self.global_cmvn(xs)
xs, pos_emb, _ = self.embed( # before embed, xs=(B, T, D1), pos_emb=(B=1, T, D)
xs, tmp_masks, offset=offset) #xs=(B, T, D), pos_emb=(B=1, T, D) xs, pos_emb, _ = self.embed(xs, tmp_masks, offset=offset)
# after embed, xs=(B=1, chunk_size, hidden-dim)
if subsampling_cache is not None: elayers, cache_t1 = paddle.shape(att_cache)[0], paddle.shape(att_cache)[2]
cache_size = subsampling_cache.shape[1] #T chunk_size = paddle.shape(xs)[1]
xs = paddle.cat((subsampling_cache, xs), dim=1) attention_key_size = cache_t1 + chunk_size
else:
cache_size = 0
# only used when using `RelPositionMultiHeadedAttention` # only used when using `RelPositionMultiHeadedAttention`
pos_emb = self.embed.position_encoding( pos_emb = self.embed.position_encoding(
offset=offset - cache_size, size=xs.shape[1]) offset=offset - cache_t1, size=attention_key_size)
if required_cache_size < 0: if required_cache_size < 0:
next_cache_start = 0 next_cache_start = 0
elif required_cache_size == 0: elif required_cache_size == 0:
next_cache_start = xs.shape[1] next_cache_start = attention_key_size
else: else:
next_cache_start = xs.shape[1] - required_cache_size next_cache_start = max(attention_key_size - required_cache_size, 0)
r_subsampling_cache = xs[:, next_cache_start:, :]
r_att_cache = []
# Real mask for transformer/conformer layers r_cnn_cache = []
masks = paddle.ones([1, xs.shape[1]], dtype=paddle.bool)
masks = masks.unsqueeze(1) #[B=1, L'=1, T]
r_elayers_output_cache = []
r_conformer_cnn_cache = []
for i, layer in enumerate(self.encoders): for i, layer in enumerate(self.encoders):
attn_cache = None if elayers_output_cache is None else elayers_output_cache[ # att_cache[i:i+1] = (1, head, cache_t1, d_k*2)
i] # cnn_cache[i] = (B=1, hidden-dim, cache_t2)
cnn_cache = None if conformer_cnn_cache is None else conformer_cnn_cache[ xs, _, new_att_cache, new_cnn_cache = layer(
i] xs, att_mask, pos_emb,
xs, _, new_cnn_cache = layer( att_cache=att_cache[i:i+1] if elayers > 0 else att_cache,
xs, cnn_cache=cnn_cache[i] if paddle.shape(cnn_cache)[0] > 0 else cnn_cache,
masks, )
pos_emb, # new_att_cache = (1, head, attention_key_size, d_k*2)
output_cache=attn_cache, # new_cnn_cache = (B=1, hidden-dim, cache_t2)
cnn_cache=cnn_cache) r_att_cache.append(new_att_cache[:,:, next_cache_start:, :])
r_elayers_output_cache.append(xs[:, next_cache_start:, :]) r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) # add elayer dim
r_conformer_cnn_cache.append(new_cnn_cache)
if self.normalize_before: if self.normalize_before:
xs = self.after_norm(xs) xs = self.after_norm(xs)
return (xs[:, cache_size:, :], r_subsampling_cache, # r_att_cache (elayers, head, T, d_k*2)
r_elayers_output_cache, r_conformer_cnn_cache) # r_cnn_cache elayers, B=1, hidden-dim, cache_t2)
r_att_cache = paddle.concat(r_att_cache, axis=0)
r_cnn_cache = paddle.concat(r_cnn_cache, axis=0)
return xs, r_att_cache, r_cnn_cache
def forward_chunk_by_chunk( def forward_chunk_by_chunk(
self, self,
@ -313,25 +313,24 @@ class BaseEncoder(nn.Layer):
num_frames = xs.shape[1] num_frames = xs.shape[1]
required_cache_size = decoding_chunk_size * num_decoding_left_chunks required_cache_size = decoding_chunk_size * num_decoding_left_chunks
subsampling_cache: Optional[paddle.Tensor] = None
elayers_output_cache: Optional[List[paddle.Tensor]] = None att_cache: paddle.Tensor = paddle.zeros([0,0,0,0])
conformer_cnn_cache: Optional[List[paddle.Tensor]] = None cnn_cache: paddle.Tensor = paddle.zeros([0,0,0,0])
outputs = [] outputs = []
offset = 0 offset = 0
# Feed forward overlap input step by step # Feed forward overlap input step by step
for cur in range(0, num_frames - context + 1, stride): for cur in range(0, num_frames - context + 1, stride):
end = min(cur + decoding_window, num_frames) end = min(cur + decoding_window, num_frames)
chunk_xs = xs[:, cur:end, :] chunk_xs = xs[:, cur:end, :]
(y, subsampling_cache, elayers_output_cache,
conformer_cnn_cache) = self.forward_chunk( (y, att_cache, cnn_cache) = self.forward_chunk(
chunk_xs, offset, required_cache_size, subsampling_cache, chunk_xs, offset, required_cache_size, att_cache, cnn_cache)
elayers_output_cache, conformer_cnn_cache)
outputs.append(y) outputs.append(y)
offset += y.shape[1] offset += y.shape[1]
ys = paddle.cat(outputs, 1) ys = paddle.cat(outputs, 1)
# fake mask, just for jit script and compatibility with `forward` api masks = paddle.ones([1, 1, ys.shape[1]], dtype=paddle.bool)
masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
masks = masks.unsqueeze(1)
return ys, masks return ys, masks

@ -75,49 +75,45 @@ class TransformerEncoderLayer(nn.Layer):
self, self,
x: paddle.Tensor, x: paddle.Tensor,
mask: paddle.Tensor, mask: paddle.Tensor,
pos_emb: Optional[paddle.Tensor]=None, pos_emb: paddle.Tensor,
mask_pad: Optional[paddle.Tensor]=None, mask_pad: paddle.Tensor= paddle.ones([0,0,0], dtype=paddle.bool),
output_cache: Optional[paddle.Tensor]=None, att_cache: paddle.Tensor=paddle.zeros([0,0,0,0]),
cnn_cache: Optional[paddle.Tensor]=None, cnn_cache: paddle.Tensor=paddle.zeros([0,0,0,0]),
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute encoded features. """Compute encoded features.
Args: Args:
x (paddle.Tensor): Input tensor (#batch, time, size). x (paddle.Tensor): (#batch, time, size)
mask (paddle.Tensor): Mask tensor for the input (#batch, time). mask (paddle.Tensor): Mask tensor for the input (#batch, timetime),
(0, 0, 0) means fake mask.
pos_emb (paddle.Tensor): just for interface compatibility pos_emb (paddle.Tensor): just for interface compatibility
to ConformerEncoderLayer to ConformerEncoderLayer
mask_pad (paddle.Tensor): not used here, it's for interface mask_pad (paddle.Tensor): does not used in transformer layer,
compatibility to ConformerEncoderLayer just for unified api with conformer.
output_cache (paddle.Tensor): Cache tensor of the output att_cache (paddle.Tensor): Cache tensor of the KEY & VALUE
(#batch, time2, size), time2 < time in x. (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
cnn_cache (paddle.Tensor): not used here, it's for interface cnn_cache (paddle.Tensor): Convolution cache in conformer layer
compatibility to ConformerEncoderLayer (#batch=1, size, cache_t2), not used here, it's for interface
compatibility to ConformerEncoderLayer.
Returns: Returns:
paddle.Tensor: Output tensor (#batch, time, size). paddle.Tensor: Output tensor (#batch, time, size).
paddle.Tensor: Mask tensor (#batch, time). paddle.Tensor: Mask tensor (#batch, time, time).
paddle.Tensor: Fake cnn cache tensor for api compatibility with Conformer (#batch, channels, time'). paddle.Tensor: att_cache tensor,
(#batch=1, head, cache_t1 + time, d_k * 2).
paddle.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
""" """
residual = x residual = x
if self.normalize_before: if self.normalize_before:
x = self.norm1(x) x = self.norm1(x)
if output_cache is None: x_att, new_att_cache = self.self_attn(
x_q = x x, x, x, mask, cache=att_cache
else: )
assert output_cache.shape[0] == x.shape[0]
assert output_cache.shape[1] < x.shape[1]
assert output_cache.shape[2] == self.size
chunk = x.shape[1] - output_cache.shape[1]
x_q = x[:, -chunk:, :]
residual = residual[:, -chunk:, :]
mask = mask[:, -chunk:, :]
if self.concat_after: if self.concat_after:
x_concat = paddle.concat( x_concat = paddle.concat((x, x_att), axis=-1)
(x, self.self_attn(x_q, x, x, mask)), axis=-1)
x = residual + self.concat_linear(x_concat) x = residual + self.concat_linear(x_concat)
else: else:
x = residual + self.dropout(self.self_attn(x_q, x, x, mask)) x = residual + self.dropout(x_att)
if not self.normalize_before: if not self.normalize_before:
x = self.norm1(x) x = self.norm1(x)
@ -128,11 +124,8 @@ class TransformerEncoderLayer(nn.Layer):
if not self.normalize_before: if not self.normalize_before:
x = self.norm2(x) x = self.norm2(x)
if output_cache is not None: fake_cnn_cache = paddle.zeros([0,0,0], dtype=x.dtype)
x = paddle.concat([output_cache, x], axis=1) return x, mask, new_att_cache, fake_cnn_cache
fake_cnn_cache = paddle.zeros([1], dtype=x.dtype)
return x, mask, fake_cnn_cache
class ConformerEncoderLayer(nn.Layer): class ConformerEncoderLayer(nn.Layer):
@ -192,32 +185,41 @@ class ConformerEncoderLayer(nn.Layer):
self.size = size self.size = size
self.normalize_before = normalize_before self.normalize_before = normalize_before
self.concat_after = concat_after self.concat_after = concat_after
self.concat_linear = Linear(size + size, size) if self.concat_after:
self.concat_linear = Linear(size + size, size)
else:
self.concat_linear = nn.Identity()
def forward( def forward(
self, self,
x: paddle.Tensor, x: paddle.Tensor,
mask: paddle.Tensor, mask: paddle.Tensor,
pos_emb: paddle.Tensor, pos_emb: paddle.Tensor,
mask_pad: Optional[paddle.Tensor]=None, mask_pad: paddle.Tensor= paddle.ones([0,0,0], dtype=paddle.bool),
output_cache: Optional[paddle.Tensor]=None, att_cache: paddle.Tensor=paddle.zeros([0,0,0,0]),
cnn_cache: Optional[paddle.Tensor]=None, cnn_cache: paddle.Tensor=paddle.zeros([0,0,0,0]),
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute encoded features. """Compute encoded features.
Args: Args:
x (paddle.Tensor): (#batch, time, size) x (paddle.Tensor): Input tensor (#batch, time, size).
mask (paddle.Tensor): Mask tensor for the input (#batch, timetime). mask (paddle.Tensor): Mask tensor for the input (#batch, time, time).
pos_emb (paddle.Tensor): positional encoding, must not be None (0,0,0) means fake mask.
for ConformerEncoderLayer. pos_emb (paddle.Tensor): postional encoding, must not be None
mask_pad (paddle.Tensor): batch padding mask used for conv module, (B, 1, T). for ConformerEncoderLayer
output_cache (paddle.Tensor): Cache tensor of the encoder output mask_pad (paddle.Tensor): batch padding mask used for conv module.
(#batch, time2, size), time2 < time in x. (#batch, 1time), (0, 0, 0) means fake mask.
att_cache (paddle.Tensor): Cache tensor of the KEY & VALUE
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
cnn_cache (paddle.Tensor): Convolution cache in conformer layer cnn_cache (paddle.Tensor): Convolution cache in conformer layer
(#batch=1, size, cache_t2)
Returns: Returns:
paddle.Tensor: Output tensor (#batch, time, size). paddle.Tensor: Output tensor (#batch, time, size).
paddle.Tensor: Mask tensor (#batch, time). paddle.Tensor: Mask tensor (#batch, time, time).
paddle.Tensor: New cnn cache tensor (#batch, channels, time'). paddle.Tensor: att_cache tensor,
(#batch=1, head, cache_t1 + time, d_k * 2).
paddle.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
""" """
# whether to use macaron style FFN # whether to use macaron style FFN
if self.feed_forward_macaron is not None: if self.feed_forward_macaron is not None:
residual = x residual = x
@ -233,18 +235,8 @@ class ConformerEncoderLayer(nn.Layer):
if self.normalize_before: if self.normalize_before:
x = self.norm_mha(x) x = self.norm_mha(x)
if output_cache is None: x_att, new_att_cache = self.self_attn(
x_q = x x, x, x, mask, pos_emb, cache=att_cache)
else:
assert output_cache.shape[0] == x.shape[0]
assert output_cache.shape[1] < x.shape[1]
assert output_cache.shape[2] == self.size
chunk = x.shape[1] - output_cache.shape[1]
x_q = x[:, -chunk:, :]
residual = residual[:, -chunk:, :]
mask = mask[:, -chunk:, :]
x_att = self.self_attn(x_q, x, x, pos_emb, mask)
if self.concat_after: if self.concat_after:
x_concat = paddle.concat((x, x_att), axis=-1) x_concat = paddle.concat((x, x_att), axis=-1)
@ -257,7 +249,7 @@ class ConformerEncoderLayer(nn.Layer):
# convolution module # convolution module
# Fake new cnn cache here, and then change it in conv_module # Fake new cnn cache here, and then change it in conv_module
new_cnn_cache = paddle.zeros([1], dtype=x.dtype) new_cnn_cache = paddle.zeros([0,0,0], dtype=x.dtype)
if self.conv_module is not None: if self.conv_module is not None:
residual = x residual = x
if self.normalize_before: if self.normalize_before:
@ -282,7 +274,4 @@ class ConformerEncoderLayer(nn.Layer):
if self.conv_module is not None: if self.conv_module is not None:
x = self.norm_final(x) x = self.norm_final(x)
if output_cache is not None: return x, mask, new_att_cache, new_cnn_cache
x = paddle.concat([output_cache, x], axis=1)
return x, mask, new_cnn_cache

@ -71,7 +71,8 @@ base = [
"colorlog", "colorlog",
"pathos == 0.2.8", "pathos == 0.2.8",
"braceexpand", "braceexpand",
"pyyaml" "pyyaml",
"pybind11",
] ]
server = [ server = [
@ -90,7 +91,6 @@ requirements = {
"gpustat", "gpustat",
"paddlespeech_ctcdecoders", "paddlespeech_ctcdecoders",
"phkit", "phkit",
"pybind11",
"pypi-kenlm", "pypi-kenlm",
"snakeviz", "snakeviz",
"sox", "sox",

Loading…
Cancel
Save