add layer for transformer

pull/556/head
Hui Zhang 5 years ago
parent 9cf8c1a5db
commit b2bc6eb526

@ -28,7 +28,7 @@ from paddle import distributed as dist
from paddle.io import DataLoader
from deepspeech.training import Trainer
from deepspeech.training.gradclip import MyClipGradByGlobalNorm
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.utils import mp_tools
from deepspeech.utils import layer_tools
@ -125,7 +125,8 @@ class DeepSpeech2Trainer(Trainer):
layer_tools.print_params(model, self.logger.info)
grad_clip = MyClipGradByGlobalNorm(config.training.global_grad_clip)
grad_clip = ClipGradByGlobalNormWithLog(
config.training.global_grad_clip)
lr_scheduler = paddle.optimizer.lr.ExponentialDecay(
learning_rate=config.training.lr,
gamma=config.training.lr_decay,

@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
import logging
import numpy as np
import math
from collections import OrderedDict
import paddle
from paddle import nn
@ -23,7 +25,7 @@ from paddle.nn import initializer as I
logger = logging.getLogger(__name__)
__all__ = ['brelu', "softplus", "gelu_accurate", "gelu", 'Swish']
__all__ = ['brelu', "glu"]
def brelu(x, t_min=0.0, t_max=24.0, name=None):
@ -33,36 +35,180 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None):
return x.maximum(t_min).minimum(t_max)
def softplus(x):
"""Softplus function."""
if hasattr(paddle.nn.functional, 'softplus'):
#return paddle.nn.functional.softplus(x.float()).type_as(x)
return paddle.nn.functional.softplus(x)
else:
raise NotImplementedError
# def softplus(x):
# """Softplus function."""
# if hasattr(paddle.nn.functional, 'softplus'):
# #return paddle.nn.functional.softplus(x.float()).type_as(x)
# return paddle.nn.functional.softplus(x)
# else:
# raise NotImplementedError
# def gelu_accurate(x):
# """Gaussian Error Linear Units (GELU) activation."""
# # [reference] https://github.com/pytorch/fairseq/blob/e75cff5f2c1d62f12dc911e0bf420025eb1a4e33/fairseq/modules/gelu.py
# if not hasattr(gelu_accurate, "_a"):
# gelu_accurate._a = math.sqrt(2 / math.pi)
# return 0.5 * x * (1 + paddle.tanh(gelu_accurate._a *
# (x + 0.044715 * paddle.pow(x, 3))))
def gelu_accurate(x):
"""Gaussian Error Linear Units (GELU) activation."""
# [reference] https://github.com/pytorch/fairseq/blob/e75cff5f2c1d62f12dc911e0bf420025eb1a4e33/fairseq/modules/gelu.py
if not hasattr(gelu_accurate, "_a"):
gelu_accurate._a = math.sqrt(2 / math.pi)
return 0.5 * x * (1 + paddle.tanh(gelu_accurate._a *
(x + 0.044715 * paddle.pow(x, 3))))
# def gelu(x):
# """Gaussian Error Linear Units (GELU) activation."""
# if hasattr(nn.functional, 'gelu'):
# #return nn.functional.gelu(x.float()).type_as(x)
# return nn.functional.gelu(x)
# else:
# return x * 0.5 * (1.0 + paddle.erf(x / math.sqrt(2.0)))
def gelu(x):
"""Gaussian Error Linear Units (GELU) activation."""
if hasattr(torch.nn.functional, 'gelu'):
#return torch.nn.functional.gelu(x.float()).type_as(x)
return torch.nn.functional.gelu(x)
# TODO(Hui Zhang): remove this activation
def glu(x, dim=-1):
"""The gated linear unit (GLU) activation."""
if hasattr(nn.functional, 'glu'):
return nn.functional.glu(x)
else:
return x * 0.5 * (1.0 + paddle.erf(x / math.sqrt(2.0)))
a, b = x.split(2, axis=dim)
act_b = F.sigmoid(b)
return a * act_b
# TODO(Hui Zhang): remove this activation
if not hasattr(nn.functional, 'glu'):
setattr(nn.functional, 'glu', glu)
# TODO(Hui Zhang): remove this activation
class GLU(nn.Layer):
"""Gated Linear Units (GLU) Layer"""
def __init__(self, dim: int=-1):
super().__init__()
self.dim = dim
def forward(self, xs):
return glu(xs, dim=self.dim)
class LinearGLUBlock(nn.Layer):
"""A linear Gated Linear Units (GLU) block."""
def __init__(self, idim: int):
""" GLU.
Args:
idim (int): input and output dimension
"""
super().__init__()
self.fc = nn.Linear(idim, idim * 2)
def forward(self, xs):
return glu(self.fc(xs), dim=-1)
# TODO(Hui Zhang): remove this Layer
class ConstantPad2d(nn.Layer):
"""Pads the input tensor boundaries with a constant value.
For N-dimensional padding, use paddle.nn.functional.pad().
"""
def __init__(self, padding: Union[tuple, list, int], value: float):
"""
Args:
paddle ([tuple]): the size of the padding.
If is int, uses the same padding in all boundaries.
If a 4-tuple, uses (padding_left, padding_right, padding_top, padding_bottom)
value ([flaot]): pad value
"""
self.padding = padding if isinstance(padding,
[tuple, list]) else [padding] * 4
self.value = value
def forward(self, xs: paddle.Tensor) -> paddle.Tensor:
return nn.functional.pad(
xs,
self.padding,
mode='constant',
value=self.value,
data_format='NCHW')
class ConvGLUBlock(nn.Layer):
def __init__(self, kernel_size, in_ch, out_ch, bottlececk_dim=0,
dropout=0.):
"""A convolutional Gated Linear Units (GLU) block.
Args:
kernel_size (int): kernel size
in_ch (int): number of input channels
out_ch (int): number of output channels
bottlececk_dim (int): dimension of the bottleneck layers for computational efficiency. Defaults to 0.
dropout (float): dropout probability. Defaults to 0..
"""
super().__init__()
self.conv_residual = None
if in_ch != out_ch:
self.conv_residual = nn.utils.weight_norm(
nn.Conv2D(
in_channels=in_ch, out_channels=out_ch, kernel_size=(1, 1)),
name='weight',
dim=0)
self.dropout_residual = nn.Dropout(p=dropout)
self.pad_left = ConstantPad2d((0, 0, kernel_size - 1, 0), 0)
layers = OrderedDict()
if bottlececk_dim == 0:
layers['conv'] = nn.utils.weight_norm(
nn.Conv2D(
in_channels=in_ch,
out_channels=out_ch * 2,
kernel_size=(kernel_size, 1)),
name='weight',
dim=0)
# TODO(hirofumi0810): padding?
layers['dropout'] = nn.Dropout(p=dropout)
layers['glu'] = GLU()
elif bottlececk_dim > 0:
layers['conv_in'] = nn.utils.weight_norm(
nn.Conv2D(
in_channels=in_ch,
out_channels=bottlececk_dim,
kernel_size=(1, 1)),
name='weight',
dim=0)
layers['dropout_in'] = nn.Dropout(p=dropout)
layers['conv_bottleneck'] = nn.utils.weight_norm(
nn.Conv2D(
in_channels=bottlececk_dim,
out_channels=bottlececk_dim,
kernel_size=(kernel_size, 1)),
name='weight',
dim=0)
layers['dropout'] = nn.Dropout(p=dropout)
layers['glu'] = GLU()
layers['conv_out'] = nn.utils.weight_norm(
nn.Conv2D(
in_channels=bottlececk_dim,
out_channels=out_ch * 2,
kernel_size=(1, 1)),
name='weight',
dim=0)
layers['dropout_out'] = nn.Dropout(p=dropout)
class Swish(nn.Layer):
"""Construct an Swish object."""
self.layers = nn.Sequential(layers)
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
"""Return Swish activation function."""
return x * F.sigmoid(x)
def forward(self, xs):
"""Forward pass.
Args:
xs (FloatTensor): `[B, in_ch, T, feat_dim]`
Returns:
out (FloatTensor): `[B, out_ch, T, feat_dim]`
"""
residual = xs
if self.conv_residual is not None:
residual = self.dropout_residual(self.conv_residual(residual))
xs = self.pad_left(xs) # `[B, embed_dim, T+kernel-1, 1]`
xs = self.layers(xs) # `[B, out_ch * 2, T ,1]`
xs = xs + residual
return xs

