att cache for streaming asr

pull/2124/head
Hui Zhang 2 years ago
parent 5ca05fea20
commit e81849277e

@ -38,4 +38,4 @@ if __name__ == '__main__':
T += m['T']
P += m['P']
print(f"RTF: {P/T}")
print(f"RTF: {P/T}, utts: {n}")

@ -401,29 +401,42 @@ class U2STBaseModel(nn.Layer):
xs: paddle.Tensor,
offset: int,
required_cache_size: int,
subsampling_cache: Optional[paddle.Tensor]=None,
elayers_output_cache: Optional[List[paddle.Tensor]]=None,
conformer_cnn_cache: Optional[List[paddle.Tensor]]=None,
) -> Tuple[paddle.Tensor, paddle.Tensor, List[paddle.Tensor], List[
paddle.Tensor]]:
att_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]),
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" Export interface for c++ call, give input chunk xs, and return
output from time 0 to current chunk.
Args:
xs (paddle.Tensor): chunk input
subsampling_cache (Optional[paddle.Tensor]): subsampling cache
elayers_output_cache (Optional[List[paddle.Tensor]]):
transformer/conformer encoder layers output cache
conformer_cnn_cache (Optional[List[paddle.Tensor]]): conformer
cnn cache
xs (paddle.Tensor): chunk input, with shape (b=1, time, mel-dim),
where `time == (chunk_size - 1) * subsample_rate + \
subsample.right_context + 1`
offset (int): current offset in encoder output time stamp
required_cache_size (int): cache size required for next chunk
compuation
>=0: actual cache size
<0: means all history cache is required
att_cache (paddle.Tensor): cache tensor for KEY & VALUE in
transformer/conformer attention, with shape
(elayers, head, cache_t1, d_k * 2), where
`head * d_k == hidden-dim` and
`cache_t1 == chunk_size * num_decoding_left_chunks`.
`d_k * 2` for att key & value.
cnn_cache (paddle.Tensor): cache tensor for cnn_module in conformer,
(elayers, b=1, hidden-dim, cache_t2), where
`cache_t2 == cnn.lorder - 1`
Returns:
paddle.Tensor: output, it ranges from time 0 to current chunk.
paddle.Tensor: subsampling cache
List[paddle.Tensor]: attention cache
List[paddle.Tensor]: conformer cnn cache
paddle.Tensor: output of current input xs,
with shape (b=1, chunk_size, hidden-dim).
paddle.Tensor: new attention cache required for next chunk, with
dynamic shape (elayers, head, T(?), d_k * 2)
depending on required_cache_size.
paddle.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
"""
return self.encoder.forward_chunk(
xs, offset, required_cache_size, subsampling_cache,
elayers_output_cache, conformer_cnn_cache)
xs, offset, required_cache_size, att_cache, cnn_cache)
# @jit.to_static
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:

@ -181,8 +181,7 @@ class MultiHeadedAttention(nn.Layer):
# >>> torch.equal(d[0], d[1]) # True
if paddle.shape(cache)[0] > 0:
# last dim `d_k * 2` for (key, val)
key_cache, value_cache = paddle.split(
cache, paddle.shape(cache)[-1] // 2, axis=-1)
key_cache, value_cache = paddle.split(cache, 2, axis=-1)
k = paddle.concat([key_cache, k], axis=2)
v = paddle.concat([value_cache, v], axis=2)
# We do cache slicing in encoder.forward_chunk, since it's
@ -289,8 +288,8 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if paddle.shape(cache)[0] > 0:
key_cache, value_cache = paddle.split(
cache, paddle.shape(cache)[-1] // 2, axis=-1)
# last dim `d_k * 2` for (key, val)
key_cache, value_cache = paddle.split(cache, 2, axis=-1)
k = paddle.concat([key_cache, k], axis=2)
v = paddle.concat([value_cache, v], axis=2)
# We do cache slicing in encoder.forward_chunk, since it's

@ -230,7 +230,8 @@ class BaseEncoder(nn.Layer):
xs, pos_emb, _ = self.embed(xs, tmp_masks, offset=offset)
# after embed, xs=(B=1, chunk_size, hidden-dim)
elayers, cache_t1 = paddle.shape(att_cache)[0], paddle.shape(att_cache)[2]
elayers = paddle.shape(att_cache)[0]
cache_t1 = paddle.shape(att_cache)[2]
chunk_size = paddle.shape(xs)[1]
attention_key_size = cache_t1 + chunk_size

@ -130,9 +130,9 @@ class PaddleASRConnectionHanddler:
## conformer
# cache for conformer online
self.subsampling_cache = None
self.elayers_output_cache = None
self.conformer_cnn_cache = None
self.att_cache = paddle.zeros([0,0,0,0])
self.cnn_cache = paddle.zeros([0,0,0,0])
self.encoder_out = None
# conformer decoding state
self.offset = 0 # global offset in decoding frame unit
@ -474,11 +474,9 @@ class PaddleASRConnectionHanddler:
# cur chunk
chunk_xs = self.cached_feat[:, cur:end, :]
# forward chunk
(y, self.subsampling_cache, self.elayers_output_cache,
self.conformer_cnn_cache) = self.model.encoder.forward_chunk(
(y, self.att_cache, self.cnn_cache) = self.model.encoder.forward_chunk(
chunk_xs, self.offset, required_cache_size,
self.subsampling_cache, self.elayers_output_cache,
self.conformer_cnn_cache)
self.att_cache, self.cnn_cache)
outputs.append(y)
# update the global offset, in decoding frame unit

Loading…
Cancel
Save