From fb40602d94dd25dda92ed18df0129c42e195385f Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Fri, 8 Jul 2022 08:08:45 +0000 Subject: [PATCH 1/3] refactor attention cache --- paddlespeech/s2t/models/u2/u2.py | 47 ++++--- paddlespeech/s2t/modules/attention.py | 130 +++++++++++++++--- .../s2t/modules/conformer_convolution.py | 18 +-- paddlespeech/s2t/modules/decoder_layer.py | 8 +- paddlespeech/s2t/modules/embedding.py | 2 +- paddlespeech/s2t/modules/encoder.py | 115 ++++++++-------- paddlespeech/s2t/modules/encoder_layer.py | 119 ++++++++-------- setup.py | 4 +- 8 files changed, 267 insertions(+), 176 deletions(-) diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index 100aca18b..3af353600 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -605,29 +605,42 @@ class U2BaseModel(ASRInterface, nn.Layer): xs: paddle.Tensor, offset: int, required_cache_size: int, - subsampling_cache: Optional[paddle.Tensor]=None, - elayers_output_cache: Optional[List[paddle.Tensor]]=None, - conformer_cnn_cache: Optional[List[paddle.Tensor]]=None, - ) -> Tuple[paddle.Tensor, paddle.Tensor, List[paddle.Tensor], List[ - paddle.Tensor]]: + att_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]), + cnn_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]), + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: """ Export interface for c++ call, give input chunk xs, and return output from time 0 to current chunk. + Args: - xs (paddle.Tensor): chunk input - subsampling_cache (Optional[paddle.Tensor]): subsampling cache - elayers_output_cache (Optional[List[paddle.Tensor]]): - transformer/conformer encoder layers output cache - conformer_cnn_cache (Optional[List[paddle.Tensor]]): conformer - cnn cache + xs (paddle.Tensor): chunk input, with shape (b=1, time, mel-dim), + where `time == (chunk_size - 1) * subsample_rate + \ + subsample.right_context + 1` + offset (int): current offset in encoder output time stamp + required_cache_size (int): cache size required for next chunk + compuation + >=0: actual cache size + <0: means all history cache is required + att_cache (paddle.Tensor): cache tensor for KEY & VALUE in + transformer/conformer attention, with shape + (elayers, head, cache_t1, d_k * 2), where + `head * d_k == hidden-dim` and + `cache_t1 == chunk_size * num_decoding_left_chunks`. + `d_k * 2` for att key & value. + cnn_cache (paddle.Tensor): cache tensor for cnn_module in conformer, + (elayers, b=1, hidden-dim, cache_t2), where + `cache_t2 == cnn.lorder - 1` + Returns: - paddle.Tensor: output, it ranges from time 0 to current chunk. - paddle.Tensor: subsampling cache - List[paddle.Tensor]: attention cache - List[paddle.Tensor]: conformer cnn cache + paddle.Tensor: output of current input xs, + with shape (b=1, chunk_size, hidden-dim). + paddle.Tensor: new attention cache required for next chunk, with + dynamic shape (elayers, head, T(?), d_k * 2) + depending on required_cache_size. + paddle.Tensor: new conformer cnn cache required for next chunk, with + same shape as the original cnn_cache. """ return self.encoder.forward_chunk( - xs, offset, required_cache_size, subsampling_cache, - elayers_output_cache, conformer_cnn_cache) + xs, offset, required_cache_size, att_cache, cnn_cache) # @jit.to_static def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py index 438efd2a1..c0b76f08e 100644 --- a/paddlespeech/s2t/modules/attention.py +++ b/paddlespeech/s2t/modules/attention.py @@ -84,9 +84,10 @@ class MultiHeadedAttention(nn.Layer): return q, k, v def forward_attention(self, - value: paddle.Tensor, - scores: paddle.Tensor, - mask: Optional[paddle.Tensor]) -> paddle.Tensor: + value: paddle.Tensor, + scores: paddle.Tensor, + mask: paddle.Tensor = paddle.ones([0, 0, 0], dtype=paddle.bool), + ) -> paddle.Tensor: """Compute attention context vector. Args: value (paddle.Tensor): Transformed value, size @@ -94,14 +95,23 @@ class MultiHeadedAttention(nn.Layer): scores (paddle.Tensor): Attention score, size (#batch, n_head, time1, time2). mask (paddle.Tensor): Mask, size (#batch, 1, time2) or - (#batch, time1, time2). + (#batch, time1, time2), (0, 0, 0) means fake mask. Returns: - paddle.Tensor: Transformed value weighted - by the attention score, (#batch, time1, d_model). + paddle.Tensor: Transformed value (#batch, time1, d_model) + weighted by the attention score (#batch, time1, time2). """ n_batch = value.shape[0] - if mask is not None: + + # When `if mask.size(2) > 0` be True: + # 1. training. + # 2. oonx(16/4, chunk_size/history_size), feed real cache and real mask for the 1st chunk. + # When will `if mask.size(2) > 0` be False? + # 1. onnx(16/-1, -1/-1, 16/0) + # 2. jit (16/-1, -1/-1, 16/0, 16/4) + if paddle.shape(mask)[2] > 0: # time2 > 0 mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + # for last chunk, time2 might be larger than scores.size(-1) + mask = mask[:, :, :, :paddle.shape(scores)[-1]] scores = scores.masked_fill(mask, -float('inf')) attn = paddle.softmax( scores, axis=-1).masked_fill(mask, @@ -121,21 +131,67 @@ class MultiHeadedAttention(nn.Layer): query: paddle.Tensor, key: paddle.Tensor, value: paddle.Tensor, - mask: Optional[paddle.Tensor]) -> paddle.Tensor: + mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool), + pos_emb: paddle.Tensor = paddle.empty([0]), + cache: paddle.Tensor = paddle.zeros([0,0,0,0]) + ) -> Tuple[paddle.Tensor, paddle.Tensor]: """Compute scaled dot product attention. - Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + Args: + query (paddle.Tensor): Query tensor (#batch, time1, size). + key (paddle.Tensor): Key tensor (#batch, time2, size). + value (paddle.Tensor): Value tensor (#batch, time2, size). + mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or (#batch, time1, time2). + 1.When applying cross attention between decoder and encoder, + the batch padding mask for input is in (#batch, 1, T) shape. + 2.When applying self attention of encoder, + the mask is in (#batch, T, T) shape. + 3.When applying self attention of decoder, + the mask is in (#batch, L, L) shape. + 4.If the different position in decoder see different block + of the encoder, such as Mocha, the passed in mask could be + in (#batch, L, T) shape. But there is no such case in current + Wenet. + cache (paddle.Tensor): Cache tensor (1, head, cache_t, d_k * 2), + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` Returns: - torch.Tensor: Output tensor (#batch, time1, d_model). + paddle.Tensor: Output tensor (#batch, time1, d_model). + paddle.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + """ q, k, v = self.forward_qkv(query, key, value) + + # when export onnx model, for 1st chunk, we feed + # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) + # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). + # In all modes, `if cache.size(0) > 0` will alwayse be `True` + # and we will always do splitting and + # concatnation(this will simplify onnx export). Note that + # it's OK to concat & split zero-shaped tensors(see code below). + # when export jit model, for 1st chunk, we always feed + # cache(0, 0, 0, 0) since jit supports dynamic if-branch. + # >>> a = torch.ones((1, 2, 0, 4)) + # >>> b = torch.ones((1, 2, 3, 4)) + # >>> c = torch.cat((a, b), dim=2) + # >>> torch.equal(b, c) # True + # >>> d = torch.split(a, 2, dim=-1) + # >>> torch.equal(d[0], d[1]) # True + if paddle.shape(cache)[0] > 0: + # last dim `d_k * 2` for (key, val) + key_cache, value_cache = paddle.split( + cache, paddle.shape(cache)[-1] // 2, axis=-1) + k = paddle.concat([key_cache, k], axis=2) + v = paddle.concat([value_cache, v], axis=2) + # We do cache slicing in encoder.forward_chunk, since it's + # non-trivial to calculate `next_cache_start` here. + new_cache = paddle.concat((k, v), axis=-1) + scores = paddle.matmul(q, k.transpose([0, 1, 3, 2])) / math.sqrt(self.d_k) - return self.forward_attention(v, scores, mask) + return self.forward_attention(v, scores, mask), new_cache class RelPositionMultiHeadedAttention(MultiHeadedAttention): @@ -192,23 +248,55 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): query: paddle.Tensor, key: paddle.Tensor, value: paddle.Tensor, - pos_emb: paddle.Tensor, - mask: Optional[paddle.Tensor]): + mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool), + pos_emb: paddle.Tensor = paddle.empty([0]), + cache: paddle.Tensor = paddle.zeros([0,0,0,0]) + ) -> Tuple[paddle.Tensor, paddle.Tensor]: """Compute 'Scaled Dot Product Attention' with rel. positional encoding. Args: query (paddle.Tensor): Query tensor (#batch, time1, size). key (paddle.Tensor): Key tensor (#batch, time2, size). value (paddle.Tensor): Value tensor (#batch, time2, size). - pos_emb (paddle.Tensor): Positional embedding tensor - (#batch, time1, size). mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or - (#batch, time1, time2). + (#batch, time1, time2), (0, 0, 0) means fake mask. + pos_emb (paddle.Tensor): Positional embedding tensor + (#batch, time2, size). + cache (paddle.Tensor): Cache tensor (1, head, cache_t, d_k * 2), + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` Returns: paddle.Tensor: Output tensor (#batch, time1, d_model). + paddle.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` """ q, k, v = self.forward_qkv(query, key, value) q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k) + # when export onnx model, for 1st chunk, we feed + # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) + # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). + # In all modes, `if cache.size(0) > 0` will alwayse be `True` + # and we will always do splitting and + # concatnation(this will simplify onnx export). Note that + # it's OK to concat & split zero-shaped tensors(see code below). + # when export jit model, for 1st chunk, we always feed + # cache(0, 0, 0, 0) since jit supports dynamic if-branch. + # >>> a = torch.ones((1, 2, 0, 4)) + # >>> b = torch.ones((1, 2, 3, 4)) + # >>> c = torch.cat((a, b), dim=2) + # >>> torch.equal(b, c) # True + # >>> d = torch.split(a, 2, dim=-1) + # >>> torch.equal(d[0], d[1]) # True + if paddle.shape(cache)[0] > 0: + key_cache, value_cache = paddle.split( + cache, paddle.shape(cache)[-1] // 2, axis=-1) + k = paddle.concat([key_cache, k], axis=2) + v = paddle.concat([value_cache, v], axis=2) + # We do cache slicing in encoder.forward_chunk, since it's + # non-trivial to calculate `next_cache_start` here. + new_cache = paddle.concat((k, v), axis=-1) + n_batch_pos = pos_emb.shape[0] p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) @@ -234,4 +322,4 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): scores = (matrix_ac + matrix_bd) / math.sqrt( self.d_k) # (batch, head, time1, time2) - return self.forward_attention(v, scores, mask) + return self.forward_attention(v, scores, mask), new_cache \ No newline at end of file diff --git a/paddlespeech/s2t/modules/conformer_convolution.py b/paddlespeech/s2t/modules/conformer_convolution.py index 89e652688..c384b9c78 100644 --- a/paddlespeech/s2t/modules/conformer_convolution.py +++ b/paddlespeech/s2t/modules/conformer_convolution.py @@ -108,15 +108,17 @@ class ConvolutionModule(nn.Layer): def forward(self, x: paddle.Tensor, - mask_pad: Optional[paddle.Tensor]=None, - cache: Optional[paddle.Tensor]=None + mask_pad: paddle.Tensor= paddle.ones([0,0,0], dtype=paddle.bool), + cache: paddle.Tensor= paddle.zeros([0,0,0]), ) -> Tuple[paddle.Tensor, paddle.Tensor]: """Compute convolution module. Args: x (paddle.Tensor): Input tensor (#batch, time, channels). - mask_pad (paddle.Tensor): used for batch padding, (#batch, channels, time). + mask_pad (paddle.Tensor): used for batch padding (#batch, 1, time), + (0, 0, 0) means fake mask. cache (paddle.Tensor): left context cache, it is only - used in causal convolution. (#batch, channels, time') + used in causal convolution (#batch, channels, cache_t), + (0, 0, 0) meas fake cache. Returns: paddle.Tensor: Output tensor (#batch, time, channels). paddle.Tensor: Output cache tensor (#batch, channels, time') @@ -125,11 +127,11 @@ class ConvolutionModule(nn.Layer): x = x.transpose([0, 2, 1]) # [B, C, T] # mask batch padding - if mask_pad is not None: + if paddle.shape(mask_pad)[2] > 0: # time > 0 x = x.masked_fill(mask_pad, 0.0) if self.lorder > 0: - if cache is None: + if paddle.shape(cache)[2] == 0: # cache_t == 0 x = nn.functional.pad( x, [self.lorder, 0], 'constant', 0.0, data_format='NCL') else: @@ -143,7 +145,7 @@ class ConvolutionModule(nn.Layer): # It's better we just return None if no cache is requried, # However, for JIT export, here we just fake one tensor instead of # None. - new_cache = paddle.zeros([1], dtype=x.dtype) + new_cache = paddle.zeros([0, 0, 0], dtype=x.dtype) # GLU mechanism x = self.pointwise_conv1(x) # (batch, 2*channel, dim) @@ -159,7 +161,7 @@ class ConvolutionModule(nn.Layer): x = self.pointwise_conv2(x) # mask batch padding - if mask_pad is not None: + if paddle.shape(mask_pad)[2] > 0: # time > 0 x = x.masked_fill(mask_pad, 0.0) x = x.transpose([0, 2, 1]) # [B, T, C] diff --git a/paddlespeech/s2t/modules/decoder_layer.py b/paddlespeech/s2t/modules/decoder_layer.py index b7f8694c1..37b124e84 100644 --- a/paddlespeech/s2t/modules/decoder_layer.py +++ b/paddlespeech/s2t/modules/decoder_layer.py @@ -121,11 +121,11 @@ class DecoderLayer(nn.Layer): if self.concat_after: tgt_concat = paddle.cat( - (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1) + (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]), dim=-1) x = residual + self.concat_linear1(tgt_concat) else: x = residual + self.dropout( - self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)) + self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]) if not self.normalize_before: x = self.norm1(x) @@ -134,11 +134,11 @@ class DecoderLayer(nn.Layer): x = self.norm2(x) if self.concat_after: x_concat = paddle.cat( - (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1) + (x, self.src_attn(x, memory, memory, memory_mask)[0]), dim=-1) x = residual + self.concat_linear2(x_concat) else: x = residual + self.dropout( - self.src_attn(x, memory, memory, memory_mask)) + self.src_attn(x, memory, memory, memory_mask)[0]) if not self.normalize_before: x = self.norm2(x) diff --git a/paddlespeech/s2t/modules/embedding.py b/paddlespeech/s2t/modules/embedding.py index 51e558eb8..3aeebd29b 100644 --- a/paddlespeech/s2t/modules/embedding.py +++ b/paddlespeech/s2t/modules/embedding.py @@ -131,7 +131,7 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface): offset (int): start offset size (int): requried size of position encoding Returns: - paddle.Tensor: Corresponding position encoding + paddle.Tensor: Corresponding position encoding, #[1, T, D]. """ assert offset + size < self.max_len return self.dropout(self.pe[:, offset:offset + size]) diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index 4d31acf1a..e05d0cc45 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -177,7 +177,7 @@ class BaseEncoder(nn.Layer): decoding_chunk_size, self.static_chunk_size, num_decoding_left_chunks) for layer in self.encoders: - xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) if self.normalize_before: xs = self.after_norm(xs) # Here we assume the mask is not changed in encoder layers, so just @@ -190,30 +190,31 @@ class BaseEncoder(nn.Layer): xs: paddle.Tensor, offset: int, required_cache_size: int, - subsampling_cache: Optional[paddle.Tensor]=None, - elayers_output_cache: Optional[List[paddle.Tensor]]=None, - conformer_cnn_cache: Optional[List[paddle.Tensor]]=None, - ) -> Tuple[paddle.Tensor, paddle.Tensor, List[paddle.Tensor], List[ - paddle.Tensor]]: + att_cache: paddle.Tensor = paddle.zeros([0,0,0,0]), + cnn_cache: paddle.Tensor = paddle.zeros([0,0,0,0]), + att_mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool), + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: """ Forward just one chunk Args: - xs (paddle.Tensor): chunk input, [B=1, T, D] + xs (paddle.Tensor): chunk audio feat input, [B=1, T, D], where + `T==(chunk_size-1)*subsampling_rate + subsample.right_context + 1` offset (int): current offset in encoder output time stamp required_cache_size (int): cache size required for next chunk compuation >=0: actual cache size <0: means all history cache is required - subsampling_cache (Optional[paddle.Tensor]): subsampling cache - elayers_output_cache (Optional[List[paddle.Tensor]]): - transformer/conformer encoder layers output cache - conformer_cnn_cache (Optional[List[paddle.Tensor]]): conformer - cnn cache + att_cache(paddle.Tensor): cache tensor for key & val in + transformer/conformer attention. Shape is + (elayers, head, cache_t1, d_k * 2), where`head * d_k == hidden-dim` + and `cache_t1 == chunk_size * num_decoding_left_chunks`. + cnn_cache (paddle.Tensor): cache tensor for cnn_module in conformer, + (elayers, B=1, hidden-dim, cache_t2), where `cache_t2 == cnn.lorder - 1` Returns: - paddle.Tensor: output of current input xs - paddle.Tensor: subsampling cache required for next chunk computation - List[paddle.Tensor]: encoder layers output cache required for next - chunk computation - List[paddle.Tensor]: conformer cnn cache + paddle.Tensor: output of current input xs, (B=1, chunk_size, hidden-dim) + paddle.Tensor: new attention cache required for next chunk, dyanmic shape + (elayers, head, T, d_k*2) depending on required_cache_size + paddle.Tensor: new conformer cnn cache required for next chunk, with + same shape as the original cnn_cache """ assert xs.shape[0] == 1 # batch size must be one # tmp_masks is just for interface compatibility @@ -225,50 +226,49 @@ class BaseEncoder(nn.Layer): if self.global_cmvn is not None: xs = self.global_cmvn(xs) - xs, pos_emb, _ = self.embed( - xs, tmp_masks, offset=offset) #xs=(B, T, D), pos_emb=(B=1, T, D) + # before embed, xs=(B, T, D1), pos_emb=(B=1, T, D) + xs, pos_emb, _ = self.embed(xs, tmp_masks, offset=offset) + # after embed, xs=(B=1, chunk_size, hidden-dim) - if subsampling_cache is not None: - cache_size = subsampling_cache.shape[1] #T - xs = paddle.cat((subsampling_cache, xs), dim=1) - else: - cache_size = 0 + elayers, cache_t1 = paddle.shape(att_cache)[0], paddle.shape(att_cache)[2] + chunk_size = paddle.shape(xs)[1] + attention_key_size = cache_t1 + chunk_size # only used when using `RelPositionMultiHeadedAttention` pos_emb = self.embed.position_encoding( - offset=offset - cache_size, size=xs.shape[1]) + offset=offset - cache_t1, size=attention_key_size) if required_cache_size < 0: next_cache_start = 0 elif required_cache_size == 0: - next_cache_start = xs.shape[1] + next_cache_start = attention_key_size else: - next_cache_start = xs.shape[1] - required_cache_size - r_subsampling_cache = xs[:, next_cache_start:, :] - - # Real mask for transformer/conformer layers - masks = paddle.ones([1, xs.shape[1]], dtype=paddle.bool) - masks = masks.unsqueeze(1) #[B=1, L'=1, T] - r_elayers_output_cache = [] - r_conformer_cnn_cache = [] + next_cache_start = max(attention_key_size - required_cache_size, 0) + + r_att_cache = [] + r_cnn_cache = [] for i, layer in enumerate(self.encoders): - attn_cache = None if elayers_output_cache is None else elayers_output_cache[ - i] - cnn_cache = None if conformer_cnn_cache is None else conformer_cnn_cache[ - i] - xs, _, new_cnn_cache = layer( - xs, - masks, - pos_emb, - output_cache=attn_cache, - cnn_cache=cnn_cache) - r_elayers_output_cache.append(xs[:, next_cache_start:, :]) - r_conformer_cnn_cache.append(new_cnn_cache) + # att_cache[i:i+1] = (1, head, cache_t1, d_k*2) + # cnn_cache[i] = (B=1, hidden-dim, cache_t2) + xs, _, new_att_cache, new_cnn_cache = layer( + xs, att_mask, pos_emb, + att_cache=att_cache[i:i+1] if elayers > 0 else att_cache, + cnn_cache=cnn_cache[i] if paddle.shape(cnn_cache)[0] > 0 else cnn_cache, + ) + # new_att_cache = (1, head, attention_key_size, d_k*2) + # new_cnn_cache = (B=1, hidden-dim, cache_t2) + r_att_cache.append(new_att_cache[:,:, next_cache_start:, :]) + r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) # add elayer dim + if self.normalize_before: xs = self.after_norm(xs) - return (xs[:, cache_size:, :], r_subsampling_cache, - r_elayers_output_cache, r_conformer_cnn_cache) + # r_att_cache (elayers, head, T, d_k*2) + # r_cnn_cache (elayers, B=1, hidden-dim, cache_t2) + r_att_cache = paddle.concat(r_att_cache, axis=0) + r_cnn_cache = paddle.concat(r_cnn_cache, axis=0) + return xs, r_att_cache, r_cnn_cache + def forward_chunk_by_chunk( self, @@ -313,25 +313,24 @@ class BaseEncoder(nn.Layer): num_frames = xs.shape[1] required_cache_size = decoding_chunk_size * num_decoding_left_chunks - subsampling_cache: Optional[paddle.Tensor] = None - elayers_output_cache: Optional[List[paddle.Tensor]] = None - conformer_cnn_cache: Optional[List[paddle.Tensor]] = None + + att_cache: paddle.Tensor = paddle.zeros([0,0,0,0]) + cnn_cache: paddle.Tensor = paddle.zeros([0,0,0,0]) + outputs = [] offset = 0 # Feed forward overlap input step by step for cur in range(0, num_frames - context + 1, stride): end = min(cur + decoding_window, num_frames) chunk_xs = xs[:, cur:end, :] - (y, subsampling_cache, elayers_output_cache, - conformer_cnn_cache) = self.forward_chunk( - chunk_xs, offset, required_cache_size, subsampling_cache, - elayers_output_cache, conformer_cnn_cache) + + (y, att_cache, cnn_cache) = self.forward_chunk( + chunk_xs, offset, required_cache_size, att_cache, cnn_cache) + outputs.append(y) offset += y.shape[1] ys = paddle.cat(outputs, 1) - # fake mask, just for jit script and compatibility with `forward` api - masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) - masks = masks.unsqueeze(1) + masks = paddle.ones([1, 1, ys.shape[1]], dtype=paddle.bool) return ys, masks diff --git a/paddlespeech/s2t/modules/encoder_layer.py b/paddlespeech/s2t/modules/encoder_layer.py index e80a298d6..d91e3f6ef 100644 --- a/paddlespeech/s2t/modules/encoder_layer.py +++ b/paddlespeech/s2t/modules/encoder_layer.py @@ -75,49 +75,45 @@ class TransformerEncoderLayer(nn.Layer): self, x: paddle.Tensor, mask: paddle.Tensor, - pos_emb: Optional[paddle.Tensor]=None, - mask_pad: Optional[paddle.Tensor]=None, - output_cache: Optional[paddle.Tensor]=None, - cnn_cache: Optional[paddle.Tensor]=None, - ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + pos_emb: paddle.Tensor, + mask_pad: paddle.Tensor= paddle.ones([0,0,0], dtype=paddle.bool), + att_cache: paddle.Tensor=paddle.zeros([0,0,0,0]), + cnn_cache: paddle.Tensor=paddle.zeros([0,0,0,0]), + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: """Compute encoded features. Args: - x (paddle.Tensor): Input tensor (#batch, time, size). - mask (paddle.Tensor): Mask tensor for the input (#batch, time). + x (paddle.Tensor): (#batch, time, size) + mask (paddle.Tensor): Mask tensor for the input (#batch, time,time), + (0, 0, 0) means fake mask. pos_emb (paddle.Tensor): just for interface compatibility to ConformerEncoderLayer - mask_pad (paddle.Tensor): not used here, it's for interface - compatibility to ConformerEncoderLayer - output_cache (paddle.Tensor): Cache tensor of the output - (#batch, time2, size), time2 < time in x. - cnn_cache (paddle.Tensor): not used here, it's for interface - compatibility to ConformerEncoderLayer + mask_pad (paddle.Tensor): does not used in transformer layer, + just for unified api with conformer. + att_cache (paddle.Tensor): Cache tensor of the KEY & VALUE + (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. + cnn_cache (paddle.Tensor): Convolution cache in conformer layer + (#batch=1, size, cache_t2), not used here, it's for interface + compatibility to ConformerEncoderLayer. Returns: paddle.Tensor: Output tensor (#batch, time, size). - paddle.Tensor: Mask tensor (#batch, time). - paddle.Tensor: Fake cnn cache tensor for api compatibility with Conformer (#batch, channels, time'). + paddle.Tensor: Mask tensor (#batch, time, time). + paddle.Tensor: att_cache tensor, + (#batch=1, head, cache_t1 + time, d_k * 2). + paddle.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2). """ residual = x if self.normalize_before: x = self.norm1(x) - if output_cache is None: - x_q = x - else: - assert output_cache.shape[0] == x.shape[0] - assert output_cache.shape[1] < x.shape[1] - assert output_cache.shape[2] == self.size - chunk = x.shape[1] - output_cache.shape[1] - x_q = x[:, -chunk:, :] - residual = residual[:, -chunk:, :] - mask = mask[:, -chunk:, :] + x_att, new_att_cache = self.self_attn( + x, x, x, mask, cache=att_cache + ) if self.concat_after: - x_concat = paddle.concat( - (x, self.self_attn(x_q, x, x, mask)), axis=-1) + x_concat = paddle.concat((x, x_att), axis=-1) x = residual + self.concat_linear(x_concat) else: - x = residual + self.dropout(self.self_attn(x_q, x, x, mask)) + x = residual + self.dropout(x_att) if not self.normalize_before: x = self.norm1(x) @@ -128,11 +124,8 @@ class TransformerEncoderLayer(nn.Layer): if not self.normalize_before: x = self.norm2(x) - if output_cache is not None: - x = paddle.concat([output_cache, x], axis=1) - - fake_cnn_cache = paddle.zeros([1], dtype=x.dtype) - return x, mask, fake_cnn_cache + fake_cnn_cache = paddle.zeros([0,0,0], dtype=x.dtype) + return x, mask, new_att_cache, fake_cnn_cache class ConformerEncoderLayer(nn.Layer): @@ -192,32 +185,41 @@ class ConformerEncoderLayer(nn.Layer): self.size = size self.normalize_before = normalize_before self.concat_after = concat_after - self.concat_linear = Linear(size + size, size) + if self.concat_after: + self.concat_linear = Linear(size + size, size) + else: + self.concat_linear = nn.Identity() def forward( self, x: paddle.Tensor, mask: paddle.Tensor, pos_emb: paddle.Tensor, - mask_pad: Optional[paddle.Tensor]=None, - output_cache: Optional[paddle.Tensor]=None, - cnn_cache: Optional[paddle.Tensor]=None, - ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + mask_pad: paddle.Tensor= paddle.ones([0,0,0], dtype=paddle.bool), + att_cache: paddle.Tensor=paddle.zeros([0,0,0,0]), + cnn_cache: paddle.Tensor=paddle.zeros([0,0,0,0]), + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: """Compute encoded features. Args: - x (paddle.Tensor): (#batch, time, size) - mask (paddle.Tensor): Mask tensor for the input (#batch, time,time). - pos_emb (paddle.Tensor): positional encoding, must not be None - for ConformerEncoderLayer. - mask_pad (paddle.Tensor): batch padding mask used for conv module, (B, 1, T). - output_cache (paddle.Tensor): Cache tensor of the encoder output - (#batch, time2, size), time2 < time in x. + x (paddle.Tensor): Input tensor (#batch, time, size). + mask (paddle.Tensor): Mask tensor for the input (#batch, time, time). + (0,0,0) means fake mask. + pos_emb (paddle.Tensor): postional encoding, must not be None + for ConformerEncoderLayer + mask_pad (paddle.Tensor): batch padding mask used for conv module. + (#batch, 1,time), (0, 0, 0) means fake mask. + att_cache (paddle.Tensor): Cache tensor of the KEY & VALUE + (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. cnn_cache (paddle.Tensor): Convolution cache in conformer layer + (#batch=1, size, cache_t2) Returns: - paddle.Tensor: Output tensor (#batch, time, size). - paddle.Tensor: Mask tensor (#batch, time). - paddle.Tensor: New cnn cache tensor (#batch, channels, time'). + paddle.Tensor: Output tensor (#batch, time, size). + paddle.Tensor: Mask tensor (#batch, time, time). + paddle.Tensor: att_cache tensor, + (#batch=1, head, cache_t1 + time, d_k * 2). + paddle.Tensor: cnn_cahce tensor (#batch, size, cache_t2). """ + # whether to use macaron style FFN if self.feed_forward_macaron is not None: residual = x @@ -233,18 +235,8 @@ class ConformerEncoderLayer(nn.Layer): if self.normalize_before: x = self.norm_mha(x) - if output_cache is None: - x_q = x - else: - assert output_cache.shape[0] == x.shape[0] - assert output_cache.shape[1] < x.shape[1] - assert output_cache.shape[2] == self.size - chunk = x.shape[1] - output_cache.shape[1] - x_q = x[:, -chunk:, :] - residual = residual[:, -chunk:, :] - mask = mask[:, -chunk:, :] - - x_att = self.self_attn(x_q, x, x, pos_emb, mask) + x_att, new_att_cache = self.self_attn( + x, x, x, mask, pos_emb, cache=att_cache) if self.concat_after: x_concat = paddle.concat((x, x_att), axis=-1) @@ -257,7 +249,7 @@ class ConformerEncoderLayer(nn.Layer): # convolution module # Fake new cnn cache here, and then change it in conv_module - new_cnn_cache = paddle.zeros([1], dtype=x.dtype) + new_cnn_cache = paddle.zeros([0,0,0], dtype=x.dtype) if self.conv_module is not None: residual = x if self.normalize_before: @@ -282,7 +274,4 @@ class ConformerEncoderLayer(nn.Layer): if self.conv_module is not None: x = self.norm_final(x) - if output_cache is not None: - x = paddle.concat([output_cache, x], axis=1) - - return x, mask, new_cnn_cache + return x, mask, new_att_cache, new_cnn_cache \ No newline at end of file diff --git a/setup.py b/setup.py index a3ef753a0..63e571b9e 100644 --- a/setup.py +++ b/setup.py @@ -71,7 +71,8 @@ base = [ "colorlog", "pathos == 0.2.8", "braceexpand", - "pyyaml" + "pyyaml", + "pybind11", ] server = [ @@ -90,7 +91,6 @@ requirements = { "gpustat", "paddlespeech_ctcdecoders", "phkit", - "pybind11", "pypi-kenlm", "snakeviz", "sox", From 5ca05fea20600a9949e680be62d31d5b614e911d Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Fri, 8 Jul 2022 08:28:50 +0000 Subject: [PATCH 2/3] cli batch process support \t --- paddlespeech/cli/asr/infer.py | 2 +- paddlespeech/cli/executor.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 76dfafb92..f9b4439ec 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -365,7 +365,7 @@ class ASRExecutor(BaseExecutor): except Exception as e: logger.exception(e) logger.error( - "can not open the audio file, please check the audio file format is 'wav'. \n \ + f"can not open the audio file, please check the audio file({audio_file}) format is 'wav'. \n \ you can try to use sox to change the file format.\n \ For example: \n \ sample rate: 16k \n \ diff --git a/paddlespeech/cli/executor.py b/paddlespeech/cli/executor.py index d4187a514..3800c36db 100644 --- a/paddlespeech/cli/executor.py +++ b/paddlespeech/cli/executor.py @@ -108,19 +108,20 @@ class BaseExecutor(ABC): Dict[str, Union[str, os.PathLike]]: A dict with ids and inputs. """ if self._is_job_input(input_): + # .job/.scp/.txt file ret = self._get_job_contents(input_) else: + # job from stdin ret = OrderedDict() - if input_ is None: # Take input from stdin if not sys.stdin.isatty( ): # Avoid getting stuck when stdin is empty. for i, line in enumerate(sys.stdin): line = line.strip() - if len(line.split(' ')) == 1: + if len(line.split()) == 1: ret[str(i + 1)] = line - elif len(line.split(' ')) == 2: - id_, info = line.split(' ') + elif len(line.split()) == 2: + id_, info = line.split() ret[id_] = info else: # No valid input info from one line. continue @@ -170,7 +171,8 @@ class BaseExecutor(ABC): bool: return `True` for job input, `False` otherwise. """ return input_ and os.path.isfile(input_) and (input_.endswith('.job') or - input_.endswith('.txt')) + input_.endswith('.txt') or + input_.endswith('.scp')) def _get_job_contents( self, job_input: os.PathLike) -> Dict[str, Union[str, os.PathLike]]: @@ -189,7 +191,7 @@ class BaseExecutor(ABC): line = line.strip() if not line: continue - k, v = line.split(' ') + k, v = line.split() # space or \t job_contents[k] = v return job_contents From e81849277ede018e575007820c7573c5db13c480 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Fri, 8 Jul 2022 09:36:26 +0000 Subject: [PATCH 3/3] att cache for streaming asr --- .../local/rtf_from_log.py | 2 +- paddlespeech/s2t/models/u2_st/u2_st.py | 47 ++++++++++++------- paddlespeech/s2t/modules/attention.py | 7 ++- paddlespeech/s2t/modules/encoder.py | 3 +- .../engine/asr/online/python/asr_engine.py | 12 ++--- 5 files changed, 41 insertions(+), 30 deletions(-) diff --git a/demos/streaming_asr_server/local/rtf_from_log.py b/demos/streaming_asr_server/local/rtf_from_log.py index 4f30d6400..4b89b48fd 100755 --- a/demos/streaming_asr_server/local/rtf_from_log.py +++ b/demos/streaming_asr_server/local/rtf_from_log.py @@ -38,4 +38,4 @@ if __name__ == '__main__': T += m['T'] P += m['P'] - print(f"RTF: {P/T}") + print(f"RTF: {P/T}, utts: {n}") diff --git a/paddlespeech/s2t/models/u2_st/u2_st.py b/paddlespeech/s2t/models/u2_st/u2_st.py index 00ded9125..e86bbedfa 100644 --- a/paddlespeech/s2t/models/u2_st/u2_st.py +++ b/paddlespeech/s2t/models/u2_st/u2_st.py @@ -401,29 +401,42 @@ class U2STBaseModel(nn.Layer): xs: paddle.Tensor, offset: int, required_cache_size: int, - subsampling_cache: Optional[paddle.Tensor]=None, - elayers_output_cache: Optional[List[paddle.Tensor]]=None, - conformer_cnn_cache: Optional[List[paddle.Tensor]]=None, - ) -> Tuple[paddle.Tensor, paddle.Tensor, List[paddle.Tensor], List[ - paddle.Tensor]]: + att_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]), + cnn_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]), + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: """ Export interface for c++ call, give input chunk xs, and return output from time 0 to current chunk. + Args: - xs (paddle.Tensor): chunk input - subsampling_cache (Optional[paddle.Tensor]): subsampling cache - elayers_output_cache (Optional[List[paddle.Tensor]]): - transformer/conformer encoder layers output cache - conformer_cnn_cache (Optional[List[paddle.Tensor]]): conformer - cnn cache + xs (paddle.Tensor): chunk input, with shape (b=1, time, mel-dim), + where `time == (chunk_size - 1) * subsample_rate + \ + subsample.right_context + 1` + offset (int): current offset in encoder output time stamp + required_cache_size (int): cache size required for next chunk + compuation + >=0: actual cache size + <0: means all history cache is required + att_cache (paddle.Tensor): cache tensor for KEY & VALUE in + transformer/conformer attention, with shape + (elayers, head, cache_t1, d_k * 2), where + `head * d_k == hidden-dim` and + `cache_t1 == chunk_size * num_decoding_left_chunks`. + `d_k * 2` for att key & value. + cnn_cache (paddle.Tensor): cache tensor for cnn_module in conformer, + (elayers, b=1, hidden-dim, cache_t2), where + `cache_t2 == cnn.lorder - 1` + Returns: - paddle.Tensor: output, it ranges from time 0 to current chunk. - paddle.Tensor: subsampling cache - List[paddle.Tensor]: attention cache - List[paddle.Tensor]: conformer cnn cache + paddle.Tensor: output of current input xs, + with shape (b=1, chunk_size, hidden-dim). + paddle.Tensor: new attention cache required for next chunk, with + dynamic shape (elayers, head, T(?), d_k * 2) + depending on required_cache_size. + paddle.Tensor: new conformer cnn cache required for next chunk, with + same shape as the original cnn_cache. """ return self.encoder.forward_chunk( - xs, offset, required_cache_size, subsampling_cache, - elayers_output_cache, conformer_cnn_cache) + xs, offset, required_cache_size, att_cache, cnn_cache) # @jit.to_static def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py index c0b76f08e..454f9c147 100644 --- a/paddlespeech/s2t/modules/attention.py +++ b/paddlespeech/s2t/modules/attention.py @@ -181,8 +181,7 @@ class MultiHeadedAttention(nn.Layer): # >>> torch.equal(d[0], d[1]) # True if paddle.shape(cache)[0] > 0: # last dim `d_k * 2` for (key, val) - key_cache, value_cache = paddle.split( - cache, paddle.shape(cache)[-1] // 2, axis=-1) + key_cache, value_cache = paddle.split(cache, 2, axis=-1) k = paddle.concat([key_cache, k], axis=2) v = paddle.concat([value_cache, v], axis=2) # We do cache slicing in encoder.forward_chunk, since it's @@ -289,8 +288,8 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): # >>> d = torch.split(a, 2, dim=-1) # >>> torch.equal(d[0], d[1]) # True if paddle.shape(cache)[0] > 0: - key_cache, value_cache = paddle.split( - cache, paddle.shape(cache)[-1] // 2, axis=-1) + # last dim `d_k * 2` for (key, val) + key_cache, value_cache = paddle.split(cache, 2, axis=-1) k = paddle.concat([key_cache, k], axis=2) v = paddle.concat([value_cache, v], axis=2) # We do cache slicing in encoder.forward_chunk, since it's diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index e05d0cc45..72300579f 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -230,7 +230,8 @@ class BaseEncoder(nn.Layer): xs, pos_emb, _ = self.embed(xs, tmp_masks, offset=offset) # after embed, xs=(B=1, chunk_size, hidden-dim) - elayers, cache_t1 = paddle.shape(att_cache)[0], paddle.shape(att_cache)[2] + elayers = paddle.shape(att_cache)[0] + cache_t1 = paddle.shape(att_cache)[2] chunk_size = paddle.shape(xs)[1] attention_key_size = cache_t1 + chunk_size diff --git a/paddlespeech/server/engine/asr/online/python/asr_engine.py b/paddlespeech/server/engine/asr/online/python/asr_engine.py index 2bacfecd6..4df38f09d 100644 --- a/paddlespeech/server/engine/asr/online/python/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/python/asr_engine.py @@ -130,9 +130,9 @@ class PaddleASRConnectionHanddler: ## conformer # cache for conformer online - self.subsampling_cache = None - self.elayers_output_cache = None - self.conformer_cnn_cache = None + self.att_cache = paddle.zeros([0,0,0,0]) + self.cnn_cache = paddle.zeros([0,0,0,0]) + self.encoder_out = None # conformer decoding state self.offset = 0 # global offset in decoding frame unit @@ -474,11 +474,9 @@ class PaddleASRConnectionHanddler: # cur chunk chunk_xs = self.cached_feat[:, cur:end, :] # forward chunk - (y, self.subsampling_cache, self.elayers_output_cache, - self.conformer_cnn_cache) = self.model.encoder.forward_chunk( + (y, self.att_cache, self.cnn_cache) = self.model.encoder.forward_chunk( chunk_xs, self.offset, required_cache_size, - self.subsampling_cache, self.elayers_output_cache, - self.conformer_cnn_cache) + self.att_cache, self.cnn_cache) outputs.append(y) # update the global offset, in decoding frame unit