Merge pull request #2168 from zh794390558/dy2st

[s2t] fix cnn cache dy2st shape
pull/2171/head
Jackwaterveg 2 years ago committed by GitHub
commit 7089c72a15
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -29,6 +29,9 @@ import paddle
from paddle import jit from paddle import jit
from paddle import nn from paddle import nn
from paddlespeech.audio.utils.tensor_utils import add_sos_eos
from paddlespeech.audio.utils.tensor_utils import pad_sequence
from paddlespeech.audio.utils.tensor_utils import th_accuracy
from paddlespeech.s2t.decoders.scorers.ctc import CTCPrefixScorer from paddlespeech.s2t.decoders.scorers.ctc import CTCPrefixScorer
from paddlespeech.s2t.frontend.utility import IGNORE_ID from paddlespeech.s2t.frontend.utility import IGNORE_ID
from paddlespeech.s2t.frontend.utility import load_cmvn from paddlespeech.s2t.frontend.utility import load_cmvn
@ -48,9 +51,6 @@ from paddlespeech.s2t.utils import checkpoint
from paddlespeech.s2t.utils import layer_tools from paddlespeech.s2t.utils import layer_tools
from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.audio.utils.tensor_utils import add_sos_eos
from paddlespeech.audio.utils.tensor_utils import pad_sequence
from paddlespeech.audio.utils.tensor_utils import th_accuracy
from paddlespeech.s2t.utils.utility import log_add from paddlespeech.s2t.utils.utility import log_add
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
@ -605,8 +605,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
xs: paddle.Tensor, xs: paddle.Tensor,
offset: int, offset: int,
required_cache_size: int, required_cache_size: int,
att_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]), att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]), cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" Export interface for c++ call, give input chunk xs, and return """ Export interface for c++ call, give input chunk xs, and return
output from time 0 to current chunk. output from time 0 to current chunk.
@ -625,13 +625,11 @@ class U2BaseModel(ASRInterface, nn.Layer):
(elayers, head, cache_t1, d_k * 2), where (elayers, head, cache_t1, d_k * 2), where
`head * d_k == hidden-dim` and `head * d_k == hidden-dim` and
`cache_t1 == chunk_size * num_decoding_left_chunks`. `cache_t1 == chunk_size * num_decoding_left_chunks`.
`d_k * 2` for att key & value. Default is 0-dims Tensor, `d_k * 2` for att key & value.
it is used for dy2st.
cnn_cache (paddle.Tensor): cache tensor for cnn_module in conformer, cnn_cache (paddle.Tensor): cache tensor for cnn_module in conformer,
(elayers, b=1, hidden-dim, cache_t2), where (elayers, b=1, hidden-dim, cache_t2), where
`cache_t2 == cnn.lorder - 1`. Default is 0-dims Tensor, `cache_t2 == cnn.lorder - 1`.
it is used for dy2st.
Returns: Returns:
paddle.Tensor: output of current input xs, paddle.Tensor: output of current input xs,
with shape (b=1, chunk_size, hidden-dim). with shape (b=1, chunk_size, hidden-dim).
@ -641,8 +639,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
paddle.Tensor: new conformer cnn cache required for next chunk, with paddle.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache. same shape as the original cnn_cache.
""" """
return self.encoder.forward_chunk( return self.encoder.forward_chunk(xs, offset, required_cache_size,
xs, offset, required_cache_size, att_cache, cnn_cache) att_cache, cnn_cache)
# @jit.to_static # @jit.to_static
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:

@ -76,9 +76,9 @@ class TransformerEncoderLayer(nn.Layer):
x: paddle.Tensor, x: paddle.Tensor,
mask: paddle.Tensor, mask: paddle.Tensor,
pos_emb: paddle.Tensor, pos_emb: paddle.Tensor,
mask_pad: paddle.Tensor= paddle.ones([0,0,0], dtype=paddle.bool), mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
att_cache: paddle.Tensor=paddle.zeros([0,0,0,0]), att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor=paddle.zeros([0,0,0,0]), cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute encoded features. """Compute encoded features.
Args: Args:
@ -105,9 +105,7 @@ class TransformerEncoderLayer(nn.Layer):
if self.normalize_before: if self.normalize_before:
x = self.norm1(x) x = self.norm1(x)
x_att, new_att_cache = self.self_attn( x_att, new_att_cache = self.self_attn(x, x, x, mask, cache=att_cache)
x, x, x, mask, cache=att_cache
)
if self.concat_after: if self.concat_after:
x_concat = paddle.concat((x, x_att), axis=-1) x_concat = paddle.concat((x, x_att), axis=-1)
@ -124,7 +122,7 @@ class TransformerEncoderLayer(nn.Layer):
if not self.normalize_before: if not self.normalize_before:
x = self.norm2(x) x = self.norm2(x)
fake_cnn_cache = paddle.zeros([0,0,0], dtype=x.dtype) fake_cnn_cache = paddle.zeros([0, 0, 0], dtype=x.dtype)
return x, mask, new_att_cache, fake_cnn_cache return x, mask, new_att_cache, fake_cnn_cache
@ -195,9 +193,9 @@ class ConformerEncoderLayer(nn.Layer):
x: paddle.Tensor, x: paddle.Tensor,
mask: paddle.Tensor, mask: paddle.Tensor,
pos_emb: paddle.Tensor, pos_emb: paddle.Tensor,
mask_pad: paddle.Tensor= paddle.ones([0,0,0], dtype=paddle.bool), mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
att_cache: paddle.Tensor=paddle.zeros([0,0,0,0]), att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor=paddle.zeros([0,0,0,0]), cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute encoded features. """Compute encoded features.
Args: Args:
@ -211,7 +209,8 @@ class ConformerEncoderLayer(nn.Layer):
att_cache (paddle.Tensor): Cache tensor of the KEY & VALUE att_cache (paddle.Tensor): Cache tensor of the KEY & VALUE
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size. (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
cnn_cache (paddle.Tensor): Convolution cache in conformer layer cnn_cache (paddle.Tensor): Convolution cache in conformer layer
(#batch=1, size, cache_t2) (1, #batch=1, size, cache_t2). First dim will not be used, just
for dy2st.
Returns: Returns:
paddle.Tensor: Output tensor (#batch, time, size). paddle.Tensor: Output tensor (#batch, time, size).
paddle.Tensor: Mask tensor (#batch, time, time). paddle.Tensor: Mask tensor (#batch, time, time).
@ -219,6 +218,8 @@ class ConformerEncoderLayer(nn.Layer):
(#batch=1, head, cache_t1 + time, d_k * 2). (#batch=1, head, cache_t1 + time, d_k * 2).
paddle.Tensor: cnn_cahce tensor (#batch, size, cache_t2). paddle.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
""" """
# (1, #batch=1, size, cache_t2) -> (#batch=1, size, cache_t2)
cnn_cache = paddle.squeeze(cnn_cache, axis=0)
# whether to use macaron style FFN # whether to use macaron style FFN
if self.feed_forward_macaron is not None: if self.feed_forward_macaron is not None:
@ -249,8 +250,7 @@ class ConformerEncoderLayer(nn.Layer):
# convolution module # convolution module
# Fake new cnn cache here, and then change it in conv_module # Fake new cnn cache here, and then change it in conv_module
new_cnn_cache = paddle.zeros([0,0,0], dtype=x.dtype) new_cnn_cache = paddle.zeros([0, 0, 0], dtype=x.dtype)
cnn_cache = paddle.squeeze(cnn_cache, axis=0)
if self.conv_module is not None: if self.conv_module is not None:
residual = x residual = x
if self.normalize_before: if self.normalize_before:
@ -275,4 +275,4 @@ class ConformerEncoderLayer(nn.Layer):
if self.conv_module is not None: if self.conv_module is not None:
x = self.norm_final(x) x = self.norm_final(x)
return x, mask, new_att_cache, new_cnn_cache return x, mask, new_att_cache, new_cnn_cache

Loading…
Cancel
Save