merge classes as required, test=asr

pull/2755/head
yeyupiaoling 3 years ago
parent 1c156bfe4d
commit c1df5b7985

@ -4,7 +4,7 @@
cmvn_file:
cmvn_file_type: "json"
# encoder related
encoder: conformer
encoder: squeezeformer
encoder_conf:
encoder_dim: 256 # dimension of attention
output_size: 256 # dimension of output
@ -21,7 +21,8 @@ encoder_conf:
normalize_before: false
activation_type: 'swish'
pos_enc_layer_type: 'rel_pos'
time_reduction_layer_type: 'conv2d'
do_rel_shift: false
time_reduction_layer_type: 'stream'
causal: true
use_dynamic_chunk: true
use_dynamic_left_chunk: false

@ -203,7 +203,10 @@ class MultiHeadedAttention(nn.Layer):
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding."""
def __init__(self, n_head, n_feat, dropout_rate):
def __init__(self, n_head, n_feat, dropout_rate,
do_rel_shift=False,
adaptive_scale=False,
init_weights=False):
"""Construct an RelPositionMultiHeadedAttention object.
Paper: https://arxiv.org/abs/1901.02860
Args:
@ -226,151 +229,15 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
pos_bias_v = self.create_parameter(
(self.h, self.d_k), default_initializer=I.XavierUniform())
self.add_parameter('pos_bias_v', pos_bias_v)
def rel_shift(self, x, zero_triu: bool=False):
"""Compute relative positinal encoding.
Args:
x (paddle.Tensor): Input tensor (batch, head, time1, time1).
zero_triu (bool): If true, return the lower triangular part of
the matrix.
Returns:
paddle.Tensor: Output tensor. (batch, head, time1, time1)
"""
zero_pad = paddle.zeros(
(x.shape[0], x.shape[1], x.shape[2], 1), dtype=x.dtype)
x_padded = paddle.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(x.shape[0], x.shape[1], x.shape[3] + 1,
x.shape[2])
x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1]
if zero_triu:
ones = paddle.ones((x.shape[2], x.shape[3]))
x = x * paddle.tril(ones, x.shape[3] - x.shape[2])[None, None, :, :]
return x
def forward(self,
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
pos_emb: paddle.Tensor=paddle.empty([0]),
cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (paddle.Tensor): Query tensor (#batch, time1, size).
key (paddle.Tensor): Key tensor (#batch, time2, size).
value (paddle.Tensor): Value tensor (#batch, time2, size).
mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
pos_emb (paddle.Tensor): Positional embedding tensor
(#batch, time2, size).
cache (paddle.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
paddle.Tensor: Output tensor (#batch, time1, d_model).
paddle.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q, k, v = self.forward_qkv(query, key, value)
# q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if cache.shape[0] > 0:
# last dim `d_k * 2` for (key, val)
key_cache, value_cache = paddle.split(cache, 2, axis=-1)
k = paddle.concat([key_cache, k], axis=2)
v = paddle.concat([value_cache, v], axis=2)
# We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = paddle.concat((k, v), axis=-1)
n_batch_pos = pos_emb.shape[0]
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
# (batch, head, time1, d_k)
# q_with_bias_u = (q + self.pos_bias_u).transpose([0, 2, 1, 3])
q_with_bias_u = q + self.pos_bias_u.unsqueeze(1)
# (batch, head, time1, d_k)
# q_with_bias_v = (q + self.pos_bias_v).transpose([0, 2, 1, 3])
q_with_bias_v = q + self.pos_bias_v.unsqueeze(1)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
# matrix_ac = paddle.matmul(q_with_bias_u, k.transpose([0, 1, 3, 2]))
matrix_ac = paddle.matmul(q_with_bias_u, k, transpose_y=True)
# compute matrix b and matrix d
# (batch, head, time1, time2)
# matrix_bd = paddle.matmul(q_with_bias_v, p.transpose([0, 1, 3, 2]))
matrix_bd = paddle.matmul(q_with_bias_v, p, transpose_y=True)
# Remove rel_shift since it is useless in speech recognition,
# and it requires special attention for streaming.
# matrix_bd = self.rel_shift(matrix_bd)
scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k) # (batch, head, time1, time2)
return self.forward_attention(v, scores, mask), new_cache
class RelPositionMultiHeadedAttention2(MultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
def __init__(self,
n_head,
n_feat,
dropout_rate,
do_rel_shift=False,
adaptive_scale=False,
init_weights=False):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate)
# linear transformation for positional encoding
self.linear_pos = Linear(n_feat, n_feat)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.do_rel_shift = do_rel_shift
pos_bias_u = self.create_parameter(
[self.h, self.d_k], default_initializer=I.XavierUniform())
self.add_parameter('pos_bias_u', pos_bias_u)
pos_bias_v = self.create_parameter(
[self.h, self.d_k], default_initializer=I.XavierUniform())
self.add_parameter('pos_bias_v', pos_bias_v)
self.adaptive_scale = adaptive_scale
ada_scale = self.create_parameter(
[1, 1, n_feat], default_initializer=I.Constant(1.0))
self.add_parameter('ada_scale', ada_scale)
ada_bias = self.create_parameter(
[1, 1, n_feat], default_initializer=I.Constant(0.0))
self.add_parameter('ada_bias', ada_bias)
if self.adaptive_scale:
ada_scale = self.create_parameter(
[1, 1, n_feat], default_initializer=I.Constant(1.0))
self.add_parameter('ada_scale', ada_scale)
ada_bias = self.create_parameter(
[1, 1, n_feat], default_initializer=I.Constant(0.0))
self.add_parameter('ada_bias', ada_bias)
if init_weights:
self.init_weights()
@ -407,12 +274,12 @@ class RelPositionMultiHeadedAttention2(MultiHeadedAttention):
paddle.Tensor: Output tensor. (batch, head, time1, time1)
"""
zero_pad = paddle.zeros(
[x.shape[0], x.shape[1], x.shape[2], 1], dtype=x.dtype)
x_padded = paddle.concat([zero_pad, x], axis=-1)
(x.shape[0], x.shape[1], x.shape[2], 1), dtype=x.dtype)
x_padded = paddle.cat([zero_pad, x], dim=-1)
x_padded = x_padded.reshape(
[x.shape[0], x.shape[1], x.shape[3] + 1, x.shape[2]])
x = x_padded[:, :, 1:].reshape(paddle.shape(x)) # [B, H, T1, T1]
x_padded = x_padded.view(x.shape[0], x.shape[1], x.shape[3] + 1,
x.shape[2])
x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1]
if zero_triu:
ones = paddle.ones((x.shape[2], x.shape[3]))
@ -424,10 +291,10 @@ class RelPositionMultiHeadedAttention2(MultiHeadedAttention):
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
mask: paddle.Tensor=paddle.ones((0, 0, 0), dtype=paddle.bool),
mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
pos_emb: paddle.Tensor=paddle.empty([0]),
cache: paddle.Tensor=paddle.zeros(
(0, 0, 0, 0))) -> Tuple[paddle.Tensor, paddle.Tensor]:
cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (paddle.Tensor): Query tensor (#batch, time1, size).
@ -452,17 +319,34 @@ class RelPositionMultiHeadedAttention2(MultiHeadedAttention):
value = self.ada_scale * value + self.ada_bias
q, k, v = self.forward_qkv(query, key, value)
# q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if cache.shape[0] > 0:
# last dim `d_k * 2` for (key, val)
key_cache, value_cache = paddle.split(cache, 2, axis=-1)
k = paddle.concat([key_cache, k], axis=2)
v = paddle.concat([value_cache, v], axis=2)
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
# We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = paddle.concat((k, v), axis=-1)
n_batch_pos = pos_emb.shape[0]
p = self.linear_pos(pos_emb).reshape(
[n_batch_pos, -1, self.h, self.d_k])
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
# (batch, head, time1, d_k)

@ -18,6 +18,7 @@ from typing import Tuple
import paddle
from paddle import nn
from paddle.nn import initializer as I
from typeguard import check_argument_types
from paddlespeech.s2t.modules.align import BatchNorm1D
@ -39,7 +40,9 @@ class ConvolutionModule(nn.Layer):
activation: nn.Layer=nn.ReLU(),
norm: str="batch_norm",
causal: bool=False,
bias: bool=True):
bias: bool=True,
adaptive_scale: bool=False,
init_weights: bool=False):
"""Construct an ConvolutionModule object.
Args:
channels (int): The number of channels of conv layers.
@ -51,6 +54,19 @@ class ConvolutionModule(nn.Layer):
"""
assert check_argument_types()
super().__init__()
self.bias = bias
self.channels = channels
self.kernel_size = kernel_size
self.adaptive_scale = adaptive_scale
if self.adaptive_scale:
ada_scale = self.create_parameter(
[1, 1, channels], default_initializer=I.Constant(1.0))
self.add_parameter('ada_scale', ada_scale)
ada_bias = self.create_parameter(
[1, 1, channels], default_initializer=I.Constant(0.0))
self.add_parameter('ada_bias', ada_bias)
self.pointwise_conv1 = Conv1D(
channels,
2 * channels,
@ -105,6 +121,28 @@ class ConvolutionModule(nn.Layer):
)
self.activation = activation
if init_weights:
self.init_weights()
def init_weights(self):
pw_max = self.channels**-0.5
dw_max = self.kernel_size**-0.5
self.pointwise_conv1._param_attr = paddle.nn.initializer.Uniform(
low=-pw_max, high=pw_max)
if self.bias:
self.pointwise_conv1._bias_attr = paddle.nn.initializer.Uniform(
low=-pw_max, high=pw_max)
self.depthwise_conv._param_attr = paddle.nn.initializer.Uniform(
low=-dw_max, high=dw_max)
if self.bias:
self.depthwise_conv._bias_attr = paddle.nn.initializer.Uniform(
low=-dw_max, high=dw_max)
self.pointwise_conv2._param_attr = paddle.nn.initializer.Uniform(
low=-pw_max, high=pw_max)
if self.bias:
self.pointwise_conv2._bias_attr = paddle.nn.initializer.Uniform(
low=-pw_max, high=pw_max)
def forward(
self,
x: paddle.Tensor,
@ -123,6 +161,9 @@ class ConvolutionModule(nn.Layer):
paddle.Tensor: Output tensor (#batch, time, channels).
paddle.Tensor: Output cache tensor (#batch, channels, time')
"""
if self.adaptive_scale:
x = self.ada_scale * x + self.ada_bias
# exchange the temporal dimension and the feature dimension
x = x.transpose([0, 2, 1]) # [B, C, T]

@ -1,181 +0,0 @@
from typing import Tuple
import paddle
from paddle import nn
from paddle.nn import initializer as I
from typeguard import check_argument_types
__all__ = ['ConvolutionModule2']
from paddlespeech.s2t import masked_fill
from paddlespeech.s2t.modules.align import Conv1D, BatchNorm1D, LayerNorm
class ConvolutionModule2(nn.Layer):
"""ConvolutionModule in Conformer model."""
def __init__(self,
channels: int,
kernel_size: int=15,
activation: nn.Layer=nn.ReLU(),
norm: str="batch_norm",
causal: bool=False,
bias: bool=True,
adaptive_scale: bool=False,
init_weights: bool=False):
"""Construct an ConvolutionModule object.
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernel size of conv layers.
causal (int): Whether use causal convolution or not
"""
assert check_argument_types()
super().__init__()
self.bias = bias
self.channels = channels
self.kernel_size = kernel_size
self.adaptive_scale = adaptive_scale
ada_scale = self.create_parameter(
[1, 1, channels], default_initializer=I.Constant(1.0))
self.add_parameter('ada_scale', ada_scale)
ada_bias = self.create_parameter(
[1, 1, channels], default_initializer=I.Constant(0.0))
self.add_parameter('ada_bias', ada_bias)
self.pointwise_conv1 = Conv1D(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias_attr=None
if bias else False, # None for True, using bias as default config
)
# self.lorder is used to distinguish if it's a causal convolution,
# if self.lorder > 0: it's a causal convolution, the input will be
# padded with self.lorder frames on the left in forward.
# else: it's a symmetrical convolution
if causal:
padding = 0
self.lorder = kernel_size - 1
else:
# kernel_size should be an odd number for none causal convolution
assert (kernel_size - 1) % 2 == 0
padding = (kernel_size - 1) // 2
self.lorder = 0
self.depthwise_conv = Conv1D(
channels,
channels,
kernel_size,
stride=1,
padding=padding,
groups=channels,
bias_attr=None
if bias else False, # None for True, using bias as default config
)
assert norm in ['batch_norm', 'layer_norm']
if norm == "batch_norm":
self.use_layer_norm = False
self.norm = BatchNorm1D(channels)
else:
self.use_layer_norm = True
self.norm = LayerNorm(channels)
self.pointwise_conv2 = Conv1D(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias_attr=None
if bias else False, # None for True, using bias as default config
)
self.activation = activation
if init_weights:
self.init_weights()
def init_weights(self):
pw_max = self.channels**-0.5
dw_max = self.kernel_size**-0.5
self.pointwise_conv1._param_attr = paddle.nn.initializer.Uniform(
low=-pw_max, high=pw_max)
if self.bias:
self.pointwise_conv1._bias_attr = paddle.nn.initializer.Uniform(
low=-pw_max, high=pw_max)
self.depthwise_conv._param_attr = paddle.nn.initializer.Uniform(
low=-dw_max, high=dw_max)
if self.bias:
self.depthwise_conv._bias_attr = paddle.nn.initializer.Uniform(
low=-dw_max, high=dw_max)
self.pointwise_conv2._param_attr = paddle.nn.initializer.Uniform(
low=-pw_max, high=pw_max)
if self.bias:
self.pointwise_conv2._bias_attr = paddle.nn.initializer.Uniform(
low=-pw_max, high=pw_max)
def forward(
self,
x: paddle.Tensor,
mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
cache: paddle.Tensor=paddle.zeros([0, 0, 0]),
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute convolution module.
Args:
x (torch.Tensor): Input tensor (#batch, time, channels).
mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
(0, 0, 0) means fake mask.
cache (torch.Tensor): left context cache, it is only
used in causal convolution (#batch, channels, cache_t),
(0, 0, 0) meas fake cache.
Returns:
torch.Tensor: Output tensor (#batch, time, channels).
"""
if self.adaptive_scale:
x = self.ada_scale * x + self.ada_bias
# exchange the temporal dimension and the feature dimension
x = x.transpose([0, 2, 1]) # [B, C, T]
# mask batch padding
if mask_pad.shape[2] > 0: # time > 0
x = masked_fill(x, mask_pad, 0.0)
if self.lorder > 0:
if cache.shape[2] == 0: # cache_t == 0
x = nn.functional.pad(
x, [self.lorder, 0], 'constant', 0.0, data_format='NCL')
else:
assert cache.shape[0] == x.shape[0] # B
assert cache.shape[1] == x.shape[1] # C
x = paddle.concat((cache, x), axis=2)
assert (x.shape[2] > self.lorder)
new_cache = x[:, :, -self.lorder:] # [B, C, T]
else:
# It's better we just return None if no cache is required,
# However, for JIT export, here we just fake one tensor instead of
# None.
new_cache = paddle.zeros([0, 0, 0], dtype=x.dtype)
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
x = nn.functional.glu(x, axis=1) # (batch, channel, dim)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
if self.use_layer_norm:
x = x.transpose([0, 2, 1]) # [B, T, C]
x = self.activation(self.norm(x))
if self.use_layer_norm:
x = x.transpose([0, 2, 1]) # [B, C, T]
x = self.pointwise_conv2(x)
# mask batch padding
if mask_pad.shape[2] > 0: # time > 0
x = masked_fill(x, mask_pad, 0.0)
x = x.transpose([0, 2, 1]) # [B, T, C]
return x, new_cache

@ -30,7 +30,6 @@ from paddlespeech.s2t.modules.attention import MultiHeadedAttention
from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention
from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention2
from paddlespeech.s2t.modules.conformer_convolution import ConvolutionModule
from paddlespeech.s2t.modules.convolution import ConvolutionModule2
from paddlespeech.s2t.modules.embedding import NoPositionalEncoding
from paddlespeech.s2t.modules.embedding import PositionalEncoding
from paddlespeech.s2t.modules.embedding import RelPositionalEncoding
@ -40,7 +39,6 @@ from paddlespeech.s2t.modules.encoder_layer import TransformerEncoderLayer
from paddlespeech.s2t.modules.mask import add_optional_chunk_mask
from paddlespeech.s2t.modules.mask import make_non_pad_mask
from paddlespeech.s2t.modules.positionwise_feed_forward import PositionwiseFeedForward
from paddlespeech.s2t.modules.positionwise_feed_forward import PositionwiseFeedForward2
from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling4
from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling6
from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling8
@ -591,19 +589,19 @@ class SqueezeformerEncoder(nn.Layer):
encoder_selfattn_layer_args = (attention_heads, output_size,
attention_dropout_rate)
else:
encoder_selfattn_layer = RelPositionMultiHeadedAttention2
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, encoder_dim,
attention_dropout_rate, do_rel_shift,
adaptive_scale, init_weights)
# feed-forward module definition
positionwise_layer = PositionwiseFeedForward2
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
encoder_dim, encoder_dim * feed_forward_expansion_factor,
feed_forward_dropout_rate, activation, adaptive_scale, init_weights)
# convolution module definition
convolution_layer = ConvolutionModule2
convolution_layer = ConvolutionModule
convolution_layer_args = (encoder_dim, cnn_module_kernel, activation,
cnn_norm_type, causal, True, adaptive_scale,
init_weights)
@ -676,7 +674,7 @@ class SqueezeformerEncoder(nn.Layer):
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
xs, pos_emb, masks = self.embed(xs, masks)
mask_pad = ~masks
mask_pad = masks
chunk_masks = add_optional_chunk_mask(
xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk,
decoding_chunk_size, self.static_chunk_size,

@ -23,7 +23,7 @@ from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ["PositionwiseFeedForward", "PositionwiseFeedForward2"]
__all__ = ["PositionwiseFeedForward"]
class PositionwiseFeedForward(nn.Layer):
@ -33,7 +33,9 @@ class PositionwiseFeedForward(nn.Layer):
idim: int,
hidden_units: int,
dropout_rate: float,
activation: nn.Layer=nn.ReLU()):
activation: nn.Layer=nn.ReLU(),
adaptive_scale: bool=False,
init_weights: bool=False):
"""Construct a PositionwiseFeedForward object.
FeedForward are appied on each position of the sequence.
@ -46,48 +48,11 @@ class PositionwiseFeedForward(nn.Layer):
activation (paddle.nn.Layer): Activation function
"""
super().__init__()
self.w_1 = Linear(idim, hidden_units)
self.activation = activation
self.dropout = nn.Dropout(dropout_rate)
self.w_2 = Linear(hidden_units, idim)
def forward(self, xs: paddle.Tensor) -> paddle.Tensor:
"""Forward function.
Args:
xs: input tensor (B, Lmax, D)
Returns:
output tensor, (B, Lmax, D)
"""
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
class PositionwiseFeedForward2(paddle.nn.Layer):
"""Positionwise feed forward layer.
FeedForward are appied on each position of the sequence.
The output dim is same with the input dim.
Args:
idim (int): Input dimenstion.
hidden_units (int): The number of hidden units.
dropout_rate (float): Dropout rate.
activation (paddle.nn.Layer): Activation function
"""
def __init__(self,
idim: int,
hidden_units: int,
dropout_rate: float,
activation: paddle.nn.Layer=paddle.nn.ReLU(),
adaptive_scale: bool=False,
init_weights: bool=False):
"""Construct a PositionwiseFeedForward object."""
super(PositionwiseFeedForward2, self).__init__()
self.idim = idim
self.hidden_units = hidden_units
self.w_1 = Linear(idim, hidden_units)
self.activation = activation
self.dropout = paddle.nn.Dropout(dropout_rate)
self.dropout = nn.Dropout(dropout_rate)
self.w_2 = Linear(hidden_units, idim)
self.adaptive_scale = adaptive_scale
ada_scale = self.create_parameter(
@ -114,12 +79,9 @@ class PositionwiseFeedForward2(paddle.nn.Layer):
def forward(self, xs: paddle.Tensor) -> paddle.Tensor:
"""Forward function.
Args:
xs: input tensor (B, L, D)
xs: input tensor (B, Lmax, D)
Returns:
output tensor, (B, L, D)
output tensor, (B, Lmax, D)
"""
if self.adaptive_scale:
xs = self.ada_scale * xs + self.ada_bias
return self.w_2(self.dropout(self.activation(self.w_1(xs))))

Loading…
Cancel
Save