@ -0,0 +1,149 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ConvolutionModule definition."""
from typing import Optional, Tuple
from typeguard import check_argument_types
import logging
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
# init F.glu func
# TODO(Hui Zhang): remove this line
import deepspeech.modules.activation
logger = logging.getLogger(__name__)
__all__ = ['ConvolutionModule']
class ConvolutionModule(nn.Layer):
"""ConvolutionModule in Conformer model."""
def __init__(self,
channels: int,
kernel_size: int=15,
activation: nn.Layer=nn.ReLU(),
norm: str="batch_norm",
causal: bool=False,
bias: bool=True):
"""Construct an ConvolutionModule object.
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernel size of conv layers.
activation (nn.Layer): Activation Layer.
norm (str): Normalization type, 'batch_norm' or 'layer_norm'
causal (bool): Whether use causal convolution or not
bias (bool): Whether Conv with bias or not
"""
assert check_argument_types()
super().__init__()
self.pointwise_conv1 = nn.Conv1D(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias=None if bias else False, # None for True as default
)
# self.lorder is used to distinguish if it's a causal convolution,
# if self.lorder > 0:
# it's a causal convolution, the input will be padded with
# `self.lorder` frames on the left in forward (causal conv impl).
# else: it's a symmetrical convolution
if causal:
padding = 0
self.lorder = kernel_size - 1
else:
# kernel_size should be an odd number for none causal convolution
assert (kernel_size - 1) % 2 == 0
padding = (kernel_size - 1) // 2
self.lorder = 0
self.depthwise_conv = nn.Conv1D(
channels,
channels,
kernel_size,
stride=1,
padding=padding,
groups=channels,
bias=None if bias else False, # None for True as default
)
assert norm in ['batch_norm', 'layer_norm']
if norm == "batch_norm":
self.use_layer_norm = False
self.norm = nn.BatchNorm1D(channels)
else:
self.use_layer_norm = True
self.norm = nn.LayerNorm(channels)
self.pointwise_conv2 = nn.Conv1D(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=None if bias else False, # None for True as default
)
self.activation = activation
def forward(self, x: paddle.Tensor, cache: Optional[paddle.Tensor]=None
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute convolution module.
Args:
x (paddle.Tensor): Input tensor (#batch, time, channels).
cache (paddle.Tensor): left context cache, it is only
used in causal convolution. (#batch, channels, time)
Returns:
paddle.Tensor: Output tensor (#batch, time, channels).
paddle.Tensor: Output cache tensor (#batch, channels, time)
"""
# exchange the temporal dimension and the feature dimension
x = x.transpose([0, 2, 1]) # [B, C, T]
if self.lorder > 0:
if cache is None:
x = nn.functional.pad(
x, (self.lorder, 0), 'constant', 0.0, data_format='NCL')
else:
assert cache.shape[0] == x.shape[0] # B
assert cache.shape[1] == x.shape[1] # C
x = paddle.concat((cache, x), axis=2)
assert (x.shape[2] > self.lorder)
new_cache = x[:, :, -self.lorder:] #[B, C, T]
else:
# It's better we just return None if no cache is requried,
# However, for JIT export, here we just fake one tensor instead of
# None.
new_cache = paddle.to_tensor([0.0], dtype=x.dtype, place=x.place)
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
if self.use_layer_norm:
x = x.transpose([0, 2, 1]) # [B, T, C]
x = self.activation(self.norm(x))
if self.use_layer_norm:
x = x.transpose([0, 2, 1]) # [B, C, T]
x = self.pointwise_conv2(x)
x = x.transpose([0, 2, 1]) # [B, T, C]
return x, new_cache

