|
|
|
@ -24,6 +24,7 @@ from deepspeech.modules.activation import get_activation
|
|
|
|
|
from deepspeech.modules.attention import MultiHeadedAttention
|
|
|
|
|
from deepspeech.modules.attention import RelPositionMultiHeadedAttention
|
|
|
|
|
from deepspeech.modules.conformer_convolution import ConvolutionModule
|
|
|
|
|
from deepspeech.modules.embedding import NoPositionalEncoding
|
|
|
|
|
from deepspeech.modules.embedding import PositionalEncoding
|
|
|
|
|
from deepspeech.modules.embedding import RelPositionalEncoding
|
|
|
|
|
from deepspeech.modules.encoder_layer import ConformerEncoderLayer
|
|
|
|
@ -101,6 +102,8 @@ class BaseEncoder(nn.Layer):
|
|
|
|
|
pos_enc_class = PositionalEncoding
|
|
|
|
|
elif pos_enc_layer_type == "rel_pos":
|
|
|
|
|
pos_enc_class = RelPositionalEncoding
|
|
|
|
|
elif pos_enc_layer_type == "None":
|
|
|
|
|
pos_enc_class = NoPositionalEncoding
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
|
|
|
|
|
|
|
|
|
@ -155,11 +158,11 @@ class BaseEncoder(nn.Layer):
|
|
|
|
|
encoder output tensor, lens and mask
|
|
|
|
|
"""
|
|
|
|
|
masks = make_non_pad_mask(xs_lens).unsqueeze(1) # (B, 1, L)
|
|
|
|
|
|
|
|
|
|
if self.global_cmvn is not None:
|
|
|
|
|
xs = self.global_cmvn(xs)
|
|
|
|
|
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
|
|
|
|
|
xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0)
|
|
|
|
|
#print("xs", xs)
|
|
|
|
|
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
|
|
|
|
|
masks = masks.astype(paddle.bool)
|
|
|
|
|
#TODO(Hui Zhang): mask_pad = ~masks
|
|
|
|
@ -168,8 +171,15 @@ class BaseEncoder(nn.Layer):
|
|
|
|
|
xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk,
|
|
|
|
|
decoding_chunk_size, self.static_chunk_size,
|
|
|
|
|
num_decoding_left_chunks)
|
|
|
|
|
#print ("chunk_masks", chunk_masks)
|
|
|
|
|
i = 0
|
|
|
|
|
for layer in self.encoders:
|
|
|
|
|
xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
|
|
|
|
if i == 3:
|
|
|
|
|
xs, chunk_masks, _ = layer(
|
|
|
|
|
xs, chunk_masks, pos_emb, mask_pad, is_print=True)
|
|
|
|
|
else:
|
|
|
|
|
xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
|
|
|
|
i += 1
|
|
|
|
|
if self.normalize_before:
|
|
|
|
|
xs = self.after_norm(xs)
|
|
|
|
|
# Here we assume the mask is not changed in encoder layers, so just
|
|
|
|
@ -248,6 +258,8 @@ class BaseEncoder(nn.Layer):
|
|
|
|
|
i]
|
|
|
|
|
cnn_cache = None if conformer_cnn_cache is None else conformer_cnn_cache[
|
|
|
|
|
i]
|
|
|
|
|
#print ("i", i)
|
|
|
|
|
#print ("xs", xs)
|
|
|
|
|
xs, _, new_cnn_cache = layer(
|
|
|
|
|
xs,
|
|
|
|
|
masks,
|
|
|
|
@ -370,6 +382,80 @@ class TransformerEncoder(BaseEncoder):
|
|
|
|
|
concat_after=concat_after) for _ in range(num_blocks)
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
def forward_one_step(
|
|
|
|
|
self,
|
|
|
|
|
xs: paddle.Tensor,
|
|
|
|
|
required_cache_size: int,
|
|
|
|
|
state=(None, None, 0),
|
|
|
|
|
) -> Tuple[paddle.Tensor, paddle.Tensor, List[paddle.Tensor], List[
|
|
|
|
|
paddle.Tensor]]:
|
|
|
|
|
""" Forward just one chunk
|
|
|
|
|
Args:
|
|
|
|
|
xs (paddle.Tensor): chunk input, [B=1, T, D]
|
|
|
|
|
offset (int): current offset in encoder output time stamp
|
|
|
|
|
required_cache_size (int): cache size required for next chunk
|
|
|
|
|
compuation
|
|
|
|
|
>=0: actual cache size
|
|
|
|
|
<0: means all history cache is required
|
|
|
|
|
subsampling_cache (Optional[paddle.Tensor]): subsampling cache
|
|
|
|
|
elayers_output_cache (Optional[List[paddle.Tensor]]):
|
|
|
|
|
transformer/conformer encoder layers output cache
|
|
|
|
|
conformer_cnn_cache (Optional[List[paddle.Tensor]]): conformer
|
|
|
|
|
cnn cache
|
|
|
|
|
Returns:
|
|
|
|
|
paddle.Tensor: output of current input xs
|
|
|
|
|
paddle.Tensor: subsampling cache required for next chunk computation
|
|
|
|
|
List[paddle.Tensor]: encoder layers output cache required for next
|
|
|
|
|
chunk computation
|
|
|
|
|
List[paddle.Tensor]: conformer cnn cache
|
|
|
|
|
"""
|
|
|
|
|
assert xs.shape[0] == 1 # batch size must be one
|
|
|
|
|
# tmp_masks is just for interface compatibility
|
|
|
|
|
# TODO(Hui Zhang): stride_slice not support bool tensor
|
|
|
|
|
# tmp_masks = paddle.ones([1, xs.size(1)], dtype=paddle.bool)
|
|
|
|
|
subsampling_cache, elayers_output_cache, offset = state
|
|
|
|
|
tmp_masks = paddle.ones([1, xs.shape[1]], dtype=paddle.int32)
|
|
|
|
|
tmp_masks = tmp_masks.unsqueeze(1) #[B=1, C=1, T]
|
|
|
|
|
|
|
|
|
|
if self.global_cmvn is not None:
|
|
|
|
|
xs = self.global_cmvn(xs)
|
|
|
|
|
|
|
|
|
|
xs, pos_emb, _ = self.embed(
|
|
|
|
|
xs, tmp_masks, offset=offset) #xs=(B, T, D), pos_emb=(B=1, T, D)
|
|
|
|
|
|
|
|
|
|
if subsampling_cache is not None:
|
|
|
|
|
cache_size = subsampling_cache.shape[1] #T
|
|
|
|
|
xs = paddle.cat((subsampling_cache, xs), dim=1)
|
|
|
|
|
else:
|
|
|
|
|
cache_size = 0
|
|
|
|
|
|
|
|
|
|
# only used when using `RelPositionMultiHeadedAttention`
|
|
|
|
|
pos_emb = self.embed.position_encoding(
|
|
|
|
|
offset=offset - cache_size, size=xs.shape[1])
|
|
|
|
|
|
|
|
|
|
if required_cache_size < 0:
|
|
|
|
|
next_cache_start = 0
|
|
|
|
|
elif required_cache_size == 0:
|
|
|
|
|
next_cache_start = xs.shape[1]
|
|
|
|
|
else:
|
|
|
|
|
next_cache_start = xs.shape[1] - required_cache_size
|
|
|
|
|
r_subsampling_cache = xs[:, next_cache_start:, :]
|
|
|
|
|
|
|
|
|
|
# Real mask for transformer/conformer layers
|
|
|
|
|
masks = paddle.ones([1, xs.shape[1]], dtype=paddle.bool)
|
|
|
|
|
masks = masks.unsqueeze(1) #[B=1, L'=1, T]
|
|
|
|
|
r_elayers_output_cache = []
|
|
|
|
|
for i, layer in enumerate(self.encoders):
|
|
|
|
|
attn_cache = None if elayers_output_cache is None else elayers_output_cache[
|
|
|
|
|
i]
|
|
|
|
|
xs, _, _ = layer(
|
|
|
|
|
xs, masks, pos_emb, output_cache=attn_cache, cnn_cache=None)
|
|
|
|
|
r_elayers_output_cache.append(xs[:, next_cache_start:, :])
|
|
|
|
|
if self.normalize_before:
|
|
|
|
|
xs = self.after_norm(xs)
|
|
|
|
|
new_state = (r_subsampling_cache, r_elayers_output_cache, offset + 1)
|
|
|
|
|
return (xs[:, cache_size:, :], new_state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConformerEncoder(BaseEncoder):
|
|
|
|
|
"""Conformer encoder module."""
|
|
|
|
|