|
|
|
@ -14,36 +14,49 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
# Modified from wenet(https://github.com/wenet-e2e/wenet)
|
|
|
|
|
"""Encoder definition."""
|
|
|
|
|
from typing import Tuple, Union, Optional, List
|
|
|
|
|
from typing import List
|
|
|
|
|
from typing import Optional
|
|
|
|
|
from typing import Tuple
|
|
|
|
|
from typing import Union
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
|
from paddle import nn
|
|
|
|
|
from typeguard import check_argument_types
|
|
|
|
|
|
|
|
|
|
from paddlespeech.s2t.modules.activation import get_activation
|
|
|
|
|
from paddlespeech.s2t.modules.align import LayerNorm, Linear
|
|
|
|
|
from paddlespeech.s2t.modules.attention import MultiHeadedAttention, RelPositionMultiHeadedAttention2
|
|
|
|
|
from paddlespeech.s2t.modules.align import LayerNorm
|
|
|
|
|
from paddlespeech.s2t.modules.align import Linear
|
|
|
|
|
from paddlespeech.s2t.modules.attention import MultiHeadedAttention
|
|
|
|
|
from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention
|
|
|
|
|
from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention2
|
|
|
|
|
from paddlespeech.s2t.modules.conformer_convolution import ConvolutionModule
|
|
|
|
|
from paddlespeech.s2t.modules.convolution import ConvolutionModule2
|
|
|
|
|
from paddlespeech.s2t.modules.embedding import NoPositionalEncoding
|
|
|
|
|
from paddlespeech.s2t.modules.embedding import PositionalEncoding
|
|
|
|
|
from paddlespeech.s2t.modules.embedding import RelPositionalEncoding
|
|
|
|
|
from paddlespeech.s2t.modules.encoder_layer import ConformerEncoderLayer, SqueezeformerEncoderLayer
|
|
|
|
|
from paddlespeech.s2t.modules.encoder_layer import ConformerEncoderLayer
|
|
|
|
|
from paddlespeech.s2t.modules.encoder_layer import SqueezeformerEncoderLayer
|
|
|
|
|
from paddlespeech.s2t.modules.encoder_layer import TransformerEncoderLayer
|
|
|
|
|
from paddlespeech.s2t.modules.mask import add_optional_chunk_mask
|
|
|
|
|
from paddlespeech.s2t.modules.mask import make_non_pad_mask
|
|
|
|
|
from paddlespeech.s2t.modules.positionwise_feed_forward import PositionwiseFeedForward, PositionwiseFeedForward2
|
|
|
|
|
from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling4, TimeReductionLayerStream, TimeReductionLayer1D, \
|
|
|
|
|
DepthwiseConv2DSubsampling4, TimeReductionLayer2D
|
|
|
|
|
from paddlespeech.s2t.modules.positionwise_feed_forward import PositionwiseFeedForward
|
|
|
|
|
from paddlespeech.s2t.modules.positionwise_feed_forward import PositionwiseFeedForward2
|
|
|
|
|
from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling4
|
|
|
|
|
from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling6
|
|
|
|
|
from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling8
|
|
|
|
|
from paddlespeech.s2t.modules.subsampling import DepthwiseConv2DSubsampling4
|
|
|
|
|
from paddlespeech.s2t.modules.subsampling import LinearNoSubsampling
|
|
|
|
|
from paddlespeech.s2t.modules.subsampling import TimeReductionLayer1D
|
|
|
|
|
from paddlespeech.s2t.modules.subsampling import TimeReductionLayer2D
|
|
|
|
|
from paddlespeech.s2t.modules.subsampling import TimeReductionLayerStream
|
|
|
|
|
from paddlespeech.s2t.utils.log import Log
|
|
|
|
|
|
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
|
|
|
|
|
|
__all__ = ["BaseEncoder", 'TransformerEncoder', "ConformerEncoder", "SqueezeformerEncoder"]
|
|
|
|
|
__all__ = [
|
|
|
|
|
"BaseEncoder", 'TransformerEncoder', "ConformerEncoder",
|
|
|
|
|
"SqueezeformerEncoder"
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseEncoder(nn.Layer):
|
|
|
|
@ -492,37 +505,35 @@ class ConformerEncoder(BaseEncoder):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SqueezeformerEncoder(nn.Layer):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
input_size: int,
|
|
|
|
|
encoder_dim: int = 256,
|
|
|
|
|
output_size: int = 256,
|
|
|
|
|
attention_heads: int = 4,
|
|
|
|
|
num_blocks: int = 12,
|
|
|
|
|
reduce_idx: Optional[Union[int, List[int]]] = 5,
|
|
|
|
|
recover_idx: Optional[Union[int, List[int]]] = 11,
|
|
|
|
|
feed_forward_expansion_factor: int = 4,
|
|
|
|
|
dw_stride: bool = False,
|
|
|
|
|
input_dropout_rate: float = 0.1,
|
|
|
|
|
pos_enc_layer_type: str = "rel_pos",
|
|
|
|
|
time_reduction_layer_type: str = "conv1d",
|
|
|
|
|
do_rel_shift: bool = True,
|
|
|
|
|
feed_forward_dropout_rate: float = 0.1,
|
|
|
|
|
attention_dropout_rate: float = 0.1,
|
|
|
|
|
cnn_module_kernel: int = 31,
|
|
|
|
|
cnn_norm_type: str = "layer_norm",
|
|
|
|
|
dropout: float = 0.1,
|
|
|
|
|
causal: bool = False,
|
|
|
|
|
adaptive_scale: bool = True,
|
|
|
|
|
activation_type: str = "swish",
|
|
|
|
|
init_weights: bool = True,
|
|
|
|
|
global_cmvn: paddle.nn.Layer = None,
|
|
|
|
|
normalize_before: bool = False,
|
|
|
|
|
use_dynamic_chunk: bool = False,
|
|
|
|
|
concat_after: bool = False,
|
|
|
|
|
static_chunk_size: int = 0,
|
|
|
|
|
use_dynamic_left_chunk: bool = False
|
|
|
|
|
):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
input_size: int,
|
|
|
|
|
encoder_dim: int=256,
|
|
|
|
|
output_size: int=256,
|
|
|
|
|
attention_heads: int=4,
|
|
|
|
|
num_blocks: int=12,
|
|
|
|
|
reduce_idx: Optional[Union[int, List[int]]]=5,
|
|
|
|
|
recover_idx: Optional[Union[int, List[int]]]=11,
|
|
|
|
|
feed_forward_expansion_factor: int=4,
|
|
|
|
|
dw_stride: bool=False,
|
|
|
|
|
input_dropout_rate: float=0.1,
|
|
|
|
|
pos_enc_layer_type: str="rel_pos",
|
|
|
|
|
time_reduction_layer_type: str="conv1d",
|
|
|
|
|
do_rel_shift: bool=True,
|
|
|
|
|
feed_forward_dropout_rate: float=0.1,
|
|
|
|
|
attention_dropout_rate: float=0.1,
|
|
|
|
|
cnn_module_kernel: int=31,
|
|
|
|
|
cnn_norm_type: str="layer_norm",
|
|
|
|
|
dropout: float=0.1,
|
|
|
|
|
causal: bool=False,
|
|
|
|
|
adaptive_scale: bool=True,
|
|
|
|
|
activation_type: str="swish",
|
|
|
|
|
init_weights: bool=True,
|
|
|
|
|
global_cmvn: paddle.nn.Layer=None,
|
|
|
|
|
normalize_before: bool=False,
|
|
|
|
|
use_dynamic_chunk: bool=False,
|
|
|
|
|
concat_after: bool=False,
|
|
|
|
|
static_chunk_size: int=0,
|
|
|
|
|
use_dynamic_left_chunk: bool=False):
|
|
|
|
|
"""Construct SqueezeformerEncoder
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
@ -577,49 +588,40 @@ class SqueezeformerEncoder(nn.Layer):
|
|
|
|
|
# self-attention module definition
|
|
|
|
|
if pos_enc_layer_type != "rel_pos":
|
|
|
|
|
encoder_selfattn_layer = MultiHeadedAttention
|
|
|
|
|
encoder_selfattn_layer_args = (attention_heads,
|
|
|
|
|
output_size,
|
|
|
|
|
encoder_selfattn_layer_args = (attention_heads, output_size,
|
|
|
|
|
attention_dropout_rate)
|
|
|
|
|
else:
|
|
|
|
|
encoder_selfattn_layer = RelPositionMultiHeadedAttention2
|
|
|
|
|
encoder_selfattn_layer_args = (attention_heads,
|
|
|
|
|
encoder_dim,
|
|
|
|
|
attention_dropout_rate,
|
|
|
|
|
do_rel_shift,
|
|
|
|
|
adaptive_scale,
|
|
|
|
|
init_weights)
|
|
|
|
|
encoder_selfattn_layer_args = (attention_heads, encoder_dim,
|
|
|
|
|
attention_dropout_rate, do_rel_shift,
|
|
|
|
|
adaptive_scale, init_weights)
|
|
|
|
|
|
|
|
|
|
# feed-forward module definition
|
|
|
|
|
positionwise_layer = PositionwiseFeedForward2
|
|
|
|
|
positionwise_layer_args = (encoder_dim,
|
|
|
|
|
encoder_dim * feed_forward_expansion_factor,
|
|
|
|
|
feed_forward_dropout_rate,
|
|
|
|
|
activation,
|
|
|
|
|
adaptive_scale,
|
|
|
|
|
init_weights)
|
|
|
|
|
positionwise_layer_args = (
|
|
|
|
|
encoder_dim, encoder_dim * feed_forward_expansion_factor,
|
|
|
|
|
feed_forward_dropout_rate, activation, adaptive_scale, init_weights)
|
|
|
|
|
|
|
|
|
|
# convolution module definition
|
|
|
|
|
convolution_layer = ConvolutionModule2
|
|
|
|
|
convolution_layer_args = (encoder_dim, cnn_module_kernel, activation,
|
|
|
|
|
cnn_norm_type, causal, True, adaptive_scale, init_weights)
|
|
|
|
|
cnn_norm_type, causal, True, adaptive_scale,
|
|
|
|
|
init_weights)
|
|
|
|
|
|
|
|
|
|
self.embed = DepthwiseConv2DSubsampling4(1, encoder_dim,
|
|
|
|
|
RelPositionalEncoding(encoder_dim, dropout_rate=0.1),
|
|
|
|
|
dw_stride,
|
|
|
|
|
input_size,
|
|
|
|
|
input_dropout_rate,
|
|
|
|
|
init_weights)
|
|
|
|
|
self.embed = DepthwiseConv2DSubsampling4(
|
|
|
|
|
1, encoder_dim,
|
|
|
|
|
RelPositionalEncoding(encoder_dim, dropout_rate=0.1), dw_stride,
|
|
|
|
|
input_size, input_dropout_rate, init_weights)
|
|
|
|
|
|
|
|
|
|
self.preln = LayerNorm(encoder_dim)
|
|
|
|
|
self.encoders = paddle.nn.LayerList([SqueezeformerEncoderLayer(
|
|
|
|
|
encoder_dim,
|
|
|
|
|
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
|
|
|
|
positionwise_layer(*positionwise_layer_args),
|
|
|
|
|
convolution_layer(*convolution_layer_args),
|
|
|
|
|
positionwise_layer(*positionwise_layer_args),
|
|
|
|
|
normalize_before,
|
|
|
|
|
dropout,
|
|
|
|
|
concat_after) for _ in range(num_blocks)
|
|
|
|
|
self.encoders = paddle.nn.LayerList([
|
|
|
|
|
SqueezeformerEncoderLayer(
|
|
|
|
|
encoder_dim,
|
|
|
|
|
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
|
|
|
|
positionwise_layer(*positionwise_layer_args),
|
|
|
|
|
convolution_layer(*convolution_layer_args),
|
|
|
|
|
positionwise_layer(*positionwise_layer_args), normalize_before,
|
|
|
|
|
dropout, concat_after) for _ in range(num_blocks)
|
|
|
|
|
])
|
|
|
|
|
if time_reduction_layer_type == 'conv1d':
|
|
|
|
|
time_reduction_layer = TimeReductionLayer1D
|
|
|
|
@ -637,7 +639,8 @@ class SqueezeformerEncoder(nn.Layer):
|
|
|
|
|
time_reduction_layer = TimeReductionLayer2D
|
|
|
|
|
time_reduction_layer_args = {'encoder_dim': encoder_dim}
|
|
|
|
|
|
|
|
|
|
self.time_reduction_layer = time_reduction_layer(**time_reduction_layer_args)
|
|
|
|
|
self.time_reduction_layer = time_reduction_layer(
|
|
|
|
|
**time_reduction_layer_args)
|
|
|
|
|
self.time_recover_layer = Linear(encoder_dim, encoder_dim)
|
|
|
|
|
self.final_proj = None
|
|
|
|
|
if output_size != encoder_dim:
|
|
|
|
@ -650,8 +653,8 @@ class SqueezeformerEncoder(nn.Layer):
|
|
|
|
|
self,
|
|
|
|
|
xs: paddle.Tensor,
|
|
|
|
|
xs_lens: paddle.Tensor,
|
|
|
|
|
decoding_chunk_size: int = 0,
|
|
|
|
|
num_decoding_left_chunks: int = -1,
|
|
|
|
|
decoding_chunk_size: int=0,
|
|
|
|
|
num_decoding_left_chunks: int=-1,
|
|
|
|
|
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
|
|
|
|
"""Embed positions in tensor.
|
|
|
|
|
Args:
|
|
|
|
@ -674,12 +677,10 @@ class SqueezeformerEncoder(nn.Layer):
|
|
|
|
|
xs = self.global_cmvn(xs)
|
|
|
|
|
xs, pos_emb, masks = self.embed(xs, masks)
|
|
|
|
|
mask_pad = ~masks
|
|
|
|
|
chunk_masks = add_optional_chunk_mask(xs, masks,
|
|
|
|
|
self.use_dynamic_chunk,
|
|
|
|
|
self.use_dynamic_left_chunk,
|
|
|
|
|
decoding_chunk_size,
|
|
|
|
|
self.static_chunk_size,
|
|
|
|
|
num_decoding_left_chunks)
|
|
|
|
|
chunk_masks = add_optional_chunk_mask(
|
|
|
|
|
xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk,
|
|
|
|
|
decoding_chunk_size, self.static_chunk_size,
|
|
|
|
|
num_decoding_left_chunks)
|
|
|
|
|
xs_lens = chunk_masks.squeeze(1).sum(1)
|
|
|
|
|
xs = self.preln(xs)
|
|
|
|
|
recover_activations: \
|
|
|
|
@ -688,15 +689,18 @@ class SqueezeformerEncoder(nn.Layer):
|
|
|
|
|
for i, layer in enumerate(self.encoders):
|
|
|
|
|
if self.reduce_idx is not None:
|
|
|
|
|
if self.time_reduce is not None and i in self.reduce_idx:
|
|
|
|
|
recover_activations.append((xs, chunk_masks, pos_emb, mask_pad))
|
|
|
|
|
xs, xs_lens, chunk_masks, mask_pad = self.time_reduction_layer(xs, xs_lens, chunk_masks, mask_pad)
|
|
|
|
|
recover_activations.append(
|
|
|
|
|
(xs, chunk_masks, pos_emb, mask_pad))
|
|
|
|
|
xs, xs_lens, chunk_masks, mask_pad = self.time_reduction_layer(
|
|
|
|
|
xs, xs_lens, chunk_masks, mask_pad)
|
|
|
|
|
pos_emb = pos_emb[:, ::2, :]
|
|
|
|
|
index += 1
|
|
|
|
|
|
|
|
|
|
if self.recover_idx is not None:
|
|
|
|
|
if self.time_reduce == 'recover' and i in self.recover_idx:
|
|
|
|
|
index -= 1
|
|
|
|
|
recover_tensor, recover_chunk_masks, recover_pos_emb, recover_mask_pad = recover_activations[index]
|
|
|
|
|
recover_tensor, recover_chunk_masks, recover_pos_emb, recover_mask_pad = recover_activations[
|
|
|
|
|
index]
|
|
|
|
|
# recover output length for ctc decode
|
|
|
|
|
xs = paddle.repeat_interleave(xs, repeats=2, axis=1)
|
|
|
|
|
xs = self.time_recover_layer(xs)
|
|
|
|
@ -732,16 +736,16 @@ class SqueezeformerEncoder(nn.Layer):
|
|
|
|
|
for exp, rc_idx in enumerate(self.recover_idx):
|
|
|
|
|
if i >= rc_idx:
|
|
|
|
|
recover_exp = exp + 1
|
|
|
|
|
return int(2 ** (reduce_exp - recover_exp))
|
|
|
|
|
return int(2**(reduce_exp - recover_exp))
|
|
|
|
|
|
|
|
|
|
def forward_chunk(
|
|
|
|
|
self,
|
|
|
|
|
xs: paddle.Tensor,
|
|
|
|
|
offset: int,
|
|
|
|
|
required_cache_size: int,
|
|
|
|
|
att_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]),
|
|
|
|
|
cnn_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]),
|
|
|
|
|
att_mask: paddle.Tensor = paddle.ones([0, 0, 0], dtype=paddle.bool),
|
|
|
|
|
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
|
|
|
|
|
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
|
|
|
|
|
att_mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
|
|
|
|
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
|
|
|
|
""" Forward just one chunk
|
|
|
|
|
|
|
|
|
@ -786,7 +790,8 @@ class SqueezeformerEncoder(nn.Layer):
|
|
|
|
|
elayers, cache_t1 = att_cache.shape[0], att_cache.shape[2]
|
|
|
|
|
chunk_size = xs.shape[1]
|
|
|
|
|
attention_key_size = cache_t1 + chunk_size
|
|
|
|
|
pos_emb = self.embed.position_encoding(offset=offset - cache_t1, size=attention_key_size)
|
|
|
|
|
pos_emb = self.embed.position_encoding(
|
|
|
|
|
offset=offset - cache_t1, size=attention_key_size)
|
|
|
|
|
if required_cache_size < 0:
|
|
|
|
|
next_cache_start = 0
|
|
|
|
|
elif required_cache_size == 0:
|
|
|
|
@ -811,15 +816,18 @@ class SqueezeformerEncoder(nn.Layer):
|
|
|
|
|
# shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
|
|
|
|
|
if self.reduce_idx is not None:
|
|
|
|
|
if self.time_reduce is not None and i in self.reduce_idx:
|
|
|
|
|
recover_activations.append((xs, att_mask, pos_emb, mask_pad))
|
|
|
|
|
xs, xs_lens, att_mask, mask_pad = self.time_reduction_layer(xs, xs_lens, att_mask, mask_pad)
|
|
|
|
|
recover_activations.append(
|
|
|
|
|
(xs, att_mask, pos_emb, mask_pad))
|
|
|
|
|
xs, xs_lens, att_mask, mask_pad = self.time_reduction_layer(
|
|
|
|
|
xs, xs_lens, att_mask, mask_pad)
|
|
|
|
|
pos_emb = pos_emb[:, ::2, :]
|
|
|
|
|
index += 1
|
|
|
|
|
|
|
|
|
|
if self.recover_idx is not None:
|
|
|
|
|
if self.time_reduce == 'recover' and i in self.recover_idx:
|
|
|
|
|
index -= 1
|
|
|
|
|
recover_tensor, recover_att_mask, recover_pos_emb, recover_mask_pad = recover_activations[index]
|
|
|
|
|
recover_tensor, recover_att_mask, recover_pos_emb, recover_mask_pad = recover_activations[
|
|
|
|
|
index]
|
|
|
|
|
# recover output length for ctc decode
|
|
|
|
|
xs = paddle.repeat_interleave(xs, repeats=2, axis=1)
|
|
|
|
|
xs = self.time_recover_layer(xs)
|
|
|
|
@ -830,7 +838,9 @@ class SqueezeformerEncoder(nn.Layer):
|
|
|
|
|
mask_pad = recover_mask_pad
|
|
|
|
|
|
|
|
|
|
factor = self.calculate_downsampling_factor(i)
|
|
|
|
|
att_cache1 = att_cache[i:i + 1][:, :, ::factor, :][:, :, :pos_emb.shape[1] - xs.shape[1], :]
|
|
|
|
|
att_cache1 = att_cache[
|
|
|
|
|
i:i + 1][:, :, ::factor, :][:, :, :pos_emb.shape[1] - xs.shape[
|
|
|
|
|
1], :]
|
|
|
|
|
cnn_cache1 = cnn_cache[i] if cnn_cache.shape[0] > 0 else cnn_cache
|
|
|
|
|
xs, _, new_att_cache, new_cnn_cache = layer(
|
|
|
|
|
xs,
|
|
|
|
|