add layer for transformer

pull/556/head
Hui Zhang 5 years ago
parent 9cf8c1a5db
commit b2bc6eb526

@ -28,7 +28,7 @@ from paddle import distributed as dist
from paddle.io import DataLoader from paddle.io import DataLoader
from deepspeech.training import Trainer from deepspeech.training import Trainer
from deepspeech.training.gradclip import MyClipGradByGlobalNorm from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.utils import mp_tools from deepspeech.utils import mp_tools
from deepspeech.utils import layer_tools from deepspeech.utils import layer_tools
@ -125,7 +125,8 @@ class DeepSpeech2Trainer(Trainer):
layer_tools.print_params(model, self.logger.info) layer_tools.print_params(model, self.logger.info)
grad_clip = MyClipGradByGlobalNorm(config.training.global_grad_clip) grad_clip = ClipGradByGlobalNormWithLog(
config.training.global_grad_clip)
lr_scheduler = paddle.optimizer.lr.ExponentialDecay( lr_scheduler = paddle.optimizer.lr.ExponentialDecay(
learning_rate=config.training.lr, learning_rate=config.training.lr,
gamma=config.training.lr_decay, gamma=config.training.lr_decay,

@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Union
import logging import logging
import numpy as np import numpy as np
import math import math
from collections import OrderedDict
import paddle import paddle
from paddle import nn from paddle import nn
@ -23,7 +25,7 @@ from paddle.nn import initializer as I
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
__all__ = ['brelu', "softplus", "gelu_accurate", "gelu", 'Swish'] __all__ = ['brelu', "glu"]
def brelu(x, t_min=0.0, t_max=24.0, name=None): def brelu(x, t_min=0.0, t_max=24.0, name=None):
@ -33,36 +35,180 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None):
return x.maximum(t_min).minimum(t_max) return x.maximum(t_min).minimum(t_max)
def softplus(x): # def softplus(x):
"""Softplus function.""" # """Softplus function."""
if hasattr(paddle.nn.functional, 'softplus'): # if hasattr(paddle.nn.functional, 'softplus'):
#return paddle.nn.functional.softplus(x.float()).type_as(x) # #return paddle.nn.functional.softplus(x.float()).type_as(x)
return paddle.nn.functional.softplus(x) # return paddle.nn.functional.softplus(x)
else: # else:
raise NotImplementedError # raise NotImplementedError
# def gelu_accurate(x):
# """Gaussian Error Linear Units (GELU) activation."""
# # [reference] https://github.com/pytorch/fairseq/blob/e75cff5f2c1d62f12dc911e0bf420025eb1a4e33/fairseq/modules/gelu.py
# if not hasattr(gelu_accurate, "_a"):
# gelu_accurate._a = math.sqrt(2 / math.pi)
# return 0.5 * x * (1 + paddle.tanh(gelu_accurate._a *
# (x + 0.044715 * paddle.pow(x, 3))))
def gelu_accurate(x): # def gelu(x):
"""Gaussian Error Linear Units (GELU) activation.""" # """Gaussian Error Linear Units (GELU) activation."""
# [reference] https://github.com/pytorch/fairseq/blob/e75cff5f2c1d62f12dc911e0bf420025eb1a4e33/fairseq/modules/gelu.py # if hasattr(nn.functional, 'gelu'):
if not hasattr(gelu_accurate, "_a"): # #return nn.functional.gelu(x.float()).type_as(x)
gelu_accurate._a = math.sqrt(2 / math.pi) # return nn.functional.gelu(x)
return 0.5 * x * (1 + paddle.tanh(gelu_accurate._a * # else:
(x + 0.044715 * paddle.pow(x, 3)))) # return x * 0.5 * (1.0 + paddle.erf(x / math.sqrt(2.0)))
def gelu(x): # TODO(Hui Zhang): remove this activation
"""Gaussian Error Linear Units (GELU) activation.""" def glu(x, dim=-1):
if hasattr(torch.nn.functional, 'gelu'): """The gated linear unit (GLU) activation."""
#return torch.nn.functional.gelu(x.float()).type_as(x) if hasattr(nn.functional, 'glu'):
return torch.nn.functional.gelu(x) return nn.functional.glu(x)
else: else:
return x * 0.5 * (1.0 + paddle.erf(x / math.sqrt(2.0))) a, b = x.split(2, axis=dim)
act_b = F.sigmoid(b)
return a * act_b
# TODO(Hui Zhang): remove this activation
if not hasattr(nn.functional, 'glu'):
setattr(nn.functional, 'glu', glu)
# TODO(Hui Zhang): remove this activation
class GLU(nn.Layer):
"""Gated Linear Units (GLU) Layer"""
def __init__(self, dim: int=-1):
super().__init__()
self.dim = dim
def forward(self, xs):
return glu(xs, dim=self.dim)
class LinearGLUBlock(nn.Layer):
"""A linear Gated Linear Units (GLU) block."""
def __init__(self, idim: int):
""" GLU.
Args:
idim (int): input and output dimension
"""
super().__init__()
self.fc = nn.Linear(idim, idim * 2)
def forward(self, xs):
return glu(self.fc(xs), dim=-1)
# TODO(Hui Zhang): remove this Layer
class ConstantPad2d(nn.Layer):
"""Pads the input tensor boundaries with a constant value.
For N-dimensional padding, use paddle.nn.functional.pad().
"""
def __init__(self, padding: Union[tuple, list, int], value: float):
"""
Args:
paddle ([tuple]): the size of the padding.
If is int, uses the same padding in all boundaries.
If a 4-tuple, uses (padding_left, padding_right, padding_top, padding_bottom)
value ([flaot]): pad value
"""
self.padding = padding if isinstance(padding,
[tuple, list]) else [padding] * 4
self.value = value
def forward(self, xs: paddle.Tensor) -> paddle.Tensor:
return nn.functional.pad(
xs,
self.padding,
mode='constant',
value=self.value,
data_format='NCHW')
class ConvGLUBlock(nn.Layer):
def __init__(self, kernel_size, in_ch, out_ch, bottlececk_dim=0,
dropout=0.):
"""A convolutional Gated Linear Units (GLU) block.
Args:
kernel_size (int): kernel size
in_ch (int): number of input channels
out_ch (int): number of output channels
bottlececk_dim (int): dimension of the bottleneck layers for computational efficiency. Defaults to 0.
dropout (float): dropout probability. Defaults to 0..
"""
super().__init__()
self.conv_residual = None
if in_ch != out_ch:
self.conv_residual = nn.utils.weight_norm(
nn.Conv2D(
in_channels=in_ch, out_channels=out_ch, kernel_size=(1, 1)),
name='weight',
dim=0)
self.dropout_residual = nn.Dropout(p=dropout)
self.pad_left = ConstantPad2d((0, 0, kernel_size - 1, 0), 0)
layers = OrderedDict()
if bottlececk_dim == 0:
layers['conv'] = nn.utils.weight_norm(
nn.Conv2D(
in_channels=in_ch,
out_channels=out_ch * 2,
kernel_size=(kernel_size, 1)),
name='weight',
dim=0)
# TODO(hirofumi0810): padding?
layers['dropout'] = nn.Dropout(p=dropout)
layers['glu'] = GLU()
elif bottlececk_dim > 0:
layers['conv_in'] = nn.utils.weight_norm(
nn.Conv2D(
in_channels=in_ch,
out_channels=bottlececk_dim,
kernel_size=(1, 1)),
name='weight',
dim=0)
layers['dropout_in'] = nn.Dropout(p=dropout)
layers['conv_bottleneck'] = nn.utils.weight_norm(
nn.Conv2D(
in_channels=bottlececk_dim,
out_channels=bottlececk_dim,
kernel_size=(kernel_size, 1)),
name='weight',
dim=0)
layers['dropout'] = nn.Dropout(p=dropout)
layers['glu'] = GLU()
layers['conv_out'] = nn.utils.weight_norm(
nn.Conv2D(
in_channels=bottlececk_dim,
out_channels=out_ch * 2,
kernel_size=(1, 1)),
name='weight',
dim=0)
layers['dropout_out'] = nn.Dropout(p=dropout)
class Swish(nn.Layer): self.layers = nn.Sequential(layers)
"""Construct an Swish object."""
def forward(self, x: paddle.Tensor) -> paddle.Tensor: def forward(self, xs):
"""Return Swish activation function.""" """Forward pass.
return x * F.sigmoid(x) Args:
xs (FloatTensor): `[B, in_ch, T, feat_dim]`
Returns:
out (FloatTensor): `[B, out_ch, T, feat_dim]`
"""
residual = xs
if self.conv_residual is not None:
residual = self.dropout_residual(self.conv_residual(residual))
xs = self.pad_left(xs) # `[B, embed_dim, T+kernel-1, 1]`
xs = self.layers(xs) # `[B, out_ch * 2, T ,1]`
xs = xs + residual
return xs

@ -0,0 +1,149 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ConvolutionModule definition."""
from typing import Optional, Tuple
from typeguard import check_argument_types
import logging
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
# init F.glu func
# TODO(Hui Zhang): remove this line
import deepspeech.modules.activation
logger = logging.getLogger(__name__)
__all__ = ['ConvolutionModule']
class ConvolutionModule(nn.Layer):
"""ConvolutionModule in Conformer model."""
def __init__(self,
channels: int,
kernel_size: int=15,
activation: nn.Layer=nn.ReLU(),
norm: str="batch_norm",
causal: bool=False,
bias: bool=True):
"""Construct an ConvolutionModule object.
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernel size of conv layers.
activation (nn.Layer): Activation Layer.
norm (str): Normalization type, 'batch_norm' or 'layer_norm'
causal (bool): Whether use causal convolution or not
bias (bool): Whether Conv with bias or not
"""
assert check_argument_types()
super().__init__()
self.pointwise_conv1 = nn.Conv1D(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias=None if bias else False, # None for True as default
)
# self.lorder is used to distinguish if it's a causal convolution,
# if self.lorder > 0:
# it's a causal convolution, the input will be padded with
# `self.lorder` frames on the left in forward (causal conv impl).
# else: it's a symmetrical convolution
if causal:
padding = 0
self.lorder = kernel_size - 1
else:
# kernel_size should be an odd number for none causal convolution
assert (kernel_size - 1) % 2 == 0
padding = (kernel_size - 1) // 2
self.lorder = 0
self.depthwise_conv = nn.Conv1D(
channels,
channels,
kernel_size,
stride=1,
padding=padding,
groups=channels,
bias=None if bias else False, # None for True as default
)
assert norm in ['batch_norm', 'layer_norm']
if norm == "batch_norm":
self.use_layer_norm = False
self.norm = nn.BatchNorm1D(channels)
else:
self.use_layer_norm = True
self.norm = nn.LayerNorm(channels)
self.pointwise_conv2 = nn.Conv1D(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=None if bias else False, # None for True as default
)
self.activation = activation
def forward(self, x: paddle.Tensor, cache: Optional[paddle.Tensor]=None
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute convolution module.
Args:
x (paddle.Tensor): Input tensor (#batch, time, channels).
cache (paddle.Tensor): left context cache, it is only
used in causal convolution. (#batch, channels, time)
Returns:
paddle.Tensor: Output tensor (#batch, time, channels).
paddle.Tensor: Output cache tensor (#batch, channels, time)
"""
# exchange the temporal dimension and the feature dimension
x = x.transpose([0, 2, 1]) # [B, C, T]
if self.lorder > 0:
if cache is None:
x = nn.functional.pad(
x, (self.lorder, 0), 'constant', 0.0, data_format='NCL')
else:
assert cache.shape[0] == x.shape[0] # B
assert cache.shape[1] == x.shape[1] # C
x = paddle.concat((cache, x), axis=2)
assert (x.shape[2] > self.lorder)
new_cache = x[:, :, -self.lorder:] #[B, C, T]
else:
# It's better we just return None if no cache is requried,
# However, for JIT export, here we just fake one tensor instead of
# None.
new_cache = paddle.to_tensor([0.0], dtype=x.dtype, place=x.place)
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
if self.use_layer_norm:
x = x.transpose([0, 2, 1]) # [B, T, C]
x = self.activation(self.norm(x))
if self.use_layer_norm:
x = x.transpose([0, 2, 1]) # [B, C, T]
x = self.pointwise_conv2(x)
x = x.transpose([0, 2, 1]) # [B, T, C]
return x, new_cache

@ -24,7 +24,32 @@ from deepspeech.modules.activation import brelu
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
__all__ = ['ConvStack'] __all__ = ['ConvStack', "conv_output_size"]
def conv_output_size(I, F, P, S):
# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters
# Output size after Conv:
# By noting I the length of the input volume size,
# F the length of the filter,
# P the amount of zero padding,
# S the stride,
# then the output size O of the feature map along that dimension is given by:
# O = (I - F + Pstart + Pend) // S + 1
# When Pstart == Pend == P, we can replace Pstart + Pend by 2P.
# When Pstart == Pend == 0
# O = (I - F - S) // S
# https://iq.opengenus.org/output-size-of-convolution/
# Output height = (Input height + padding height top + padding height bottom - kernel height) / (stride height) + 1
# Output width = (Output width + padding width right + padding width left - kernel width) / (stride width) + 1
return (I - F + 2 * P - S) // S
# receptive field calculator
# https://fomoro.com/research/article/receptive-field-calculator
# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters
# https://distill.pub/2019/computing-receptive-fields/
# Rl-1 = Sl * Rl + (Kl - Sl)
class ConvBn(nn.Layer): class ConvBn(nn.Layer):

@ -0,0 +1,277 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Encoder self-attention layer definition."""
from typing import Optional, Tuple
import logging
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
logger = logging.getLogger(__name__)
__all__ = ["TransformerEncoderLayer", "ConformerEncoderLayer"]
class TransformerEncoderLayer(nn.Layer):
"""Encoder layer module."""
def __init__(
self,
size: int,
self_attn: nn.Layer,
feed_forward: nn.Layer,
dropout_rate: float,
normalize_before: bool=True,
concat_after: bool=False, ):
"""Construct an EncoderLayer object.
Args:
size (int): Input dimension.
self_attn (nn.Layer): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward (nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward`, instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool):
True: use layer_norm before each sub-block.
False: to use layer_norm after each sub-block.
concat_after (bool): Whether to concat attention layer's input and
output.
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
"""
super().__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = nn.LayerNorm(size, epsilon=1e-12)
self.norm2 = nn.LayerNorm(size, epsilon=1e-12)
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
# concat_linear may be not used in forward fuction,
# but will be saved in the *.pt
self.concat_linear = nn.Linear(size + size, size)
def forward(
self,
x: paddle.Tensor,
mask: paddle.Tensor,
pos_emb: paddle.Tensor,
output_cache: Optional[paddle.Tensor]=None,
cnn_cache: Optional[paddle.Tensor]=None,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute encoded features.
Args:
x (paddle.Tensor): Input tensor (#batch, time, size).
mask (paddle.Tensor): Mask tensor for the input (#batch, time).
pos_emb (paddle.Tensor): just for interface compatibility
to ConformerEncoderLayer
output_cache (paddle.Tensor): Cache tensor of the output
(#batch, time2, size), time2 < time in x.
cnn_cache (paddle.Tensor): not used here, it's for interface
compatibility to ConformerEncoderLayer
Returns:
paddle.Tensor: Output tensor (#batch, time, size).
paddle.Tensor: Mask tensor (#batch, time).
"""
residual = x
if self.normalize_before:
x = self.norm1(x)
if output_cache is None:
x_q = x
else:
assert output_cache.shape[0] == x.shape[0]
assert output_cache.shape[1] < x.shape[1]
assert output_cache.shape[2] == self.size
chunk = x.shape[1] - output_cache.shape[1]
x_q = x[:, -chunk:, :]
residual = residual[:, -chunk:, :]
mask = mask[:, -chunk:, :]
if self.concat_after:
x_concat = paddle.concat(
(x, self.self_attn(x_q, x, x, mask)), axis=-1)
x = residual + self.concat_linear(x_concat)
else:
x = residual + self.dropout(self.self_attn(x_q, x, x, mask))
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
if output_cache is not None:
x = paddle.concat([output_cache, x], axis=1)
fake_cnn_cache = paddle.to_tensor([0.0], dtype=x.dtype, place=x.place)
return x, mask, fake_cnn_cache
class ConformerEncoderLayer(nn.Layer):
"""Encoder layer module."""
def __init__(
self,
size: int,
self_attn: int,
feed_forward: Optional[nn.Layer]=None,
feed_forward_macaron: Optional[nn.Layer]=None,
conv_module: Optional[nn.Layer]=None,
dropout_rate: float=0.1,
normalize_before: bool=True,
concat_after: bool=False, ):
"""Construct an EncoderLayer object.
Args:
size (int): Input dimension.
self_attn (nn.Layer): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward (nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
feed_forward_macaron (nn.Layer): Additional feed-forward module
instance.
`PositionwiseFeedForward` instance can be used as the argument.
conv_module (nn.Layer): Convolution module instance.
`ConvlutionModule` instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool):
True: use layer_norm before each sub-block.
False: use layer_norm after each sub-block.
concat_after (bool): Whether to concat attention layer's input and
output.
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
"""
super().__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.feed_forward_macaron = feed_forward_macaron
self.conv_module = conv_module
self.norm_ff = nn.LayerNorm(size, epsilon=1e-12) # for the FNN module
self.norm_mha = nn.LayerNorm(size, epsilon=1e-12) # for the MHA module
if feed_forward_macaron is not None:
self.norm_ff_macaron = nn.LayerNorm(size, epsilon=1e-12)
self.ff_scale = 0.5
else:
self.ff_scale = 1.0
if self.conv_module is not None:
self.norm_conv = nn.LayerNorm(
size, epsilon=1e-12) # for the CNN module
self.norm_final = nn.LayerNorm(
size, epsilon=1e-12) # for the final output of the block
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
self.concat_linear = nn.Linear(size + size, size)
def forward(
self,
x: paddle.Tensor,
mask: paddle.Tensor,
pos_emb: paddle.Tensor,
output_cache: Optional[paddle.Tensor]=None,
cnn_cache: Optional[paddle.Tensor]=None,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute encoded features.
Args:
x (paddle.Tensor): (#batch, time, size)
mask (paddle.Tensor): Mask tensor for the input (#batch, timetime).
pos_emb (paddle.Tensor): positional encoding, must not be None
for ConformerEncoderLayer.
output_cache (paddle.Tensor): Cache tensor of the output
(#batch, time2, size), time2 < time in x.
cnn_cache (paddle.Tensor): Convolution cache in conformer layer
Returns:
paddle.Tensor: Output tensor (#batch, time, size).
paddle.Tensor: Mask tensor (#batch, time).
"""
# whether to use macaron style FFN
if self.feed_forward_macaron is not None:
residual = x
if self.normalize_before:
x = self.norm_ff_macaron(x)
x = residual + self.ff_scale * self.dropout(
self.feed_forward_macaron(x))
if not self.normalize_before:
x = self.norm_ff_macaron(x)
# multi-headed self-attention module
residual = x
if self.normalize_before:
x = self.norm_mha(x)
if output_cache is None:
x_q = x
else:
assert output_cache.shape[0] == x.shape[0]
assert output_cache.shape[1] < x.shape[1]
assert output_cache.shape[2] == self.size
chunk = x.shape[1] - output_cache.shape[1]
x_q = x[:, -chunk:, :]
residual = residual[:, -chunk:, :]
mask = mask[:, -chunk:, :]
x_att = self.self_attn(x_q, x, x, pos_emb, mask)
if self.concat_after:
x_concat = paddle.concat((x, x_att), axis=-1)
x = residual + self.concat_linear(x_concat)
else:
x = residual + self.dropout(x_att)
if not self.normalize_before:
x = self.norm_mha(x)
# convolution module
# Fake new cnn cache here, and then change it in conv_module
new_cnn_cache = paddle.to_tensor([0.0], dtype=x.dtype, place=x.place)
if self.conv_module is not None:
residual = x
if self.normalize_before:
x = self.norm_conv(x)
x, new_cnn_cache = self.conv_module(x, cnn_cache)
x = residual + self.dropout(x)
if not self.normalize_before:
x = self.norm_conv(x)
# feed forward module
residual = x
if self.normalize_before:
x = self.norm_ff(x)
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm_ff(x)
if self.conv_module is not None:
x = self.norm_final(x)
if output_cache is not None:
x = paddle.concat([output_cache, x], axis=1)
return x, mask, new_cnn_cache

@ -0,0 +1,59 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Positionwise feed forward layer definition."""
import logging
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
logger = logging.getLogger(__name__)
__all__ = ["PositionwiseFeedForward"]
class PositionwiseFeedForward(nn.Layer):
"""Positionwise feed forward layer."""
def __init__(self,
idim: int,
hidden_units: int,
dropout_rate: float,
activation: nn.Layer=nn.ReLU()):
"""Construct a PositionwiseFeedForward object.
FeedForward are appied on each position of the sequence.
The output dim is same with the input dim.
Args:
idim (int): Input dimenstion.
hidden_units (int): The number of hidden units.
dropout_rate (float): Dropout rate.
activation (paddle.nn.Layer): Activation function
"""
super().__init__()
self.w_1 = nn.Linear(idim, hidden_units)
self.activation = activation
self.dropout = nn.Dropout(dropout_rate)
self.w_2 = nn.Linear(hidden_units, idim)
def forward(self, xs: paddle.Tensor) -> paddle.Tensor:
"""Forward function.
Args:
xs: input tensor (B, Lmax, D)
Returns:
output tensor, (B, Lmax, D)
"""
return self.w_2(self.dropout(self.activation(self.w_1(xs))))

@ -0,0 +1,235 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Subsampling layer definition."""
from typing import Tuple
import logging
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
from deepspeech.modules.embedding import PositionalEncoding
logger = logging.getLogger(__name__)
__all__ = [
"LinearNoSubsampling", "Conv2dSubsampling4", "Conv2dSubsampling6",
"Conv2dSubsampling8"
]
class BaseSubsampling(nn.Layer):
def __init__(self, pos_enc_class: PositionalEncoding):
super().__init__()
self.pos_enc = pos_enc_class
self.right_context = 0
self.subsampling_rate = 1
def position_encoding(self, offset: int, size: int) -> paddle.Tensor:
return self.pos_enc.position_encoding(offset, size)
class LinearNoSubsampling(BaseSubsampling):
"""Linear transform the input without subsampling."""
def __init__(self,
idim: int,
odim: int,
dropout_rate: float,
pos_enc_class: PositionalEncoding):
"""Construct an linear object.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc_class (PositionalEncoding): position encoding class
"""
super().__init__(pos_enc_class)
self.out = nn.Sequential(
nn.Linear(idim, odim),
nn.LayerNorm(odim, epsilon=1e-12),
nn.Dropout(dropout_rate), )
self.right_context = 0
self.subsampling_rate = 1
def forward(self, x: paddle.Tensor, x_mask: paddle.Tensor, offset: int=0
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Input x.
Args:
x (paddle.Tensor): Input tensor (#batch, time, idim).
x_mask (paddle.Tensor): Input mask (#batch, 1, time).
Returns:
paddle.Tensor: linear input tensor (#batch, time', odim),
where time' = time .
paddle.Tensor: positional encoding
paddle.Tensor: linear input mask (#batch, 1, time'),
where time' = time .
"""
x = self.out(x)
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask
class Conv2dSubsampling4(BaseSubsampling):
"""Convolutional 2D subsampling (to 1/4 length)."""
def __init__(self,
idim: int,
odim: int,
dropout_rate: float,
pos_enc_class: PositionalEncoding):
"""Construct an Conv2dSubsampling4 object.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
super().__init__(pos_enc_class)
self.conv = nn.Sequential(
nn.Conv2D(1, odim, 3, 2),
nn.ReLU(),
nn.Conv2D(odim, odim, 3, 2),
nn.ReLU(), )
self.linear = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
self.subsampling_rate = 4
# The right context for every conv layer is computed by:
# (kernel_size - 1) / 2 * stride * frame_rate_of_this_layer
# 6 = (3 - 1) / 2 * 2 * 1 + (3 - 1) / 2 * 2 * 2
self.right_context = 6
def forward(self, x: paddle.Tensor, x_mask: paddle.Tensor, offset: int=0
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Subsample x.
Args:
x (paddle.Tensor): Input tensor (#batch, time, idim).
x_mask (paddle.Tensor): Input mask (#batch, 1, time).
Returns:
paddle.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 4.
paddle.Tensor: positional encoding
paddle.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 4.
"""
x = x.unsqueeze(1) # (b, c=1, t, f)
x = self.conv(x)
b, c, t, f = paddle.shape(x)
x = self.linear(x.transpose([0, 1, 2, 3]).reshape([b, t, c * f]))
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2]
class Conv2dSubsampling6(BaseSubsampling):
"""Convolutional 2D subsampling (to 1/6 length)."""
def __init__(self,
idim: int,
odim: int,
dropout_rate: float,
pos_enc_class: PositionalEncoding):
"""Construct an Conv2dSubsampling6 object.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (PositionalEncoding): Custom position encoding layer.
"""
super().__init__(pos_enc_class)
self.conv = nn.Sequential(
nn.Conv2D(1, odim, 3, 2),
nn.ReLU(),
nn.Conv2D(odim, odim, 5, 3),
nn.ReLU(), )
# O = (I - F + Pstart + Pend) // S + 1
# when Padding == 0, O = (I - F - S) // S
self.linear = nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim)
# The right context for every conv layer is computed by:
# (kernel_size - 1) / 2 * stride * frame_rate_of_this_layer
# 14 = (3 - 1) / 2 * 2 * 1 + (5 - 1) / 2 * 3 * 2
self.subsampling_rate = 6
self.right_context = 14
def forward(self, x: paddle.Tensor, x_mask: paddle.Tensor, offset: int=0
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Subsample x.
Args:
x (paddle.Tensor): Input tensor (#batch, time, idim).
x_mask (paddle.Tensor): Input mask (#batch, 1, time).
Returns:
paddle.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 6.
paddle.Tensor: positional encoding
paddle.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 6.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = paddle.shape(x)
x = self.linear(x.transpose([0, 1, 2, 3]).reshape([b, t, c * f]))
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-4:3]
class Conv2dSubsampling8(BaseSubsampling):
"""Convolutional 2D subsampling (to 1/8 length)."""
def __init__(self,
idim: int,
odim: int,
dropout_rate: float,
pos_enc_class: PositionalEncoding):
"""Construct an Conv2dSubsampling8 object.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
super().__init__(pos_enc_class)
self.conv = nn.Sequential(
nn.Conv2D(1, odim, 3, 2),
nn.ReLU(),
nn.Conv2D(odim, odim, 3, 2),
nn.ReLU(),
nn.Conv2D(odim, odim, 3, 2),
nn.ReLU(), )
self.linear = nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2),
odim)
self.subsampling_rate = 8
# The right context for every conv layer is computed by:
# (kernel_size - 1) / 2 * stride * frame_rate_of_this_layer
# 14 = (3 - 1) / 2 * 2 * 1 + (3 - 1) / 2 * 2 * 2 + (3 - 1) / 2 * 2 * 4
self.right_context = 14
def forward(self, x: paddle.Tensor, x_mask: paddle.Tensor, offset: int=0
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Subsample x.
Args:
x (paddle.Tensor): Input tensor (#batch, time, idim).
x_mask (paddle.Tensor): Input mask (#batch, 1, time).
Returns:
paddle.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 8.
paddle.Tensor: positional encoding
paddle.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 8.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
x = self.linear(x.transpose([0, 1, 2, 3]).reshape([b, t, c * f]))
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2]

@ -21,8 +21,10 @@ from paddle.fluid import core
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
__all__ = ["ClipGradByGlobalNormWithLog"]
class MyClipGradByGlobalNorm(paddle.nn.ClipGradByGlobalNorm):
class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm):
def __init__(self, clip_norm): def __init__(self, clip_norm):
super().__init__(clip_norm) super().__init__(clip_norm)

Loading…
Cancel
Save