parent
9cf8c1a5db
commit
b2bc6eb526
@ -0,0 +1,149 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""ConvolutionModule definition."""
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
from typeguard import check_argument_types
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
from paddle.nn import functional as F
|
||||||
|
from paddle.nn import initializer as I
|
||||||
|
|
||||||
|
# init F.glu func
|
||||||
|
# TODO(Hui Zhang): remove this line
|
||||||
|
import deepspeech.modules.activation
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
__all__ = ['ConvolutionModule']
|
||||||
|
|
||||||
|
|
||||||
|
class ConvolutionModule(nn.Layer):
|
||||||
|
"""ConvolutionModule in Conformer model."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
channels: int,
|
||||||
|
kernel_size: int=15,
|
||||||
|
activation: nn.Layer=nn.ReLU(),
|
||||||
|
norm: str="batch_norm",
|
||||||
|
causal: bool=False,
|
||||||
|
bias: bool=True):
|
||||||
|
"""Construct an ConvolutionModule object.
|
||||||
|
Args:
|
||||||
|
channels (int): The number of channels of conv layers.
|
||||||
|
kernel_size (int): Kernel size of conv layers.
|
||||||
|
activation (nn.Layer): Activation Layer.
|
||||||
|
norm (str): Normalization type, 'batch_norm' or 'layer_norm'
|
||||||
|
causal (bool): Whether use causal convolution or not
|
||||||
|
bias (bool): Whether Conv with bias or not
|
||||||
|
"""
|
||||||
|
assert check_argument_types()
|
||||||
|
super().__init__()
|
||||||
|
self.pointwise_conv1 = nn.Conv1D(
|
||||||
|
channels,
|
||||||
|
2 * channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bias=None if bias else False, # None for True as default
|
||||||
|
)
|
||||||
|
|
||||||
|
# self.lorder is used to distinguish if it's a causal convolution,
|
||||||
|
# if self.lorder > 0:
|
||||||
|
# it's a causal convolution, the input will be padded with
|
||||||
|
# `self.lorder` frames on the left in forward (causal conv impl).
|
||||||
|
# else: it's a symmetrical convolution
|
||||||
|
if causal:
|
||||||
|
padding = 0
|
||||||
|
self.lorder = kernel_size - 1
|
||||||
|
else:
|
||||||
|
# kernel_size should be an odd number for none causal convolution
|
||||||
|
assert (kernel_size - 1) % 2 == 0
|
||||||
|
padding = (kernel_size - 1) // 2
|
||||||
|
self.lorder = 0
|
||||||
|
|
||||||
|
self.depthwise_conv = nn.Conv1D(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=padding,
|
||||||
|
groups=channels,
|
||||||
|
bias=None if bias else False, # None for True as default
|
||||||
|
)
|
||||||
|
|
||||||
|
assert norm in ['batch_norm', 'layer_norm']
|
||||||
|
if norm == "batch_norm":
|
||||||
|
self.use_layer_norm = False
|
||||||
|
self.norm = nn.BatchNorm1D(channels)
|
||||||
|
else:
|
||||||
|
self.use_layer_norm = True
|
||||||
|
self.norm = nn.LayerNorm(channels)
|
||||||
|
|
||||||
|
self.pointwise_conv2 = nn.Conv1D(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bias=None if bias else False, # None for True as default
|
||||||
|
)
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
|
def forward(self, x: paddle.Tensor, cache: Optional[paddle.Tensor]=None
|
||||||
|
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||||
|
"""Compute convolution module.
|
||||||
|
Args:
|
||||||
|
x (paddle.Tensor): Input tensor (#batch, time, channels).
|
||||||
|
cache (paddle.Tensor): left context cache, it is only
|
||||||
|
used in causal convolution. (#batch, channels, time)
|
||||||
|
Returns:
|
||||||
|
paddle.Tensor: Output tensor (#batch, time, channels).
|
||||||
|
paddle.Tensor: Output cache tensor (#batch, channels, time)
|
||||||
|
"""
|
||||||
|
# exchange the temporal dimension and the feature dimension
|
||||||
|
x = x.transpose([0, 2, 1]) # [B, C, T]
|
||||||
|
if self.lorder > 0:
|
||||||
|
if cache is None:
|
||||||
|
x = nn.functional.pad(
|
||||||
|
x, (self.lorder, 0), 'constant', 0.0, data_format='NCL')
|
||||||
|
else:
|
||||||
|
assert cache.shape[0] == x.shape[0] # B
|
||||||
|
assert cache.shape[1] == x.shape[1] # C
|
||||||
|
x = paddle.concat((cache, x), axis=2)
|
||||||
|
|
||||||
|
assert (x.shape[2] > self.lorder)
|
||||||
|
new_cache = x[:, :, -self.lorder:] #[B, C, T]
|
||||||
|
else:
|
||||||
|
# It's better we just return None if no cache is requried,
|
||||||
|
# However, for JIT export, here we just fake one tensor instead of
|
||||||
|
# None.
|
||||||
|
new_cache = paddle.to_tensor([0.0], dtype=x.dtype, place=x.place)
|
||||||
|
|
||||||
|
# GLU mechanism
|
||||||
|
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
|
||||||
|
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
|
||||||
|
|
||||||
|
# 1D Depthwise Conv
|
||||||
|
x = self.depthwise_conv(x)
|
||||||
|
if self.use_layer_norm:
|
||||||
|
x = x.transpose([0, 2, 1]) # [B, T, C]
|
||||||
|
x = self.activation(self.norm(x))
|
||||||
|
if self.use_layer_norm:
|
||||||
|
x = x.transpose([0, 2, 1]) # [B, C, T]
|
||||||
|
x = self.pointwise_conv2(x)
|
||||||
|
x = x.transpose([0, 2, 1]) # [B, T, C]
|
||||||
|
return x, new_cache
|
@ -0,0 +1,59 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Positionwise feed forward layer definition."""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
from paddle.nn import functional as F
|
||||||
|
from paddle.nn import initializer as I
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
__all__ = ["PositionwiseFeedForward"]
|
||||||
|
|
||||||
|
|
||||||
|
class PositionwiseFeedForward(nn.Layer):
|
||||||
|
"""Positionwise feed forward layer."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
idim: int,
|
||||||
|
hidden_units: int,
|
||||||
|
dropout_rate: float,
|
||||||
|
activation: nn.Layer=nn.ReLU()):
|
||||||
|
"""Construct a PositionwiseFeedForward object.
|
||||||
|
|
||||||
|
FeedForward are appied on each position of the sequence.
|
||||||
|
The output dim is same with the input dim.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idim (int): Input dimenstion.
|
||||||
|
hidden_units (int): The number of hidden units.
|
||||||
|
dropout_rate (float): Dropout rate.
|
||||||
|
activation (paddle.nn.Layer): Activation function
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.w_1 = nn.Linear(idim, hidden_units)
|
||||||
|
self.activation = activation
|
||||||
|
self.dropout = nn.Dropout(dropout_rate)
|
||||||
|
self.w_2 = nn.Linear(hidden_units, idim)
|
||||||
|
|
||||||
|
def forward(self, xs: paddle.Tensor) -> paddle.Tensor:
|
||||||
|
"""Forward function.
|
||||||
|
Args:
|
||||||
|
xs: input tensor (B, Lmax, D)
|
||||||
|
Returns:
|
||||||
|
output tensor, (B, Lmax, D)
|
||||||
|
"""
|
||||||
|
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
|
@ -0,0 +1,235 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Subsampling layer definition."""
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
from paddle.nn import functional as F
|
||||||
|
from paddle.nn import initializer as I
|
||||||
|
|
||||||
|
from deepspeech.modules.embedding import PositionalEncoding
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"LinearNoSubsampling", "Conv2dSubsampling4", "Conv2dSubsampling6",
|
||||||
|
"Conv2dSubsampling8"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSubsampling(nn.Layer):
|
||||||
|
def __init__(self, pos_enc_class: PositionalEncoding):
|
||||||
|
super().__init__()
|
||||||
|
self.pos_enc = pos_enc_class
|
||||||
|
self.right_context = 0
|
||||||
|
self.subsampling_rate = 1
|
||||||
|
|
||||||
|
def position_encoding(self, offset: int, size: int) -> paddle.Tensor:
|
||||||
|
return self.pos_enc.position_encoding(offset, size)
|
||||||
|
|
||||||
|
|
||||||
|
class LinearNoSubsampling(BaseSubsampling):
|
||||||
|
"""Linear transform the input without subsampling."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
idim: int,
|
||||||
|
odim: int,
|
||||||
|
dropout_rate: float,
|
||||||
|
pos_enc_class: PositionalEncoding):
|
||||||
|
"""Construct an linear object.
|
||||||
|
Args:
|
||||||
|
idim (int): Input dimension.
|
||||||
|
odim (int): Output dimension.
|
||||||
|
dropout_rate (float): Dropout rate.
|
||||||
|
pos_enc_class (PositionalEncoding): position encoding class
|
||||||
|
"""
|
||||||
|
super().__init__(pos_enc_class)
|
||||||
|
self.out = nn.Sequential(
|
||||||
|
nn.Linear(idim, odim),
|
||||||
|
nn.LayerNorm(odim, epsilon=1e-12),
|
||||||
|
nn.Dropout(dropout_rate), )
|
||||||
|
self.right_context = 0
|
||||||
|
self.subsampling_rate = 1
|
||||||
|
|
||||||
|
def forward(self, x: paddle.Tensor, x_mask: paddle.Tensor, offset: int=0
|
||||||
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
||||||
|
"""Input x.
|
||||||
|
Args:
|
||||||
|
x (paddle.Tensor): Input tensor (#batch, time, idim).
|
||||||
|
x_mask (paddle.Tensor): Input mask (#batch, 1, time).
|
||||||
|
Returns:
|
||||||
|
paddle.Tensor: linear input tensor (#batch, time', odim),
|
||||||
|
where time' = time .
|
||||||
|
paddle.Tensor: positional encoding
|
||||||
|
paddle.Tensor: linear input mask (#batch, 1, time'),
|
||||||
|
where time' = time .
|
||||||
|
"""
|
||||||
|
x = self.out(x)
|
||||||
|
x, pos_emb = self.pos_enc(x, offset)
|
||||||
|
return x, pos_emb, x_mask
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2dSubsampling4(BaseSubsampling):
|
||||||
|
"""Convolutional 2D subsampling (to 1/4 length)."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
idim: int,
|
||||||
|
odim: int,
|
||||||
|
dropout_rate: float,
|
||||||
|
pos_enc_class: PositionalEncoding):
|
||||||
|
"""Construct an Conv2dSubsampling4 object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idim (int): Input dimension.
|
||||||
|
odim (int): Output dimension.
|
||||||
|
dropout_rate (float): Dropout rate.
|
||||||
|
"""
|
||||||
|
super().__init__(pos_enc_class)
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
nn.Conv2D(1, odim, 3, 2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2D(odim, odim, 3, 2),
|
||||||
|
nn.ReLU(), )
|
||||||
|
self.linear = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
||||||
|
self.subsampling_rate = 4
|
||||||
|
# The right context for every conv layer is computed by:
|
||||||
|
# (kernel_size - 1) / 2 * stride * frame_rate_of_this_layer
|
||||||
|
# 6 = (3 - 1) / 2 * 2 * 1 + (3 - 1) / 2 * 2 * 2
|
||||||
|
self.right_context = 6
|
||||||
|
|
||||||
|
def forward(self, x: paddle.Tensor, x_mask: paddle.Tensor, offset: int=0
|
||||||
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
||||||
|
"""Subsample x.
|
||||||
|
Args:
|
||||||
|
x (paddle.Tensor): Input tensor (#batch, time, idim).
|
||||||
|
x_mask (paddle.Tensor): Input mask (#batch, 1, time).
|
||||||
|
Returns:
|
||||||
|
paddle.Tensor: Subsampled tensor (#batch, time', odim),
|
||||||
|
where time' = time // 4.
|
||||||
|
paddle.Tensor: positional encoding
|
||||||
|
paddle.Tensor: Subsampled mask (#batch, 1, time'),
|
||||||
|
where time' = time // 4.
|
||||||
|
"""
|
||||||
|
x = x.unsqueeze(1) # (b, c=1, t, f)
|
||||||
|
x = self.conv(x)
|
||||||
|
b, c, t, f = paddle.shape(x)
|
||||||
|
x = self.linear(x.transpose([0, 1, 2, 3]).reshape([b, t, c * f]))
|
||||||
|
x, pos_emb = self.pos_enc(x, offset)
|
||||||
|
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2]
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2dSubsampling6(BaseSubsampling):
|
||||||
|
"""Convolutional 2D subsampling (to 1/6 length)."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
idim: int,
|
||||||
|
odim: int,
|
||||||
|
dropout_rate: float,
|
||||||
|
pos_enc_class: PositionalEncoding):
|
||||||
|
"""Construct an Conv2dSubsampling6 object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idim (int): Input dimension.
|
||||||
|
odim (int): Output dimension.
|
||||||
|
dropout_rate (float): Dropout rate.
|
||||||
|
pos_enc (PositionalEncoding): Custom position encoding layer.
|
||||||
|
"""
|
||||||
|
super().__init__(pos_enc_class)
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
nn.Conv2D(1, odim, 3, 2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2D(odim, odim, 5, 3),
|
||||||
|
nn.ReLU(), )
|
||||||
|
# O = (I - F + Pstart + Pend) // S + 1
|
||||||
|
# when Padding == 0, O = (I - F - S) // S
|
||||||
|
self.linear = nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim)
|
||||||
|
# The right context for every conv layer is computed by:
|
||||||
|
# (kernel_size - 1) / 2 * stride * frame_rate_of_this_layer
|
||||||
|
# 14 = (3 - 1) / 2 * 2 * 1 + (5 - 1) / 2 * 3 * 2
|
||||||
|
self.subsampling_rate = 6
|
||||||
|
self.right_context = 14
|
||||||
|
|
||||||
|
def forward(self, x: paddle.Tensor, x_mask: paddle.Tensor, offset: int=0
|
||||||
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
||||||
|
"""Subsample x.
|
||||||
|
Args:
|
||||||
|
x (paddle.Tensor): Input tensor (#batch, time, idim).
|
||||||
|
x_mask (paddle.Tensor): Input mask (#batch, 1, time).
|
||||||
|
Returns:
|
||||||
|
paddle.Tensor: Subsampled tensor (#batch, time', odim),
|
||||||
|
where time' = time // 6.
|
||||||
|
paddle.Tensor: positional encoding
|
||||||
|
paddle.Tensor: Subsampled mask (#batch, 1, time'),
|
||||||
|
where time' = time // 6.
|
||||||
|
"""
|
||||||
|
x = x.unsqueeze(1) # (b, c, t, f)
|
||||||
|
x = self.conv(x)
|
||||||
|
b, c, t, f = paddle.shape(x)
|
||||||
|
x = self.linear(x.transpose([0, 1, 2, 3]).reshape([b, t, c * f]))
|
||||||
|
x, pos_emb = self.pos_enc(x, offset)
|
||||||
|
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-4:3]
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2dSubsampling8(BaseSubsampling):
|
||||||
|
"""Convolutional 2D subsampling (to 1/8 length)."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
idim: int,
|
||||||
|
odim: int,
|
||||||
|
dropout_rate: float,
|
||||||
|
pos_enc_class: PositionalEncoding):
|
||||||
|
"""Construct an Conv2dSubsampling8 object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idim (int): Input dimension.
|
||||||
|
odim (int): Output dimension.
|
||||||
|
dropout_rate (float): Dropout rate.
|
||||||
|
"""
|
||||||
|
super().__init__(pos_enc_class)
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
nn.Conv2D(1, odim, 3, 2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2D(odim, odim, 3, 2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2D(odim, odim, 3, 2),
|
||||||
|
nn.ReLU(), )
|
||||||
|
self.linear = nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2),
|
||||||
|
odim)
|
||||||
|
self.subsampling_rate = 8
|
||||||
|
# The right context for every conv layer is computed by:
|
||||||
|
# (kernel_size - 1) / 2 * stride * frame_rate_of_this_layer
|
||||||
|
# 14 = (3 - 1) / 2 * 2 * 1 + (3 - 1) / 2 * 2 * 2 + (3 - 1) / 2 * 2 * 4
|
||||||
|
self.right_context = 14
|
||||||
|
|
||||||
|
def forward(self, x: paddle.Tensor, x_mask: paddle.Tensor, offset: int=0
|
||||||
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
||||||
|
"""Subsample x.
|
||||||
|
Args:
|
||||||
|
x (paddle.Tensor): Input tensor (#batch, time, idim).
|
||||||
|
x_mask (paddle.Tensor): Input mask (#batch, 1, time).
|
||||||
|
Returns:
|
||||||
|
paddle.Tensor: Subsampled tensor (#batch, time', odim),
|
||||||
|
where time' = time // 8.
|
||||||
|
paddle.Tensor: positional encoding
|
||||||
|
paddle.Tensor: Subsampled mask (#batch, 1, time'),
|
||||||
|
where time' = time // 8.
|
||||||
|
"""
|
||||||
|
x = x.unsqueeze(1) # (b, c, t, f)
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.linear(x.transpose([0, 1, 2, 3]).reshape([b, t, c * f]))
|
||||||
|
x, pos_emb = self.pos_enc(x, offset)
|
||||||
|
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2]
|
Loading…
Reference in new issue