diff --git a/deepspeech/models/ds2/__init__.py b/deepspeech/models/ds2/__init__.py index de78ebe9..39bea5bf 100644 --- a/deepspeech/models/ds2/__init__.py +++ b/deepspeech/models/ds2/__init__.py @@ -1,4 +1,17 @@ -from .deepspeech2 import DeepSpeech2Model +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from .deepspeech2 import DeepSpeech2InferModel +from .deepspeech2 import DeepSpeech2Model __all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel'] diff --git a/deepspeech/models/ds2/deepspeech2.py b/deepspeech/models/ds2/deepspeech2.py index 4026c89a..8d737e80 100644 --- a/deepspeech/models/ds2/deepspeech2.py +++ b/deepspeech/models/ds2/deepspeech2.py @@ -19,8 +19,8 @@ from paddle import nn from yacs.config import CfgNode from deepspeech.models.ds2.conv import ConvStack -from deepspeech.modules.ctc import CTCDecoder from deepspeech.models.ds2.rnn import RNNStack +from deepspeech.modules.ctc import CTCDecoder from deepspeech.utils import layer_tools from deepspeech.utils.checkpoint import Checkpoint from deepspeech.utils.log import Log diff --git a/deepspeech/models/ds2_online/__init__.py b/deepspeech/models/ds2_online/__init__.py index 88076667..255000ee 100644 --- a/deepspeech/models/ds2_online/__init__.py +++ b/deepspeech/models/ds2_online/__init__.py @@ -1,7 +1,17 @@ -from .deepspeech2 import DeepSpeech2ModelOnline +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from .deepspeech2 import DeepSpeech2InferModelOnline +from .deepspeech2 import DeepSpeech2ModelOnline __all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModelOnline'] - - - diff --git a/deepspeech/models/ds2_online/conv.py b/deepspeech/models/ds2_online/conv.py index 8bf48b2c..13c3d330 100644 --- a/deepspeech/models/ds2_online/conv.py +++ b/deepspeech/models/ds2_online/conv.py @@ -11,162 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import paddle from paddle import nn -from paddle.nn import functional as F -from deepspeech.modules.activation import brelu -from deepspeech.modules.mask import make_non_pad_mask -from deepspeech.utils.log import Log +from deepspeech.modules.embedding import PositionalEncoding +from deepspeech.modules.subsampling import Conv2dSubsampling4 -logger = Log(__name__).getlog() -__all__ = ['ConvStack', "conv_output_size"] +class Conv2dSubsampling4Online(Conv2dSubsampling4): + def __init__(self, + idim: int, + odim: int, + dropout_rate: float, + pos_enc_class: nn.Layer=PositionalEncoding): + super().__init__(idim, odim, dropout_rate, pos_enc_class) + self.output_dim = ((idim - 1) // 2 - 1) // 2 * odim - -def conv_output_size(I, F, P, S): - # https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters - # Output size after Conv: - # By noting I the length of the input volume size, - # F the length of the filter, - # P the amount of zero padding, - # S the stride, - # then the output size O of the feature map along that dimension is given by: - # O = (I - F + Pstart + Pend) // S + 1 - # When Pstart == Pend == P, we can replace Pstart + Pend by 2P. - # When Pstart == Pend == 0 - # O = (I - F - S) // S - # https://iq.opengenus.org/output-size-of-convolution/ - # Output height = (Input height + padding height top + padding height bottom - kernel height) / (stride height) + 1 - # Output width = (Output width + padding width right + padding width left - kernel width) / (stride width) + 1 - return (I - F + 2 * P - S) // S - - -# receptive field calculator -# https://fomoro.com/research/article/receptive-field-calculator -# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters -# https://distill.pub/2019/computing-receptive-fields/ -# Rl-1 = Sl * Rl + (Kl - Sl) - - -class ConvBn(nn.Layer): - """Convolution layer with batch normalization. - - :param kernel_size: The x dimension of a filter kernel. Or input a tuple for - two image dimension. - :type kernel_size: int|tuple|list - :param num_channels_in: Number of input channels. - :type num_channels_in: int - :param num_channels_out: Number of output channels. - :type num_channels_out: int - :param stride: The x dimension of the stride. Or input a tuple for two - image dimension. - :type stride: int|tuple|list - :param padding: The x dimension of the padding. Or input a tuple for two - image dimension. - :type padding: int|tuple|list - :param act: Activation type, relu|brelu - :type act: string - :return: Batch norm layer after convolution layer. - :rtype: Variable - - """ - - def __init__(self, num_channels_in, num_channels_out, kernel_size, stride, - padding, act): - - super().__init__() - assert len(kernel_size) == 2 - assert len(stride) == 2 - assert len(padding) == 2 - self.kernel_size = kernel_size - self.stride = stride - self.padding = padding - - self.conv = nn.Conv2D( - num_channels_in, - num_channels_out, - kernel_size=kernel_size, - stride=stride, - padding=padding, - weight_attr=None, - bias_attr=False, - data_format='NCHW') - - self.bn = nn.BatchNorm2D( - num_channels_out, - weight_attr=None, - bias_attr=None, - data_format='NCHW') - self.act = F.relu if act == 'relu' else brelu - - def forward(self, x, x_len): - """ - x(Tensor): audio, shape [B, C, D, T] - """ + def forward(self, x: paddle.Tensor, + x_len: paddle.Tensor) -> [paddle.Tensor, paddle.Tensor]: + x = x.unsqueeze(1) # (b, c=1, t, f) x = self.conv(x) - x = self.bn(x) - x = self.act(x) - - x_len = (x_len - self.kernel_size[1] + 2 * self.padding[1] - ) // self.stride[1] + 1 - - # reset padding part to 0 - masks = make_non_pad_mask(x_len) #[B, T] - masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T] - # TODO(Hui Zhang): not support bool multiply - # masks = masks.type_as(x) - masks = masks.astype(x.dtype) - x = x.multiply(masks) - - return x, x_len - - -class ConvStack(nn.Layer): - """Convolution group with stacked convolution layers. - - :param feat_size: audio feature dim. - :type feat_size: int - :param num_stacks: Number of stacked convolution layers. - :type num_stacks: int - """ - - def __init__(self, feat_size, num_stacks): - super().__init__() - self.feat_size = feat_size # D - self.num_stacks = num_stacks - - self.conv_in = ConvBn( - num_channels_in=1, - num_channels_out=32, - kernel_size=(41, 11), #[D, T] - stride=(2, 3), - padding=(20, 5), - act='brelu') - - out_channel = 32 - convs = [ - ConvBn( - num_channels_in=32, - num_channels_out=out_channel, - kernel_size=(21, 11), - stride=(2, 1), - padding=(10, 5), - act='brelu') for i in range(num_stacks - 1) - ] - self.conv_stack = nn.LayerList(convs) - - # conv output feat_dim - output_height = (feat_size - 1) // 2 + 1 - for i in range(self.num_stacks - 1): - output_height = (output_height - 1) // 2 + 1 - self.output_height = out_channel * output_height - - def forward(self, x, x_len): - """ - x: shape [B, C, D, T] - x_len : shape [B] - """ - x, x_len = self.conv_in(x, x_len) - for i, conv in enumerate(self.conv_stack): - x, x_len = conv(x, x_len) + b, c, t, f = paddle.shape(x) + x = x.transpose([0, 2, 1, 3]).reshape([b, t, c * f]) + x_len = ((x_len - 1) // 2 - 1) // 2 return x, x_len diff --git a/deepspeech/models/ds2_online/deepspeech2.py b/deepspeech/models/ds2_online/deepspeech2.py index 3c77209f..4fa6da0d 100644 --- a/deepspeech/models/ds2_online/deepspeech2.py +++ b/deepspeech/models/ds2_online/deepspeech2.py @@ -11,27 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Deepspeech2 ASR Model""" +"""Deepspeech2 ASR Online Model""" from typing import Optional import paddle import paddle.nn.functional as F from paddle import nn -from paddle.fluid.layers import fc -from paddle.nn import GRU -from paddle.nn import LayerList -from paddle.nn import LayerNorm -from paddle.nn import Linear -from paddle.nn import LSTM from yacs.config import CfgNode -from deepspeech.models.ds2_online.conv import ConvStack -from deepspeech.models.ds2_online.rnn import RNNStack +from deepspeech.models.ds2_online.conv import Conv2dSubsampling4Online from deepspeech.modules.ctc import CTCDecoder from deepspeech.utils import layer_tools from deepspeech.utils.checkpoint import Checkpoint from deepspeech.utils.log import Log - logger = Log(__name__).getlog() __all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModeOnline'] @@ -55,46 +47,48 @@ class CRNNEncoder(nn.Layer): self.num_rnn_layers = num_rnn_layers self.num_fc_layers = num_fc_layers self.fc_layers_size_list = fc_layers_size_list - self.conv = ConvStack(feat_size, num_conv_layers) + self.conv = Conv2dSubsampling4Online(feat_size, 32, dropout_rate=0.0) - i_size = self.conv.output_height # H after conv stack + i_size = self.conv.output_dim - self.rnn = LayerList() - self.layernorm_list = LayerList() - self.fc_layers_list = LayerList() + self.rnn = nn.LayerList() + self.layernorm_list = nn.LayerList() + self.fc_layers_list = nn.LayerList() rnn_direction = 'forward' layernorm_size = rnn_size if use_gru == True: self.rnn.append( - GRU(input_size=i_size, + nn.GRU( + input_size=i_size, hidden_size=rnn_size, num_layers=1, direction=rnn_direction)) - self.layernorm_list.append(LayerNorm(layernorm_size)) + self.layernorm_list.append(nn.LayerNorm(layernorm_size)) for i in range(1, num_rnn_layers): self.rnn.append( - GRU(input_size=layernorm_size, + nn.GRU( + input_size=layernorm_size, hidden_size=rnn_size, num_layers=1, direction=rnn_direction)) - self.layernorm_list.append(LayerNorm(layernorm_size)) + self.layernorm_list.append(nn.LayerNorm(layernorm_size)) else: self.rnn.append( - LSTM( + nn.LSTM( input_size=i_size, hidden_size=rnn_size, num_layers=1, direction=rnn_direction)) - self.layernorm_list.append(LayerNorm(layernorm_size)) + self.layernorm_list.append(nn.LayerNorm(layernorm_size)) for i in range(1, num_rnn_layers): self.rnn.append( - LSTM( + nn.LSTM( input_size=layernorm_size, hidden_size=rnn_size, num_layers=1, direction=rnn_direction)) - self.layernorm_list.append(LayerNorm(layernorm_size)) + self.layernorm_list.append(nn.LayerNorm(layernorm_size)) fc_input_size = layernorm_size for i in range(self.num_fc_layers): self.fc_layers_list.append( @@ -117,20 +111,16 @@ class CRNNEncoder(nn.Layer): x (Tensor): encoder outputs, [B, T, D] x_lens (Tensor): encoder length, [B] """ - # [B, T, D] -> [B, D, T] - audio = audio.transpose([0, 2, 1]) - # [B, D, T] -> [B, C=1, D, T] - x = audio.unsqueeze(1) + # [B, T, D] + x = audio x_lens = audio_len - # convolution group x, x_lens = self.conv(x, x_lens) - # convert data from convolution feature map to sequence of vectors #B, C, D, T = paddle.shape(x) # not work under jit - x = x.transpose([0, 3, 1, 2]) #[B, T, C, D] + #x = x.transpose([0, 3, 1, 2]) #[B, T, C, D] #x = x.reshape([B, T, C * D]) #[B, T, C*D] # not work under jit - x = x.reshape([0, 0, -1]) #[B, T, C*D] + #x = x.reshape([0, 0, -1]) #[B, T, C*D] # remove padding part x, output_state = self.rnn[0](x, None, x_lens) diff --git a/examples/aishell/s0/conf/deepspeech2.yaml b/examples/aishell/s0/conf/deepspeech2.yaml index 7d0d1f89..1c97fc60 100644 --- a/examples/aishell/s0/conf/deepspeech2.yaml +++ b/examples/aishell/s0/conf/deepspeech2.yaml @@ -36,11 +36,10 @@ collator: model: num_conv_layers: 2 - num_rnn_layers: 4 + num_rnn_layers: 3 rnn_layer_size: 1024 use_gru: True share_rnn_weights: False - apply_online: False training: n_epoch: 50 diff --git a/examples/librispeech/s0/conf/deepspeech2.yaml b/examples/librispeech/s0/conf/deepspeech2.yaml index be1918d0..acee94c3 100644 --- a/examples/librispeech/s0/conf/deepspeech2.yaml +++ b/examples/librispeech/s0/conf/deepspeech2.yaml @@ -40,7 +40,6 @@ model: rnn_layer_size: 2048 use_gru: False share_rnn_weights: True - apply_online: False training: n_epoch: 50 diff --git a/examples/tiny/s0/conf/deepspeech2.yaml b/examples/tiny/s0/conf/deepspeech2.yaml index 8c719e5c..ea433f34 100644 --- a/examples/tiny/s0/conf/deepspeech2.yaml +++ b/examples/tiny/s0/conf/deepspeech2.yaml @@ -41,7 +41,6 @@ model: rnn_layer_size: 2048 use_gru: False share_rnn_weights: True - apply_online: True training: n_epoch: 10 diff --git a/tests/deepspeech2_model_test.py b/tests/deepspeech2_model_test.py index bb40802d..00df8195 100644 --- a/tests/deepspeech2_model_test.py +++ b/tests/deepspeech2_model_test.py @@ -16,8 +16,8 @@ import unittest import numpy as np import paddle -#from deepspeech.models.deepspeech2 import DeepSpeech2Model -from deepspeech.models.ds2_online import DeepSpeech2ModelOnline as DeepSpeech2Model +from deepspeech.models.ds2 import DeepSpeech2Model + class TestDeepSpeech2Model(unittest.TestCase): def setUp(self):