@ -24,7 +24,32 @@ from deepspeech.modules.activation import brelu
logger = logging.getLogger(__name__)
__all__ = ['ConvStack']
__all__ = ['ConvStack', "conv_output_size"]
def conv_output_size(I, F, P, S):
# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters
# Output size after Conv:
# By noting I the length of the input volume size,
# F the length of the filter,
# P the amount of zero padding,
# S the stride,
# then the output size O of the feature map along that dimension is given by:
# O = (I - F + Pstart + Pend) // S + 1
# When Pstart == Pend == P, we can replace Pstart + Pend by 2P.
# When Pstart == Pend == 0
# O = (I - F - S) // S
# https://iq.opengenus.org/output-size-of-convolution/
# Output height = (Input height + padding height top + padding height bottom - kernel height) / (stride height) + 1
# Output width = (Output width + padding width right + padding width left - kernel width) / (stride width) + 1
return (I - F + 2 * P - S) // S
# receptive field calculator
# https://fomoro.com/research/article/receptive-field-calculator
# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters
# https://distill.pub/2019/computing-receptive-fields/
# Rl-1 = Sl * Rl + (Kl - Sl)
class ConvBn(nn.Layer):

@ -0,0 +1,277 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Encoder self-attention layer definition."""
from typing import Optional, Tuple
import logging
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
logger = logging.getLogger(__name__)
__all__ = ["TransformerEncoderLayer", "ConformerEncoderLayer"]
class TransformerEncoderLayer(nn.Layer):
"""Encoder layer module."""
def __init__(
self,
size: int,
self_attn: nn.Layer,
feed_forward: nn.Layer,
dropout_rate: float,
normalize_before: bool=True,
concat_after: bool=False, ):
"""Construct an EncoderLayer object.
Args:
size (int): Input dimension.
self_attn (nn.Layer): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward (nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward`, instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool):
True: use layer_norm before each sub-block.
False: to use layer_norm after each sub-block.
concat_after (bool): Whether to concat attention layer's input and
output.
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
"""
super().__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = nn.LayerNorm(size, epsilon=1e-12)
self.norm2 = nn.LayerNorm(size, epsilon=1e-12)
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
# concat_linear may be not used in forward fuction,
# but will be saved in the *.pt
self.concat_linear = nn.Linear(size + size, size)
def forward(
self,
x: paddle.Tensor,
mask: paddle.Tensor,
pos_emb: paddle.Tensor,
output_cache: Optional[paddle.Tensor]=None,
cnn_cache: Optional[paddle.Tensor]=None,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute encoded features.
Args:
x (paddle.Tensor): Input tensor (#batch, time, size).
mask (paddle.Tensor): Mask tensor for the input (#batch, time).
pos_emb (paddle.Tensor): just for interface compatibility
to ConformerEncoderLayer
output_cache (paddle.Tensor): Cache tensor of the output
(#batch, time2, size), time2 < time in x.
cnn_cache (paddle.Tensor): not used here, it's for interface
compatibility to ConformerEncoderLayer
Returns:
paddle.Tensor: Output tensor (#batch, time, size).
paddle.Tensor: Mask tensor (#batch, time).
"""
residual = x
if self.normalize_before:
x = self.norm1(x)
if output_cache is None:
x_q = x
else:
assert output_cache.shape[0] == x.shape[0]
assert output_cache.shape[1] < x.shape[1]
assert output_cache.shape[2] == self.size
chunk = x.shape[1] - output_cache.shape[1]
x_q = x[:, -chunk:, :]
residual = residual[:, -chunk:, :]
mask = mask[:, -chunk:, :]
if self.concat_after:
x_concat = paddle.concat(
(x, self.self_attn(x_q, x, x, mask)), axis=-1)
x = residual + self.concat_linear(x_concat)
else:
x = residual + self.dropout(self.self_attn(x_q, x, x, mask))
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
if output_cache is not None:
x = paddle.concat([output_cache, x], axis=1)
fake_cnn_cache = paddle.to_tensor([0.0], dtype=x.dtype, place=x.place)
return x, mask, fake_cnn_cache
class ConformerEncoderLayer(nn.Layer):
"""Encoder layer module."""
def __init__(
self,
size: int,
self_attn: int,
feed_forward: Optional[nn.Layer]=None,
feed_forward_macaron: Optional[nn.Layer]=None,
conv_module: Optional[nn.Layer]=None,
dropout_rate: float=0.1,
normalize_before: bool=True,
concat_after: bool=False, ):
"""Construct an EncoderLayer object.
Args:
size (int): Input dimension.
self_attn (nn.Layer): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward (nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
feed_forward_macaron (nn.Layer): Additional feed-forward module
instance.
`PositionwiseFeedForward` instance can be used as the argument.
conv_module (nn.Layer): Convolution module instance.
`ConvlutionModule` instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool):
True: use layer_norm before each sub-block.
False: use layer_norm after each sub-block.
concat_after (bool): Whether to concat attention layer's input and
output.
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
"""
super().__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.feed_forward_macaron = feed_forward_macaron
self.conv_module = conv_module
self.norm_ff = nn.LayerNorm(size, epsilon=1e-12) # for the FNN module
self.norm_mha = nn.LayerNorm(size, epsilon=1e-12) # for the MHA module
if feed_forward_macaron is not None:
self.norm_ff_macaron = nn.LayerNorm(size, epsilon=1e-12)
self.ff_scale = 0.5
else:
self.ff_scale = 1.0
if self.conv_module is not None:
self.norm_conv = nn.LayerNorm(
size, epsilon=1e-12) # for the CNN module
self.norm_final = nn.LayerNorm(
size, epsilon=1e-12) # for the final output of the block
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
self.concat_linear = nn.Linear(size + size, size)
def forward(
self,
x: paddle.Tensor,
mask: paddle.Tensor,
pos_emb: paddle.Tensor,
output_cache: Optional[paddle.Tensor]=None,
cnn_cache: Optional[paddle.Tensor]=None,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute encoded features.
Args:
x (paddle.Tensor): (#batch, time, size)
mask (paddle.Tensor): Mask tensor for the input (#batch, timetime).
pos_emb (paddle.Tensor): positional encoding, must not be None
for ConformerEncoderLayer.
output_cache (paddle.Tensor): Cache tensor of the output
(#batch, time2, size), time2 < time in x.
cnn_cache (paddle.Tensor): Convolution cache in conformer layer
Returns:
paddle.Tensor: Output tensor (#batch, time, size).
paddle.Tensor: Mask tensor (#batch, time).
"""
# whether to use macaron style FFN
if self.feed_forward_macaron is not None:
residual = x
if self.normalize_before:
x = self.norm_ff_macaron(x)
x = residual + self.ff_scale * self.dropout(
self.feed_forward_macaron(x))
if not self.normalize_before:
x = self.norm_ff_macaron(x)
# multi-headed self-attention module
residual = x
if self.normalize_before:
x = self.norm_mha(x)
if output_cache is None:
x_q = x
else:
assert output_cache.shape[0] == x.shape[0]
assert output_cache.shape[1] < x.shape[1]
assert output_cache.shape[2] == self.size
chunk = x.shape[1] - output_cache.shape[1]
x_q = x[:, -chunk:, :]
residual = residual[:, -chunk:, :]
mask = mask[:, -chunk:, :]
x_att = self.self_attn(x_q, x, x, pos_emb, mask)
if self.concat_after:
x_concat = paddle.concat((x, x_att), axis=-1)
x = residual + self.concat_linear(x_concat)
else:
x = residual + self.dropout(x_att)
if not self.normalize_before:
x = self.norm_mha(x)
# convolution module
# Fake new cnn cache here, and then change it in conv_module
new_cnn_cache = paddle.to_tensor([0.0], dtype=x.dtype, place=x.place)
if self.conv_module is not None:
residual = x
if self.normalize_before:
x = self.norm_conv(x)
x, new_cnn_cache = self.conv_module(x, cnn_cache)
x = residual + self.dropout(x)
if not self.normalize_before:
x = self.norm_conv(x)
# feed forward module
residual = x
if self.normalize_before:
x = self.norm_ff(x)
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm_ff(x)
if self.conv_module is not None:
x = self.norm_final(x)
if output_cache is not None:
x = paddle.concat([output_cache, x], axis=1)
return x, mask, new_cnn_cache

@ -0,0 +1,59 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Positionwise feed forward layer definition."""
import logging
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
logger = logging.getLogger(__name__)
__all__ = ["PositionwiseFeedForward"]
class PositionwiseFeedForward(nn.Layer):
"""Positionwise feed forward layer."""
def __init__(self,
idim: int,
hidden_units: int,
dropout_rate: float,
activation: nn.Layer=nn.ReLU()):
"""Construct a PositionwiseFeedForward object.
FeedForward are appied on each position of the sequence.
The output dim is same with the input dim.
Args:
idim (int): Input dimenstion.
hidden_units (int): The number of hidden units.
dropout_rate (float): Dropout rate.
activation (paddle.nn.Layer): Activation function
"""
super().__init__()
self.w_1 = nn.Linear(idim, hidden_units)
self.activation = activation
self.dropout = nn.Dropout(dropout_rate)
self.w_2 = nn.Linear(hidden_units, idim)
def forward(self, xs: paddle.Tensor) -> paddle.Tensor:
"""Forward function.
Args:
xs: input tensor (B, Lmax, D)
Returns:
output tensor, (B, Lmax, D)
"""
return self.w_2(self.dropout(self.activation(self.w_1(xs))))

@ -0,0 +1,235 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Subsampling layer definition."""
from typing import Tuple
import logging
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
from deepspeech.modules.embedding import PositionalEncoding
logger = logging.getLogger(__name__)
__all__ = [
"LinearNoSubsampling", "Conv2dSubsampling4", "Conv2dSubsampling6",
"Conv2dSubsampling8"
]
class BaseSubsampling(nn.Layer):
def __init__(self, pos_enc_class: PositionalEncoding):
super().__init__()
self.pos_enc = pos_enc_class
self.right_context = 0
self.subsampling_rate = 1
def position_encoding(self, offset: int, size: int) -> paddle.Tensor:
return self.pos_enc.position_encoding(offset, size)
class LinearNoSubsampling(BaseSubsampling):
"""Linear transform the input without subsampling."""
def __init__(self,
idim: int,
odim: int,
dropout_rate: float,
pos_enc_class: PositionalEncoding):
"""Construct an linear object.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc_class (PositionalEncoding): position encoding class
"""
super().__init__(pos_enc_class)
self.out = nn.Sequential(
nn.Linear(idim, odim),
nn.LayerNorm(odim, epsilon=1e-12),
nn.Dropout(dropout_rate), )
self.right_context = 0
self.subsampling_rate = 1
def forward(self, x: paddle.Tensor, x_mask: paddle.Tensor, offset: int=0
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Input x.
Args:
x (paddle.Tensor): Input tensor (#batch, time, idim).
x_mask (paddle.Tensor): Input mask (#batch, 1, time).
Returns:
paddle.Tensor: linear input tensor (#batch, time', odim),
where time' = time .
paddle.Tensor: positional encoding
paddle.Tensor: linear input mask (#batch, 1, time'),
where time' = time .
"""
x = self.out(x)
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask
class Conv2dSubsampling4(BaseSubsampling):
"""Convolutional 2D subsampling (to 1/4 length)."""
def __init__(self,
idim: int,
odim: int,
dropout_rate: float,
pos_enc_class: PositionalEncoding):
"""Construct an Conv2dSubsampling4 object.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
super().__init__(pos_enc_class)
self.conv = nn.Sequential(
nn.Conv2D(1, odim, 3, 2),
nn.ReLU(),
nn.Conv2D(odim, odim, 3, 2),
nn.ReLU(), )
self.linear = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
self.subsampling_rate = 4
# The right context for every conv layer is computed by:
# (kernel_size - 1) / 2 * stride * frame_rate_of_this_layer
# 6 = (3 - 1) / 2 * 2 * 1 + (3 - 1) / 2 * 2 * 2
self.right_context = 6
def forward(self, x: paddle.Tensor, x_mask: paddle.Tensor, offset: int=0
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Subsample x.
Args:
x (paddle.Tensor): Input tensor (#batch, time, idim).
x_mask (paddle.Tensor): Input mask (#batch, 1, time).
Returns:
paddle.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 4.
paddle.Tensor: positional encoding
paddle.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 4.
"""
x = x.unsqueeze(1) # (b, c=1, t, f)
x = self.conv(x)
b, c, t, f = paddle.shape(x)
x = self.linear(x.transpose([0, 1, 2, 3]).reshape([b, t, c * f]))
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2]
class Conv2dSubsampling6(BaseSubsampling):
"""Convolutional 2D subsampling (to 1/6 length)."""
def __init__(self,
idim: int,
odim: int,
dropout_rate: float,
pos_enc_class: PositionalEncoding):
"""Construct an Conv2dSubsampling6 object.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (PositionalEncoding): Custom position encoding layer.
"""
super().__init__(pos_enc_class)
self.conv = nn.Sequential(
nn.Conv2D(1, odim, 3, 2),
nn.ReLU(),
nn.Conv2D(odim, odim, 5, 3),
nn.ReLU(), )
# O = (I - F + Pstart + Pend) // S + 1
# when Padding == 0, O = (I - F - S) // S
self.linear = nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim)
# The right context for every conv layer is computed by:
# (kernel_size - 1) / 2 * stride * frame_rate_of_this_layer
# 14 = (3 - 1) / 2 * 2 * 1 + (5 - 1) / 2 * 3 * 2
self.subsampling_rate = 6
self.right_context = 14
def forward(self, x: paddle.Tensor, x_mask: paddle.Tensor, offset: int=0
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Subsample x.
Args:
x (paddle.Tensor): Input tensor (#batch, time, idim).
x_mask (paddle.Tensor): Input mask (#batch, 1, time).
Returns:
paddle.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 6.
paddle.Tensor: positional encoding
paddle.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 6.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = paddle.shape(x)
x = self.linear(x.transpose([0, 1, 2, 3]).reshape([b, t, c * f]))
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-4:3]
class Conv2dSubsampling8(BaseSubsampling):
"""Convolutional 2D subsampling (to 1/8 length)."""
def __init__(self,
idim: int,
odim: int,
dropout_rate: float,
pos_enc_class: PositionalEncoding):
"""Construct an Conv2dSubsampling8 object.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
super().__init__(pos_enc_class)
self.conv = nn.Sequential(
nn.Conv2D(1, odim, 3, 2),
nn.ReLU(),
nn.Conv2D(odim, odim, 3, 2),
nn.ReLU(),
nn.Conv2D(odim, odim, 3, 2),
nn.ReLU(), )
self.linear = nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2),
odim)
self.subsampling_rate = 8
# The right context for every conv layer is computed by:
# (kernel_size - 1) / 2 * stride * frame_rate_of_this_layer
# 14 = (3 - 1) / 2 * 2 * 1 + (3 - 1) / 2 * 2 * 2 + (3 - 1) / 2 * 2 * 4
self.right_context = 14
def forward(self, x: paddle.Tensor, x_mask: paddle.Tensor, offset: int=0
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Subsample x.
Args:
x (paddle.Tensor): Input tensor (#batch, time, idim).
x_mask (paddle.Tensor): Input mask (#batch, 1, time).
Returns:
paddle.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 8.
paddle.Tensor: positional encoding
paddle.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 8.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
x = self.linear(x.transpose([0, 1, 2, 3]).reshape([b, t, c * f]))
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2]

@ -21,8 +21,10 @@ from paddle.fluid import core
logger = logging.getLogger(__name__)
__all__ = ["ClipGradByGlobalNormWithLog"]
class MyClipGradByGlobalNorm(paddle.nn.ClipGradByGlobalNorm):
class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm):
def __init__(self, clip_norm):
super().__init__(clip_norm)

Loading…
Cancel
Save