merge classes as required, test=asr

pull/2755/head
yeyupiaoling 3 years ago
parent 1c156bfe4d
commit c1df5b7985

@ -4,7 +4,7 @@
cmvn_file: cmvn_file:
cmvn_file_type: "json" cmvn_file_type: "json"
# encoder related # encoder related
encoder: conformer encoder: squeezeformer
encoder_conf: encoder_conf:
encoder_dim: 256 # dimension of attention encoder_dim: 256 # dimension of attention
output_size: 256 # dimension of output output_size: 256 # dimension of output
@ -21,7 +21,8 @@ encoder_conf:
normalize_before: false normalize_before: false
activation_type: 'swish' activation_type: 'swish'
pos_enc_layer_type: 'rel_pos' pos_enc_layer_type: 'rel_pos'
time_reduction_layer_type: 'conv2d' do_rel_shift: false
time_reduction_layer_type: 'stream'
causal: true causal: true
use_dynamic_chunk: true use_dynamic_chunk: true
use_dynamic_left_chunk: false use_dynamic_left_chunk: false

@ -203,7 +203,10 @@ class MultiHeadedAttention(nn.Layer):
class RelPositionMultiHeadedAttention(MultiHeadedAttention): class RelPositionMultiHeadedAttention(MultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding.""" """Multi-Head Attention layer with relative position encoding."""
def __init__(self, n_head, n_feat, dropout_rate): def __init__(self, n_head, n_feat, dropout_rate,
do_rel_shift=False,
adaptive_scale=False,
init_weights=False):
"""Construct an RelPositionMultiHeadedAttention object. """Construct an RelPositionMultiHeadedAttention object.
Paper: https://arxiv.org/abs/1901.02860 Paper: https://arxiv.org/abs/1901.02860
Args: Args:
@ -226,151 +229,15 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
pos_bias_v = self.create_parameter( pos_bias_v = self.create_parameter(
(self.h, self.d_k), default_initializer=I.XavierUniform()) (self.h, self.d_k), default_initializer=I.XavierUniform())
self.add_parameter('pos_bias_v', pos_bias_v) self.add_parameter('pos_bias_v', pos_bias_v)
def rel_shift(self, x, zero_triu: bool=False):
"""Compute relative positinal encoding.
Args:
x (paddle.Tensor): Input tensor (batch, head, time1, time1).
zero_triu (bool): If true, return the lower triangular part of
the matrix.
Returns:
paddle.Tensor: Output tensor. (batch, head, time1, time1)
"""
zero_pad = paddle.zeros(
(x.shape[0], x.shape[1], x.shape[2], 1), dtype=x.dtype)
x_padded = paddle.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(x.shape[0], x.shape[1], x.shape[3] + 1,
x.shape[2])
x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1]
if zero_triu:
ones = paddle.ones((x.shape[2], x.shape[3]))
x = x * paddle.tril(ones, x.shape[3] - x.shape[2])[None, None, :, :]
return x
def forward(self,
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
pos_emb: paddle.Tensor=paddle.empty([0]),
cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (paddle.Tensor): Query tensor (#batch, time1, size).
key (paddle.Tensor): Key tensor (#batch, time2, size).
value (paddle.Tensor): Value tensor (#batch, time2, size).
mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
pos_emb (paddle.Tensor): Positional embedding tensor
(#batch, time2, size).
cache (paddle.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
paddle.Tensor: Output tensor (#batch, time1, d_model).
paddle.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q, k, v = self.forward_qkv(query, key, value)
# q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if cache.shape[0] > 0:
# last dim `d_k * 2` for (key, val)
key_cache, value_cache = paddle.split(cache, 2, axis=-1)
k = paddle.concat([key_cache, k], axis=2)
v = paddle.concat([value_cache, v], axis=2)
# We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = paddle.concat((k, v), axis=-1)
n_batch_pos = pos_emb.shape[0]
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
# (batch, head, time1, d_k)
# q_with_bias_u = (q + self.pos_bias_u).transpose([0, 2, 1, 3])
q_with_bias_u = q + self.pos_bias_u.unsqueeze(1)
# (batch, head, time1, d_k)
# q_with_bias_v = (q + self.pos_bias_v).transpose([0, 2, 1, 3])
q_with_bias_v = q + self.pos_bias_v.unsqueeze(1)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
# matrix_ac = paddle.matmul(q_with_bias_u, k.transpose([0, 1, 3, 2]))
matrix_ac = paddle.matmul(q_with_bias_u, k, transpose_y=True)
# compute matrix b and matrix d
# (batch, head, time1, time2)
# matrix_bd = paddle.matmul(q_with_bias_v, p.transpose([0, 1, 3, 2]))
matrix_bd = paddle.matmul(q_with_bias_v, p, transpose_y=True)
# Remove rel_shift since it is useless in speech recognition,
# and it requires special attention for streaming.
# matrix_bd = self.rel_shift(matrix_bd)
scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k) # (batch, head, time1, time2)
return self.forward_attention(v, scores, mask), new_cache
class RelPositionMultiHeadedAttention2(MultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
def __init__(self,
n_head,
n_feat,
dropout_rate,
do_rel_shift=False,
adaptive_scale=False,
init_weights=False):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate)
# linear transformation for positional encoding
self.linear_pos = Linear(n_feat, n_feat)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.do_rel_shift = do_rel_shift self.do_rel_shift = do_rel_shift
pos_bias_u = self.create_parameter(
[self.h, self.d_k], default_initializer=I.XavierUniform())
self.add_parameter('pos_bias_u', pos_bias_u)
pos_bias_v = self.create_parameter(
[self.h, self.d_k], default_initializer=I.XavierUniform())
self.add_parameter('pos_bias_v', pos_bias_v)
self.adaptive_scale = adaptive_scale self.adaptive_scale = adaptive_scale
ada_scale = self.create_parameter( if self.adaptive_scale:
[1, 1, n_feat], default_initializer=I.Constant(1.0)) ada_scale = self.create_parameter(
self.add_parameter('ada_scale', ada_scale) [1, 1, n_feat], default_initializer=I.Constant(1.0))
ada_bias = self.create_parameter( self.add_parameter('ada_scale', ada_scale)
[1, 1, n_feat], default_initializer=I.Constant(0.0)) ada_bias = self.create_parameter(
self.add_parameter('ada_bias', ada_bias) [1, 1, n_feat], default_initializer=I.Constant(0.0))
self.add_parameter('ada_bias', ada_bias)
if init_weights: if init_weights:
self.init_weights() self.init_weights()
@ -407,12 +274,12 @@ class RelPositionMultiHeadedAttention2(MultiHeadedAttention):
paddle.Tensor: Output tensor. (batch, head, time1, time1) paddle.Tensor: Output tensor. (batch, head, time1, time1)
""" """
zero_pad = paddle.zeros( zero_pad = paddle.zeros(
[x.shape[0], x.shape[1], x.shape[2], 1], dtype=x.dtype) (x.shape[0], x.shape[1], x.shape[2], 1), dtype=x.dtype)
x_padded = paddle.concat([zero_pad, x], axis=-1) x_padded = paddle.cat([zero_pad, x], dim=-1)
x_padded = x_padded.reshape( x_padded = x_padded.view(x.shape[0], x.shape[1], x.shape[3] + 1,
[x.shape[0], x.shape[1], x.shape[3] + 1, x.shape[2]]) x.shape[2])
x = x_padded[:, :, 1:].reshape(paddle.shape(x)) # [B, H, T1, T1] x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1]
if zero_triu: if zero_triu:
ones = paddle.ones((x.shape[2], x.shape[3])) ones = paddle.ones((x.shape[2], x.shape[3]))
@ -424,10 +291,10 @@ class RelPositionMultiHeadedAttention2(MultiHeadedAttention):
query: paddle.Tensor, query: paddle.Tensor,
key: paddle.Tensor, key: paddle.Tensor,
value: paddle.Tensor, value: paddle.Tensor,
mask: paddle.Tensor=paddle.ones((0, 0, 0), dtype=paddle.bool), mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
pos_emb: paddle.Tensor=paddle.empty([0]), pos_emb: paddle.Tensor=paddle.empty([0]),
cache: paddle.Tensor=paddle.zeros( cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0])
(0, 0, 0, 0))) -> Tuple[paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding. """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args: Args:
query (paddle.Tensor): Query tensor (#batch, time1, size). query (paddle.Tensor): Query tensor (#batch, time1, size).
@ -452,17 +319,34 @@ class RelPositionMultiHeadedAttention2(MultiHeadedAttention):
value = self.ada_scale * value + self.ada_bias value = self.ada_scale * value + self.ada_bias
q, k, v = self.forward_qkv(query, key, value) q, k, v = self.forward_qkv(query, key, value)
# q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if cache.shape[0] > 0: if cache.shape[0] > 0:
# last dim `d_k * 2` for (key, val)
key_cache, value_cache = paddle.split(cache, 2, axis=-1) key_cache, value_cache = paddle.split(cache, 2, axis=-1)
k = paddle.concat([key_cache, k], axis=2) k = paddle.concat([key_cache, k], axis=2)
v = paddle.concat([value_cache, v], axis=2) v = paddle.concat([value_cache, v], axis=2)
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's # We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here. # non-trivial to calculate `next_cache_start` here.
new_cache = paddle.concat((k, v), axis=-1) new_cache = paddle.concat((k, v), axis=-1)
n_batch_pos = pos_emb.shape[0] n_batch_pos = pos_emb.shape[0]
p = self.linear_pos(pos_emb).reshape( p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
[n_batch_pos, -1, self.h, self.d_k])
p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
# (batch, head, time1, d_k) # (batch, head, time1, d_k)

@ -18,6 +18,7 @@ from typing import Tuple
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.nn import initializer as I
from typeguard import check_argument_types from typeguard import check_argument_types
from paddlespeech.s2t.modules.align import BatchNorm1D from paddlespeech.s2t.modules.align import BatchNorm1D
@ -39,7 +40,9 @@ class ConvolutionModule(nn.Layer):
activation: nn.Layer=nn.ReLU(), activation: nn.Layer=nn.ReLU(),
norm: str="batch_norm", norm: str="batch_norm",
causal: bool=False, causal: bool=False,
bias: bool=True): bias: bool=True,
adaptive_scale: bool=False,
init_weights: bool=False):
"""Construct an ConvolutionModule object. """Construct an ConvolutionModule object.
Args: Args:
channels (int): The number of channels of conv layers. channels (int): The number of channels of conv layers.
@ -51,6 +54,19 @@ class ConvolutionModule(nn.Layer):
""" """
assert check_argument_types() assert check_argument_types()
super().__init__() super().__init__()
self.bias = bias
self.channels = channels
self.kernel_size = kernel_size
self.adaptive_scale = adaptive_scale
if self.adaptive_scale:
ada_scale = self.create_parameter(
[1, 1, channels], default_initializer=I.Constant(1.0))
self.add_parameter('ada_scale', ada_scale)
ada_bias = self.create_parameter(
[1, 1, channels], default_initializer=I.Constant(0.0))
self.add_parameter('ada_bias', ada_bias)
self.pointwise_conv1 = Conv1D( self.pointwise_conv1 = Conv1D(
channels, channels,
2 * channels, 2 * channels,
@ -105,6 +121,28 @@ class ConvolutionModule(nn.Layer):
) )
self.activation = activation self.activation = activation
if init_weights:
self.init_weights()
def init_weights(self):
pw_max = self.channels**-0.5
dw_max = self.kernel_size**-0.5
self.pointwise_conv1._param_attr = paddle.nn.initializer.Uniform(
low=-pw_max, high=pw_max)
if self.bias:
self.pointwise_conv1._bias_attr = paddle.nn.initializer.Uniform(
low=-pw_max, high=pw_max)
self.depthwise_conv._param_attr = paddle.nn.initializer.Uniform(
low=-dw_max, high=dw_max)
if self.bias:
self.depthwise_conv._bias_attr = paddle.nn.initializer.Uniform(
low=-dw_max, high=dw_max)
self.pointwise_conv2._param_attr = paddle.nn.initializer.Uniform(
low=-pw_max, high=pw_max)
if self.bias:
self.pointwise_conv2._bias_attr = paddle.nn.initializer.Uniform(
low=-pw_max, high=pw_max)
def forward( def forward(
self, self,
x: paddle.Tensor, x: paddle.Tensor,
@ -123,6 +161,9 @@ class ConvolutionModule(nn.Layer):
paddle.Tensor: Output tensor (#batch, time, channels). paddle.Tensor: Output tensor (#batch, time, channels).
paddle.Tensor: Output cache tensor (#batch, channels, time') paddle.Tensor: Output cache tensor (#batch, channels, time')
""" """
if self.adaptive_scale:
x = self.ada_scale * x + self.ada_bias
# exchange the temporal dimension and the feature dimension # exchange the temporal dimension and the feature dimension
x = x.transpose([0, 2, 1]) # [B, C, T] x = x.transpose([0, 2, 1]) # [B, C, T]

@ -1,181 +0,0 @@
from typing import Tuple
import paddle
from paddle import nn
from paddle.nn import initializer as I
from typeguard import check_argument_types
__all__ = ['ConvolutionModule2']
from paddlespeech.s2t import masked_fill
from paddlespeech.s2t.modules.align import Conv1D, BatchNorm1D, LayerNorm
class ConvolutionModule2(nn.Layer):
"""ConvolutionModule in Conformer model."""
def __init__(self,
channels: int,
kernel_size: int=15,
activation: nn.Layer=nn.ReLU(),
norm: str="batch_norm",
causal: bool=False,
bias: bool=True,
adaptive_scale: bool=False,
init_weights: bool=False):
"""Construct an ConvolutionModule object.
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernel size of conv layers.
causal (int): Whether use causal convolution or not
"""
assert check_argument_types()
super().__init__()
self.bias = bias
self.channels = channels
self.kernel_size = kernel_size
self.adaptive_scale = adaptive_scale
ada_scale = self.create_parameter(
[1, 1, channels], default_initializer=I.Constant(1.0))
self.add_parameter('ada_scale', ada_scale)
ada_bias = self.create_parameter(
[1, 1, channels], default_initializer=I.Constant(0.0))
self.add_parameter('ada_bias', ada_bias)
self.pointwise_conv1 = Conv1D(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias_attr=None
if bias else False, # None for True, using bias as default config
)
# self.lorder is used to distinguish if it's a causal convolution,
# if self.lorder > 0: it's a causal convolution, the input will be
# padded with self.lorder frames on the left in forward.
# else: it's a symmetrical convolution
if causal:
padding = 0
self.lorder = kernel_size - 1
else:
# kernel_size should be an odd number for none causal convolution
assert (kernel_size - 1) % 2 == 0
padding = (kernel_size - 1) // 2
self.lorder = 0
self.depthwise_conv = Conv1D(
channels,
channels,
kernel_size,
stride=1,
padding=padding,
groups=channels,
bias_attr=None
if bias else False, # None for True, using bias as default config
)
assert norm in ['batch_norm', 'layer_norm']
if norm == "batch_norm":
self.use_layer_norm = False
self.norm = BatchNorm1D(channels)
else:
self.use_layer_norm = True
self.norm = LayerNorm(channels)
self.pointwise_conv2 = Conv1D(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias_attr=None
if bias else False, # None for True, using bias as default config
)
self.activation = activation
if init_weights:
self.init_weights()
def init_weights(self):
pw_max = self.channels**-0.5
dw_max = self.kernel_size**-0.5
self.pointwise_conv1._param_attr = paddle.nn.initializer.Uniform(
low=-pw_max, high=pw_max)
if self.bias:
self.pointwise_conv1._bias_attr = paddle.nn.initializer.Uniform(
low=-pw_max, high=pw_max)
self.depthwise_conv._param_attr = paddle.nn.initializer.Uniform(
low=-dw_max, high=dw_max)
if self.bias:
self.depthwise_conv._bias_attr = paddle.nn.initializer.Uniform(
low=-dw_max, high=dw_max)
self.pointwise_conv2._param_attr = paddle.nn.initializer.Uniform(
low=-pw_max, high=pw_max)
if self.bias:
self.pointwise_conv2._bias_attr = paddle.nn.initializer.Uniform(
low=-pw_max, high=pw_max)
def forward(
self,
x: paddle.Tensor,
mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
cache: paddle.Tensor=paddle.zeros([0, 0, 0]),
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute convolution module.
Args:
x (torch.Tensor): Input tensor (#batch, time, channels).
mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
(0, 0, 0) means fake mask.
cache (torch.Tensor): left context cache, it is only
used in causal convolution (#batch, channels, cache_t),
(0, 0, 0) meas fake cache.
Returns:
torch.Tensor: Output tensor (#batch, time, channels).
"""
if self.adaptive_scale:
x = self.ada_scale * x + self.ada_bias
# exchange the temporal dimension and the feature dimension
x = x.transpose([0, 2, 1]) # [B, C, T]
# mask batch padding
if mask_pad.shape[2] > 0: # time > 0
x = masked_fill(x, mask_pad, 0.0)
if self.lorder > 0:
if cache.shape[2] == 0: # cache_t == 0
x = nn.functional.pad(
x, [self.lorder, 0], 'constant', 0.0, data_format='NCL')
else:
assert cache.shape[0] == x.shape[0] # B
assert cache.shape[1] == x.shape[1] # C
x = paddle.concat((cache, x), axis=2)
assert (x.shape[2] > self.lorder)
new_cache = x[:, :, -self.lorder:] # [B, C, T]
else:
# It's better we just return None if no cache is required,
# However, for JIT export, here we just fake one tensor instead of
# None.
new_cache = paddle.zeros([0, 0, 0], dtype=x.dtype)
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
x = nn.functional.glu(x, axis=1) # (batch, channel, dim)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
if self.use_layer_norm:
x = x.transpose([0, 2, 1]) # [B, T, C]
x = self.activation(self.norm(x))
if self.use_layer_norm:
x = x.transpose([0, 2, 1]) # [B, C, T]
x = self.pointwise_conv2(x)
# mask batch padding
if mask_pad.shape[2] > 0: # time > 0
x = masked_fill(x, mask_pad, 0.0)
x = x.transpose([0, 2, 1]) # [B, T, C]
return x, new_cache

@ -30,7 +30,6 @@ from paddlespeech.s2t.modules.attention import MultiHeadedAttention
from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention
from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention2 from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention2
from paddlespeech.s2t.modules.conformer_convolution import ConvolutionModule from paddlespeech.s2t.modules.conformer_convolution import ConvolutionModule
from paddlespeech.s2t.modules.convolution import ConvolutionModule2
from paddlespeech.s2t.modules.embedding import NoPositionalEncoding from paddlespeech.s2t.modules.embedding import NoPositionalEncoding
from paddlespeech.s2t.modules.embedding import PositionalEncoding from paddlespeech.s2t.modules.embedding import PositionalEncoding
from paddlespeech.s2t.modules.embedding import RelPositionalEncoding from paddlespeech.s2t.modules.embedding import RelPositionalEncoding
@ -40,7 +39,6 @@ from paddlespeech.s2t.modules.encoder_layer import TransformerEncoderLayer
from paddlespeech.s2t.modules.mask import add_optional_chunk_mask from paddlespeech.s2t.modules.mask import add_optional_chunk_mask
from paddlespeech.s2t.modules.mask import make_non_pad_mask from paddlespeech.s2t.modules.mask import make_non_pad_mask
from paddlespeech.s2t.modules.positionwise_feed_forward import PositionwiseFeedForward from paddlespeech.s2t.modules.positionwise_feed_forward import PositionwiseFeedForward
from paddlespeech.s2t.modules.positionwise_feed_forward import PositionwiseFeedForward2
from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling4 from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling4
from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling6 from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling6
from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling8 from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling8
@ -591,19 +589,19 @@ class SqueezeformerEncoder(nn.Layer):
encoder_selfattn_layer_args = (attention_heads, output_size, encoder_selfattn_layer_args = (attention_heads, output_size,
attention_dropout_rate) attention_dropout_rate)
else: else:
encoder_selfattn_layer = RelPositionMultiHeadedAttention2 encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, encoder_dim, encoder_selfattn_layer_args = (attention_heads, encoder_dim,
attention_dropout_rate, do_rel_shift, attention_dropout_rate, do_rel_shift,
adaptive_scale, init_weights) adaptive_scale, init_weights)
# feed-forward module definition # feed-forward module definition
positionwise_layer = PositionwiseFeedForward2 positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = ( positionwise_layer_args = (
encoder_dim, encoder_dim * feed_forward_expansion_factor, encoder_dim, encoder_dim * feed_forward_expansion_factor,
feed_forward_dropout_rate, activation, adaptive_scale, init_weights) feed_forward_dropout_rate, activation, adaptive_scale, init_weights)
# convolution module definition # convolution module definition
convolution_layer = ConvolutionModule2 convolution_layer = ConvolutionModule
convolution_layer_args = (encoder_dim, cnn_module_kernel, activation, convolution_layer_args = (encoder_dim, cnn_module_kernel, activation,
cnn_norm_type, causal, True, adaptive_scale, cnn_norm_type, causal, True, adaptive_scale,
init_weights) init_weights)
@ -676,7 +674,7 @@ class SqueezeformerEncoder(nn.Layer):
if self.global_cmvn is not None: if self.global_cmvn is not None:
xs = self.global_cmvn(xs) xs = self.global_cmvn(xs)
xs, pos_emb, masks = self.embed(xs, masks) xs, pos_emb, masks = self.embed(xs, masks)
mask_pad = ~masks mask_pad = masks
chunk_masks = add_optional_chunk_mask( chunk_masks = add_optional_chunk_mask(
xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk, xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk,
decoding_chunk_size, self.static_chunk_size, decoding_chunk_size, self.static_chunk_size,

@ -23,7 +23,7 @@ from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
__all__ = ["PositionwiseFeedForward", "PositionwiseFeedForward2"] __all__ = ["PositionwiseFeedForward"]
class PositionwiseFeedForward(nn.Layer): class PositionwiseFeedForward(nn.Layer):
@ -33,7 +33,9 @@ class PositionwiseFeedForward(nn.Layer):
idim: int, idim: int,
hidden_units: int, hidden_units: int,
dropout_rate: float, dropout_rate: float,
activation: nn.Layer=nn.ReLU()): activation: nn.Layer=nn.ReLU(),
adaptive_scale: bool=False,
init_weights: bool=False):
"""Construct a PositionwiseFeedForward object. """Construct a PositionwiseFeedForward object.
FeedForward are appied on each position of the sequence. FeedForward are appied on each position of the sequence.
@ -46,48 +48,11 @@ class PositionwiseFeedForward(nn.Layer):
activation (paddle.nn.Layer): Activation function activation (paddle.nn.Layer): Activation function
""" """
super().__init__() super().__init__()
self.w_1 = Linear(idim, hidden_units)
self.activation = activation
self.dropout = nn.Dropout(dropout_rate)
self.w_2 = Linear(hidden_units, idim)
def forward(self, xs: paddle.Tensor) -> paddle.Tensor:
"""Forward function.
Args:
xs: input tensor (B, Lmax, D)
Returns:
output tensor, (B, Lmax, D)
"""
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
class PositionwiseFeedForward2(paddle.nn.Layer):
"""Positionwise feed forward layer.
FeedForward are appied on each position of the sequence.
The output dim is same with the input dim.
Args:
idim (int): Input dimenstion.
hidden_units (int): The number of hidden units.
dropout_rate (float): Dropout rate.
activation (paddle.nn.Layer): Activation function
"""
def __init__(self,
idim: int,
hidden_units: int,
dropout_rate: float,
activation: paddle.nn.Layer=paddle.nn.ReLU(),
adaptive_scale: bool=False,
init_weights: bool=False):
"""Construct a PositionwiseFeedForward object."""
super(PositionwiseFeedForward2, self).__init__()
self.idim = idim self.idim = idim
self.hidden_units = hidden_units self.hidden_units = hidden_units
self.w_1 = Linear(idim, hidden_units) self.w_1 = Linear(idim, hidden_units)
self.activation = activation self.activation = activation
self.dropout = paddle.nn.Dropout(dropout_rate) self.dropout = nn.Dropout(dropout_rate)
self.w_2 = Linear(hidden_units, idim) self.w_2 = Linear(hidden_units, idim)
self.adaptive_scale = adaptive_scale self.adaptive_scale = adaptive_scale
ada_scale = self.create_parameter( ada_scale = self.create_parameter(
@ -114,12 +79,9 @@ class PositionwiseFeedForward2(paddle.nn.Layer):
def forward(self, xs: paddle.Tensor) -> paddle.Tensor: def forward(self, xs: paddle.Tensor) -> paddle.Tensor:
"""Forward function. """Forward function.
Args: Args:
xs: input tensor (B, L, D) xs: input tensor (B, Lmax, D)
Returns: Returns:
output tensor, (B, L, D) output tensor, (B, Lmax, D)
""" """
if self.adaptive_scale:
xs = self.ada_scale * xs + self.ada_bias
return self.w_2(self.dropout(self.activation(self.w_1(xs)))) return self.w_2(self.dropout(self.activation(self.w_1(xs))))

Loading…
Cancel
Save