From b2bc6eb526009b682e3a9ed9881642ad30ad1337 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 16 Mar 2021 12:52:37 +0000 Subject: [PATCH] add layer for transformer --- deepspeech/exps/deepspeech2/model.py | 5 +- deepspeech/modules/activation.py | 198 +++++++++++-- deepspeech/modules/conformer_convolution.py | 149 ++++++++++ deepspeech/modules/conv.py | 27 +- deepspeech/modules/encoder_layer.py | 277 ++++++++++++++++++ .../modules/positionwise_feed_forward.py | 59 ++++ deepspeech/modules/subsampling.py | 235 +++++++++++++++ deepspeech/training/gradclip.py | 4 +- 8 files changed, 924 insertions(+), 30 deletions(-) create mode 100644 deepspeech/modules/conformer_convolution.py create mode 100644 deepspeech/modules/encoder_layer.py create mode 100644 deepspeech/modules/positionwise_feed_forward.py create mode 100644 deepspeech/modules/subsampling.py diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 46ef915c6..e6779be63 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -28,7 +28,7 @@ from paddle import distributed as dist from paddle.io import DataLoader from deepspeech.training import Trainer -from deepspeech.training.gradclip import MyClipGradByGlobalNorm +from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.utils import mp_tools from deepspeech.utils import layer_tools @@ -125,7 +125,8 @@ class DeepSpeech2Trainer(Trainer): layer_tools.print_params(model, self.logger.info) - grad_clip = MyClipGradByGlobalNorm(config.training.global_grad_clip) + grad_clip = ClipGradByGlobalNormWithLog( + config.training.global_grad_clip) lr_scheduler = paddle.optimizer.lr.ExponentialDecay( learning_rate=config.training.lr, gamma=config.training.lr_decay, diff --git a/deepspeech/modules/activation.py b/deepspeech/modules/activation.py index a42bd1e74..72ccb5346 100644 --- a/deepspeech/modules/activation.py +++ b/deepspeech/modules/activation.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union import logging import numpy as np import math +from collections import OrderedDict import paddle from paddle import nn @@ -23,7 +25,7 @@ from paddle.nn import initializer as I logger = logging.getLogger(__name__) -__all__ = ['brelu', "softplus", "gelu_accurate", "gelu", 'Swish'] +__all__ = ['brelu', "glu"] def brelu(x, t_min=0.0, t_max=24.0, name=None): @@ -33,36 +35,180 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None): return x.maximum(t_min).minimum(t_max) -def softplus(x): - """Softplus function.""" - if hasattr(paddle.nn.functional, 'softplus'): - #return paddle.nn.functional.softplus(x.float()).type_as(x) - return paddle.nn.functional.softplus(x) - else: - raise NotImplementedError +# def softplus(x): +# """Softplus function.""" +# if hasattr(paddle.nn.functional, 'softplus'): +# #return paddle.nn.functional.softplus(x.float()).type_as(x) +# return paddle.nn.functional.softplus(x) +# else: +# raise NotImplementedError +# def gelu_accurate(x): +# """Gaussian Error Linear Units (GELU) activation.""" +# # [reference] https://github.com/pytorch/fairseq/blob/e75cff5f2c1d62f12dc911e0bf420025eb1a4e33/fairseq/modules/gelu.py +# if not hasattr(gelu_accurate, "_a"): +# gelu_accurate._a = math.sqrt(2 / math.pi) +# return 0.5 * x * (1 + paddle.tanh(gelu_accurate._a * +# (x + 0.044715 * paddle.pow(x, 3)))) -def gelu_accurate(x): - """Gaussian Error Linear Units (GELU) activation.""" - # [reference] https://github.com/pytorch/fairseq/blob/e75cff5f2c1d62f12dc911e0bf420025eb1a4e33/fairseq/modules/gelu.py - if not hasattr(gelu_accurate, "_a"): - gelu_accurate._a = math.sqrt(2 / math.pi) - return 0.5 * x * (1 + paddle.tanh(gelu_accurate._a * - (x + 0.044715 * paddle.pow(x, 3)))) +# def gelu(x): +# """Gaussian Error Linear Units (GELU) activation.""" +# if hasattr(nn.functional, 'gelu'): +# #return nn.functional.gelu(x.float()).type_as(x) +# return nn.functional.gelu(x) +# else: +# return x * 0.5 * (1.0 + paddle.erf(x / math.sqrt(2.0))) -def gelu(x): - """Gaussian Error Linear Units (GELU) activation.""" - if hasattr(torch.nn.functional, 'gelu'): - #return torch.nn.functional.gelu(x.float()).type_as(x) - return torch.nn.functional.gelu(x) +# TODO(Hui Zhang): remove this activation +def glu(x, dim=-1): + """The gated linear unit (GLU) activation.""" + if hasattr(nn.functional, 'glu'): + return nn.functional.glu(x) else: - return x * 0.5 * (1.0 + paddle.erf(x / math.sqrt(2.0))) + a, b = x.split(2, axis=dim) + act_b = F.sigmoid(b) + return a * act_b + + +# TODO(Hui Zhang): remove this activation +if not hasattr(nn.functional, 'glu'): + setattr(nn.functional, 'glu', glu) + + +# TODO(Hui Zhang): remove this activation +class GLU(nn.Layer): + """Gated Linear Units (GLU) Layer""" + + def __init__(self, dim: int=-1): + super().__init__() + self.dim = dim + + def forward(self, xs): + return glu(xs, dim=self.dim) + + +class LinearGLUBlock(nn.Layer): + """A linear Gated Linear Units (GLU) block.""" + + def __init__(self, idim: int): + """ GLU. + Args: + idim (int): input and output dimension + """ + super().__init__() + self.fc = nn.Linear(idim, idim * 2) + + def forward(self, xs): + return glu(self.fc(xs), dim=-1) + + +# TODO(Hui Zhang): remove this Layer +class ConstantPad2d(nn.Layer): + """Pads the input tensor boundaries with a constant value. + For N-dimensional padding, use paddle.nn.functional.pad(). + """ + + def __init__(self, padding: Union[tuple, list, int], value: float): + """ + Args: + paddle ([tuple]): the size of the padding. + If is int, uses the same padding in all boundaries. + If a 4-tuple, uses (padding_left, padding_right, padding_top, padding_bottom) + value ([flaot]): pad value + """ + self.padding = padding if isinstance(padding, + [tuple, list]) else [padding] * 4 + self.value = value + + def forward(self, xs: paddle.Tensor) -> paddle.Tensor: + return nn.functional.pad( + xs, + self.padding, + mode='constant', + value=self.value, + data_format='NCHW') + + +class ConvGLUBlock(nn.Layer): + def __init__(self, kernel_size, in_ch, out_ch, bottlececk_dim=0, + dropout=0.): + """A convolutional Gated Linear Units (GLU) block. + + Args: + kernel_size (int): kernel size + in_ch (int): number of input channels + out_ch (int): number of output channels + bottlececk_dim (int): dimension of the bottleneck layers for computational efficiency. Defaults to 0. + dropout (float): dropout probability. Defaults to 0.. + """ + + super().__init__() + + self.conv_residual = None + if in_ch != out_ch: + self.conv_residual = nn.utils.weight_norm( + nn.Conv2D( + in_channels=in_ch, out_channels=out_ch, kernel_size=(1, 1)), + name='weight', + dim=0) + self.dropout_residual = nn.Dropout(p=dropout) + + self.pad_left = ConstantPad2d((0, 0, kernel_size - 1, 0), 0) + + layers = OrderedDict() + if bottlececk_dim == 0: + layers['conv'] = nn.utils.weight_norm( + nn.Conv2D( + in_channels=in_ch, + out_channels=out_ch * 2, + kernel_size=(kernel_size, 1)), + name='weight', + dim=0) + # TODO(hirofumi0810): padding? + layers['dropout'] = nn.Dropout(p=dropout) + layers['glu'] = GLU() + elif bottlececk_dim > 0: + layers['conv_in'] = nn.utils.weight_norm( + nn.Conv2D( + in_channels=in_ch, + out_channels=bottlececk_dim, + kernel_size=(1, 1)), + name='weight', + dim=0) + layers['dropout_in'] = nn.Dropout(p=dropout) + layers['conv_bottleneck'] = nn.utils.weight_norm( + nn.Conv2D( + in_channels=bottlececk_dim, + out_channels=bottlececk_dim, + kernel_size=(kernel_size, 1)), + name='weight', + dim=0) + layers['dropout'] = nn.Dropout(p=dropout) + layers['glu'] = GLU() + layers['conv_out'] = nn.utils.weight_norm( + nn.Conv2D( + in_channels=bottlececk_dim, + out_channels=out_ch * 2, + kernel_size=(1, 1)), + name='weight', + dim=0) + layers['dropout_out'] = nn.Dropout(p=dropout) -class Swish(nn.Layer): - """Construct an Swish object.""" + self.layers = nn.Sequential(layers) - def forward(self, x: paddle.Tensor) -> paddle.Tensor: - """Return Swish activation function.""" - return x * F.sigmoid(x) + def forward(self, xs): + """Forward pass. + Args: + xs (FloatTensor): `[B, in_ch, T, feat_dim]` + Returns: + out (FloatTensor): `[B, out_ch, T, feat_dim]` + """ + residual = xs + if self.conv_residual is not None: + residual = self.dropout_residual(self.conv_residual(residual)) + xs = self.pad_left(xs) # `[B, embed_dim, T+kernel-1, 1]` + xs = self.layers(xs) # `[B, out_ch * 2, T ,1]` + xs = xs + residual + return xs diff --git a/deepspeech/modules/conformer_convolution.py b/deepspeech/modules/conformer_convolution.py new file mode 100644 index 000000000..5416bd898 --- /dev/null +++ b/deepspeech/modules/conformer_convolution.py @@ -0,0 +1,149 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ConvolutionModule definition.""" +from typing import Optional, Tuple +from typeguard import check_argument_types + +import logging + +import paddle +from paddle import nn +from paddle.nn import functional as F +from paddle.nn import initializer as I + +# init F.glu func +# TODO(Hui Zhang): remove this line +import deepspeech.modules.activation + +logger = logging.getLogger(__name__) + +__all__ = ['ConvolutionModule'] + + +class ConvolutionModule(nn.Layer): + """ConvolutionModule in Conformer model.""" + + def __init__(self, + channels: int, + kernel_size: int=15, + activation: nn.Layer=nn.ReLU(), + norm: str="batch_norm", + causal: bool=False, + bias: bool=True): + """Construct an ConvolutionModule object. + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernel size of conv layers. + activation (nn.Layer): Activation Layer. + norm (str): Normalization type, 'batch_norm' or 'layer_norm' + causal (bool): Whether use causal convolution or not + bias (bool): Whether Conv with bias or not + """ + assert check_argument_types() + super().__init__() + self.pointwise_conv1 = nn.Conv1D( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=None if bias else False, # None for True as default + ) + + # self.lorder is used to distinguish if it's a causal convolution, + # if self.lorder > 0: + # it's a causal convolution, the input will be padded with + # `self.lorder` frames on the left in forward (causal conv impl). + # else: it's a symmetrical convolution + if causal: + padding = 0 + self.lorder = kernel_size - 1 + else: + # kernel_size should be an odd number for none causal convolution + assert (kernel_size - 1) % 2 == 0 + padding = (kernel_size - 1) // 2 + self.lorder = 0 + + self.depthwise_conv = nn.Conv1D( + channels, + channels, + kernel_size, + stride=1, + padding=padding, + groups=channels, + bias=None if bias else False, # None for True as default + ) + + assert norm in ['batch_norm', 'layer_norm'] + if norm == "batch_norm": + self.use_layer_norm = False + self.norm = nn.BatchNorm1D(channels) + else: + self.use_layer_norm = True + self.norm = nn.LayerNorm(channels) + + self.pointwise_conv2 = nn.Conv1D( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=None if bias else False, # None for True as default + ) + self.activation = activation + + def forward(self, x: paddle.Tensor, cache: Optional[paddle.Tensor]=None + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Compute convolution module. + Args: + x (paddle.Tensor): Input tensor (#batch, time, channels). + cache (paddle.Tensor): left context cache, it is only + used in causal convolution. (#batch, channels, time) + Returns: + paddle.Tensor: Output tensor (#batch, time, channels). + paddle.Tensor: Output cache tensor (#batch, channels, time) + """ + # exchange the temporal dimension and the feature dimension + x = x.transpose([0, 2, 1]) # [B, C, T] + if self.lorder > 0: + if cache is None: + x = nn.functional.pad( + x, (self.lorder, 0), 'constant', 0.0, data_format='NCL') + else: + assert cache.shape[0] == x.shape[0] # B + assert cache.shape[1] == x.shape[1] # C + x = paddle.concat((cache, x), axis=2) + + assert (x.shape[2] > self.lorder) + new_cache = x[:, :, -self.lorder:] #[B, C, T] + else: + # It's better we just return None if no cache is requried, + # However, for JIT export, here we just fake one tensor instead of + # None. + new_cache = paddle.to_tensor([0.0], dtype=x.dtype, place=x.place) + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channel, dim) + x = nn.functional.glu(x, dim=1) # (batch, channel, dim) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + if self.use_layer_norm: + x = x.transpose([0, 2, 1]) # [B, T, C] + x = self.activation(self.norm(x)) + if self.use_layer_norm: + x = x.transpose([0, 2, 1]) # [B, C, T] + x = self.pointwise_conv2(x) + x = x.transpose([0, 2, 1]) # [B, T, C] + return x, new_cache diff --git a/deepspeech/modules/conv.py b/deepspeech/modules/conv.py index 7d64c963d..38134e0c2 100644 --- a/deepspeech/modules/conv.py +++ b/deepspeech/modules/conv.py @@ -24,7 +24,32 @@ from deepspeech.modules.activation import brelu logger = logging.getLogger(__name__) -__all__ = ['ConvStack'] +__all__ = ['ConvStack', "conv_output_size"] + + +def conv_output_size(I, F, P, S): + # https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters + # Output size after Conv: + # By noting I the length of the input volume size, + # F the length of the filter, + # P the amount of zero padding, + # S the stride, + # then the output size O of the feature map along that dimension is given by: + # O = (I - F + Pstart + Pend) // S + 1 + # When Pstart == Pend == P, we can replace Pstart + Pend by 2P. + # When Pstart == Pend == 0 + # O = (I - F - S) // S + # https://iq.opengenus.org/output-size-of-convolution/ + # Output height = (Input height + padding height top + padding height bottom - kernel height) / (stride height) + 1 + # Output width = (Output width + padding width right + padding width left - kernel width) / (stride width) + 1 + return (I - F + 2 * P - S) // S + + +# receptive field calculator +# https://fomoro.com/research/article/receptive-field-calculator +# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters +# https://distill.pub/2019/computing-receptive-fields/ +# Rl-1 = Sl * Rl + (Kl - Sl) class ConvBn(nn.Layer): diff --git a/deepspeech/modules/encoder_layer.py b/deepspeech/modules/encoder_layer.py new file mode 100644 index 000000000..bd117b976 --- /dev/null +++ b/deepspeech/modules/encoder_layer.py @@ -0,0 +1,277 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Encoder self-attention layer definition.""" +from typing import Optional, Tuple +import logging + +import paddle +from paddle import nn +from paddle.nn import functional as F +from paddle.nn import initializer as I + +logger = logging.getLogger(__name__) + +__all__ = ["TransformerEncoderLayer", "ConformerEncoderLayer"] + + +class TransformerEncoderLayer(nn.Layer): + """Encoder layer module.""" + + def __init__( + self, + size: int, + self_attn: nn.Layer, + feed_forward: nn.Layer, + dropout_rate: float, + normalize_before: bool=True, + concat_after: bool=False, ): + """Construct an EncoderLayer object. + + Args: + size (int): Input dimension. + self_attn (nn.Layer): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + instance can be used as the argument. + feed_forward (nn.Layer): Feed-forward module instance. + `PositionwiseFeedForward`, instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): + True: use layer_norm before each sub-block. + False: to use layer_norm after each sub-block. + concat_after (bool): Whether to concat attention layer's input and + output. + True: x -> x + linear(concat(x, att(x))) + False: x -> x + att(x) + """ + super().__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.norm1 = nn.LayerNorm(size, epsilon=1e-12) + self.norm2 = nn.LayerNorm(size, epsilon=1e-12) + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + self.concat_after = concat_after + # concat_linear may be not used in forward fuction, + # but will be saved in the *.pt + self.concat_linear = nn.Linear(size + size, size) + + def forward( + self, + x: paddle.Tensor, + mask: paddle.Tensor, + pos_emb: paddle.Tensor, + output_cache: Optional[paddle.Tensor]=None, + cnn_cache: Optional[paddle.Tensor]=None, + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Compute encoded features. + Args: + x (paddle.Tensor): Input tensor (#batch, time, size). + mask (paddle.Tensor): Mask tensor for the input (#batch, time). + pos_emb (paddle.Tensor): just for interface compatibility + to ConformerEncoderLayer + output_cache (paddle.Tensor): Cache tensor of the output + (#batch, time2, size), time2 < time in x. + cnn_cache (paddle.Tensor): not used here, it's for interface + compatibility to ConformerEncoderLayer + Returns: + paddle.Tensor: Output tensor (#batch, time, size). + paddle.Tensor: Mask tensor (#batch, time). + """ + residual = x + if self.normalize_before: + x = self.norm1(x) + + if output_cache is None: + x_q = x + else: + assert output_cache.shape[0] == x.shape[0] + assert output_cache.shape[1] < x.shape[1] + assert output_cache.shape[2] == self.size + chunk = x.shape[1] - output_cache.shape[1] + x_q = x[:, -chunk:, :] + residual = residual[:, -chunk:, :] + mask = mask[:, -chunk:, :] + + if self.concat_after: + x_concat = paddle.concat( + (x, self.self_attn(x_q, x, x, mask)), axis=-1) + x = residual + self.concat_linear(x_concat) + else: + x = residual + self.dropout(self.self_attn(x_q, x, x, mask)) + if not self.normalize_before: + x = self.norm1(x) + + residual = x + if self.normalize_before: + x = self.norm2(x) + x = residual + self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm2(x) + + if output_cache is not None: + x = paddle.concat([output_cache, x], axis=1) + + fake_cnn_cache = paddle.to_tensor([0.0], dtype=x.dtype, place=x.place) + return x, mask, fake_cnn_cache + + +class ConformerEncoderLayer(nn.Layer): + """Encoder layer module.""" + + def __init__( + self, + size: int, + self_attn: int, + feed_forward: Optional[nn.Layer]=None, + feed_forward_macaron: Optional[nn.Layer]=None, + conv_module: Optional[nn.Layer]=None, + dropout_rate: float=0.1, + normalize_before: bool=True, + concat_after: bool=False, ): + """Construct an EncoderLayer object. + + Args: + size (int): Input dimension. + self_attn (nn.Layer): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + instance can be used as the argument. + feed_forward (nn.Layer): Feed-forward module instance. + `PositionwiseFeedForward` instance can be used as the argument. + feed_forward_macaron (nn.Layer): Additional feed-forward module + instance. + `PositionwiseFeedForward` instance can be used as the argument. + conv_module (nn.Layer): Convolution module instance. + `ConvlutionModule` instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): + True: use layer_norm before each sub-block. + False: use layer_norm after each sub-block. + concat_after (bool): Whether to concat attention layer's input and + output. + True: x -> x + linear(concat(x, att(x))) + False: x -> x + att(x) + """ + super().__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.feed_forward_macaron = feed_forward_macaron + self.conv_module = conv_module + self.norm_ff = nn.LayerNorm(size, epsilon=1e-12) # for the FNN module + self.norm_mha = nn.LayerNorm(size, epsilon=1e-12) # for the MHA module + if feed_forward_macaron is not None: + self.norm_ff_macaron = nn.LayerNorm(size, epsilon=1e-12) + self.ff_scale = 0.5 + else: + self.ff_scale = 1.0 + if self.conv_module is not None: + self.norm_conv = nn.LayerNorm( + size, epsilon=1e-12) # for the CNN module + self.norm_final = nn.LayerNorm( + size, epsilon=1e-12) # for the final output of the block + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + self.concat_after = concat_after + self.concat_linear = nn.Linear(size + size, size) + + def forward( + self, + x: paddle.Tensor, + mask: paddle.Tensor, + pos_emb: paddle.Tensor, + output_cache: Optional[paddle.Tensor]=None, + cnn_cache: Optional[paddle.Tensor]=None, + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Compute encoded features. + Args: + x (paddle.Tensor): (#batch, time, size) + mask (paddle.Tensor): Mask tensor for the input (#batch, time,time). + pos_emb (paddle.Tensor): positional encoding, must not be None + for ConformerEncoderLayer. + output_cache (paddle.Tensor): Cache tensor of the output + (#batch, time2, size), time2 < time in x. + cnn_cache (paddle.Tensor): Convolution cache in conformer layer + Returns: + paddle.Tensor: Output tensor (#batch, time, size). + paddle.Tensor: Mask tensor (#batch, time). + """ + # whether to use macaron style FFN + if self.feed_forward_macaron is not None: + residual = x + if self.normalize_before: + x = self.norm_ff_macaron(x) + x = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(x)) + if not self.normalize_before: + x = self.norm_ff_macaron(x) + + # multi-headed self-attention module + residual = x + if self.normalize_before: + x = self.norm_mha(x) + + if output_cache is None: + x_q = x + else: + assert output_cache.shape[0] == x.shape[0] + assert output_cache.shape[1] < x.shape[1] + assert output_cache.shape[2] == self.size + chunk = x.shape[1] - output_cache.shape[1] + x_q = x[:, -chunk:, :] + residual = residual[:, -chunk:, :] + mask = mask[:, -chunk:, :] + + x_att = self.self_attn(x_q, x, x, pos_emb, mask) + + if self.concat_after: + x_concat = paddle.concat((x, x_att), axis=-1) + x = residual + self.concat_linear(x_concat) + else: + x = residual + self.dropout(x_att) + + if not self.normalize_before: + x = self.norm_mha(x) + + # convolution module + # Fake new cnn cache here, and then change it in conv_module + new_cnn_cache = paddle.to_tensor([0.0], dtype=x.dtype, place=x.place) + if self.conv_module is not None: + residual = x + if self.normalize_before: + x = self.norm_conv(x) + + x, new_cnn_cache = self.conv_module(x, cnn_cache) + x = residual + self.dropout(x) + + if not self.normalize_before: + x = self.norm_conv(x) + + # feed forward module + residual = x + if self.normalize_before: + x = self.norm_ff(x) + + x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) + + if not self.normalize_before: + x = self.norm_ff(x) + + if self.conv_module is not None: + x = self.norm_final(x) + + if output_cache is not None: + x = paddle.concat([output_cache, x], axis=1) + + return x, mask, new_cnn_cache diff --git a/deepspeech/modules/positionwise_feed_forward.py b/deepspeech/modules/positionwise_feed_forward.py new file mode 100644 index 000000000..89cf60331 --- /dev/null +++ b/deepspeech/modules/positionwise_feed_forward.py @@ -0,0 +1,59 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Positionwise feed forward layer definition.""" +import logging + +import paddle +from paddle import nn +from paddle.nn import functional as F +from paddle.nn import initializer as I + +logger = logging.getLogger(__name__) + +__all__ = ["PositionwiseFeedForward"] + + +class PositionwiseFeedForward(nn.Layer): + """Positionwise feed forward layer.""" + + def __init__(self, + idim: int, + hidden_units: int, + dropout_rate: float, + activation: nn.Layer=nn.ReLU()): + """Construct a PositionwiseFeedForward object. + + FeedForward are appied on each position of the sequence. + The output dim is same with the input dim. + + Args: + idim (int): Input dimenstion. + hidden_units (int): The number of hidden units. + dropout_rate (float): Dropout rate. + activation (paddle.nn.Layer): Activation function + """ + super().__init__() + self.w_1 = nn.Linear(idim, hidden_units) + self.activation = activation + self.dropout = nn.Dropout(dropout_rate) + self.w_2 = nn.Linear(hidden_units, idim) + + def forward(self, xs: paddle.Tensor) -> paddle.Tensor: + """Forward function. + Args: + xs: input tensor (B, Lmax, D) + Returns: + output tensor, (B, Lmax, D) + """ + return self.w_2(self.dropout(self.activation(self.w_1(xs)))) diff --git a/deepspeech/modules/subsampling.py b/deepspeech/modules/subsampling.py new file mode 100644 index 000000000..a01374d71 --- /dev/null +++ b/deepspeech/modules/subsampling.py @@ -0,0 +1,235 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Subsampling layer definition.""" + +from typing import Tuple +import logging + +import paddle +from paddle import nn +from paddle.nn import functional as F +from paddle.nn import initializer as I + +from deepspeech.modules.embedding import PositionalEncoding + +logger = logging.getLogger(__name__) + +__all__ = [ + "LinearNoSubsampling", "Conv2dSubsampling4", "Conv2dSubsampling6", + "Conv2dSubsampling8" +] + + +class BaseSubsampling(nn.Layer): + def __init__(self, pos_enc_class: PositionalEncoding): + super().__init__() + self.pos_enc = pos_enc_class + self.right_context = 0 + self.subsampling_rate = 1 + + def position_encoding(self, offset: int, size: int) -> paddle.Tensor: + return self.pos_enc.position_encoding(offset, size) + + +class LinearNoSubsampling(BaseSubsampling): + """Linear transform the input without subsampling.""" + + def __init__(self, + idim: int, + odim: int, + dropout_rate: float, + pos_enc_class: PositionalEncoding): + """Construct an linear object. + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + pos_enc_class (PositionalEncoding): position encoding class + """ + super().__init__(pos_enc_class) + self.out = nn.Sequential( + nn.Linear(idim, odim), + nn.LayerNorm(odim, epsilon=1e-12), + nn.Dropout(dropout_rate), ) + self.right_context = 0 + self.subsampling_rate = 1 + + def forward(self, x: paddle.Tensor, x_mask: paddle.Tensor, offset: int=0 + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Input x. + Args: + x (paddle.Tensor): Input tensor (#batch, time, idim). + x_mask (paddle.Tensor): Input mask (#batch, 1, time). + Returns: + paddle.Tensor: linear input tensor (#batch, time', odim), + where time' = time . + paddle.Tensor: positional encoding + paddle.Tensor: linear input mask (#batch, 1, time'), + where time' = time . + """ + x = self.out(x) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask + + +class Conv2dSubsampling4(BaseSubsampling): + """Convolutional 2D subsampling (to 1/4 length).""" + + def __init__(self, + idim: int, + odim: int, + dropout_rate: float, + pos_enc_class: PositionalEncoding): + """Construct an Conv2dSubsampling4 object. + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + """ + super().__init__(pos_enc_class) + self.conv = nn.Sequential( + nn.Conv2D(1, odim, 3, 2), + nn.ReLU(), + nn.Conv2D(odim, odim, 3, 2), + nn.ReLU(), ) + self.linear = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) + self.subsampling_rate = 4 + # The right context for every conv layer is computed by: + # (kernel_size - 1) / 2 * stride * frame_rate_of_this_layer + # 6 = (3 - 1) / 2 * 2 * 1 + (3 - 1) / 2 * 2 * 2 + self.right_context = 6 + + def forward(self, x: paddle.Tensor, x_mask: paddle.Tensor, offset: int=0 + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Subsample x. + Args: + x (paddle.Tensor): Input tensor (#batch, time, idim). + x_mask (paddle.Tensor): Input mask (#batch, 1, time). + Returns: + paddle.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 4. + paddle.Tensor: positional encoding + paddle.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 4. + """ + x = x.unsqueeze(1) # (b, c=1, t, f) + x = self.conv(x) + b, c, t, f = paddle.shape(x) + x = self.linear(x.transpose([0, 1, 2, 3]).reshape([b, t, c * f])) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2] + + +class Conv2dSubsampling6(BaseSubsampling): + """Convolutional 2D subsampling (to 1/6 length).""" + + def __init__(self, + idim: int, + odim: int, + dropout_rate: float, + pos_enc_class: PositionalEncoding): + """Construct an Conv2dSubsampling6 object. + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + pos_enc (PositionalEncoding): Custom position encoding layer. + """ + super().__init__(pos_enc_class) + self.conv = nn.Sequential( + nn.Conv2D(1, odim, 3, 2), + nn.ReLU(), + nn.Conv2D(odim, odim, 5, 3), + nn.ReLU(), ) + # O = (I - F + Pstart + Pend) // S + 1 + # when Padding == 0, O = (I - F - S) // S + self.linear = nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim) + # The right context for every conv layer is computed by: + # (kernel_size - 1) / 2 * stride * frame_rate_of_this_layer + # 14 = (3 - 1) / 2 * 2 * 1 + (5 - 1) / 2 * 3 * 2 + self.subsampling_rate = 6 + self.right_context = 14 + + def forward(self, x: paddle.Tensor, x_mask: paddle.Tensor, offset: int=0 + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Subsample x. + Args: + x (paddle.Tensor): Input tensor (#batch, time, idim). + x_mask (paddle.Tensor): Input mask (#batch, 1, time). + Returns: + paddle.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 6. + paddle.Tensor: positional encoding + paddle.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 6. + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = paddle.shape(x) + x = self.linear(x.transpose([0, 1, 2, 3]).reshape([b, t, c * f])) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-4:3] + + +class Conv2dSubsampling8(BaseSubsampling): + """Convolutional 2D subsampling (to 1/8 length).""" + + def __init__(self, + idim: int, + odim: int, + dropout_rate: float, + pos_enc_class: PositionalEncoding): + """Construct an Conv2dSubsampling8 object. + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + """ + super().__init__(pos_enc_class) + self.conv = nn.Sequential( + nn.Conv2D(1, odim, 3, 2), + nn.ReLU(), + nn.Conv2D(odim, odim, 3, 2), + nn.ReLU(), + nn.Conv2D(odim, odim, 3, 2), + nn.ReLU(), ) + self.linear = nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), + odim) + self.subsampling_rate = 8 + # The right context for every conv layer is computed by: + # (kernel_size - 1) / 2 * stride * frame_rate_of_this_layer + # 14 = (3 - 1) / 2 * 2 * 1 + (3 - 1) / 2 * 2 * 2 + (3 - 1) / 2 * 2 * 4 + self.right_context = 14 + + def forward(self, x: paddle.Tensor, x_mask: paddle.Tensor, offset: int=0 + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Subsample x. + Args: + x (paddle.Tensor): Input tensor (#batch, time, idim). + x_mask (paddle.Tensor): Input mask (#batch, 1, time). + Returns: + paddle.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 8. + paddle.Tensor: positional encoding + paddle.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 8. + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + x = self.linear(x.transpose([0, 1, 2, 3]).reshape([b, t, c * f])) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2] diff --git a/deepspeech/training/gradclip.py b/deepspeech/training/gradclip.py index 1693b76df..5a090796f 100644 --- a/deepspeech/training/gradclip.py +++ b/deepspeech/training/gradclip.py @@ -21,8 +21,10 @@ from paddle.fluid import core logger = logging.getLogger(__name__) +__all__ = ["ClipGradByGlobalNormWithLog"] -class MyClipGradByGlobalNorm(paddle.nn.ClipGradByGlobalNorm): + +class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm): def __init__(self, clip_norm): super().__init__(clip_norm)