diff --git a/deepspeech/models/ds2/conv.py b/deepspeech/models/ds2/conv.py new file mode 100644 index 000000000..8bf48b2c8 --- /dev/null +++ b/deepspeech/models/ds2/conv.py @@ -0,0 +1,172 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from paddle import nn +from paddle.nn import functional as F + +from deepspeech.modules.activation import brelu +from deepspeech.modules.mask import make_non_pad_mask +from deepspeech.utils.log import Log + +logger = Log(__name__).getlog() + +__all__ = ['ConvStack', "conv_output_size"] + + +def conv_output_size(I, F, P, S): + # https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters + # Output size after Conv: + # By noting I the length of the input volume size, + # F the length of the filter, + # P the amount of zero padding, + # S the stride, + # then the output size O of the feature map along that dimension is given by: + # O = (I - F + Pstart + Pend) // S + 1 + # When Pstart == Pend == P, we can replace Pstart + Pend by 2P. + # When Pstart == Pend == 0 + # O = (I - F - S) // S + # https://iq.opengenus.org/output-size-of-convolution/ + # Output height = (Input height + padding height top + padding height bottom - kernel height) / (stride height) + 1 + # Output width = (Output width + padding width right + padding width left - kernel width) / (stride width) + 1 + return (I - F + 2 * P - S) // S + + +# receptive field calculator +# https://fomoro.com/research/article/receptive-field-calculator +# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters +# https://distill.pub/2019/computing-receptive-fields/ +# Rl-1 = Sl * Rl + (Kl - Sl) + + +class ConvBn(nn.Layer): + """Convolution layer with batch normalization. + + :param kernel_size: The x dimension of a filter kernel. Or input a tuple for + two image dimension. + :type kernel_size: int|tuple|list + :param num_channels_in: Number of input channels. + :type num_channels_in: int + :param num_channels_out: Number of output channels. + :type num_channels_out: int + :param stride: The x dimension of the stride. Or input a tuple for two + image dimension. + :type stride: int|tuple|list + :param padding: The x dimension of the padding. Or input a tuple for two + image dimension. + :type padding: int|tuple|list + :param act: Activation type, relu|brelu + :type act: string + :return: Batch norm layer after convolution layer. + :rtype: Variable + + """ + + def __init__(self, num_channels_in, num_channels_out, kernel_size, stride, + padding, act): + + super().__init__() + assert len(kernel_size) == 2 + assert len(stride) == 2 + assert len(padding) == 2 + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + + self.conv = nn.Conv2D( + num_channels_in, + num_channels_out, + kernel_size=kernel_size, + stride=stride, + padding=padding, + weight_attr=None, + bias_attr=False, + data_format='NCHW') + + self.bn = nn.BatchNorm2D( + num_channels_out, + weight_attr=None, + bias_attr=None, + data_format='NCHW') + self.act = F.relu if act == 'relu' else brelu + + def forward(self, x, x_len): + """ + x(Tensor): audio, shape [B, C, D, T] + """ + x = self.conv(x) + x = self.bn(x) + x = self.act(x) + + x_len = (x_len - self.kernel_size[1] + 2 * self.padding[1] + ) // self.stride[1] + 1 + + # reset padding part to 0 + masks = make_non_pad_mask(x_len) #[B, T] + masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T] + # TODO(Hui Zhang): not support bool multiply + # masks = masks.type_as(x) + masks = masks.astype(x.dtype) + x = x.multiply(masks) + + return x, x_len + + +class ConvStack(nn.Layer): + """Convolution group with stacked convolution layers. + + :param feat_size: audio feature dim. + :type feat_size: int + :param num_stacks: Number of stacked convolution layers. + :type num_stacks: int + """ + + def __init__(self, feat_size, num_stacks): + super().__init__() + self.feat_size = feat_size # D + self.num_stacks = num_stacks + + self.conv_in = ConvBn( + num_channels_in=1, + num_channels_out=32, + kernel_size=(41, 11), #[D, T] + stride=(2, 3), + padding=(20, 5), + act='brelu') + + out_channel = 32 + convs = [ + ConvBn( + num_channels_in=32, + num_channels_out=out_channel, + kernel_size=(21, 11), + stride=(2, 1), + padding=(10, 5), + act='brelu') for i in range(num_stacks - 1) + ] + self.conv_stack = nn.LayerList(convs) + + # conv output feat_dim + output_height = (feat_size - 1) // 2 + 1 + for i in range(self.num_stacks - 1): + output_height = (output_height - 1) // 2 + 1 + self.output_height = out_channel * output_height + + def forward(self, x, x_len): + """ + x: shape [B, C, D, T] + x_len : shape [B] + """ + x, x_len = self.conv_in(x, x_len) + for i, conv in enumerate(self.conv_stack): + x, x_len = conv(x, x_len) + return x, x_len diff --git a/deepspeech/models/ds2/deepspeech2.py b/deepspeech/models/ds2/deepspeech2.py deleted file mode 100644 index 7f173ce29..000000000 --- a/deepspeech/models/ds2/deepspeech2.py +++ /dev/null @@ -1,305 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Deepspeech2 ASR Model""" -from typing import Optional - -import paddle -from paddle import nn -from yacs.config import CfgNode - -from deepspeech.models.ds2.conv import ConvStack -from deepspeech.modules.ctc import CTCDecoder -from deepspeech.models.ds2.rnn import RNNStack -from deepspeech.utils import layer_tools -from deepspeech.utils.checkpoint import Checkpoint -from deepspeech.utils.log import Log - -from paddle.nn import LSTM, GRU -from paddle.nn import LayerNorm -from paddle.nn import LayerList - - -logger = Log(__name__).getlog() - -__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferMode'] - - -class CRNNEncoder(nn.Layer): - def __init__(self, - feat_size, - dict_size, - num_conv_layers=2, - num_rnn_layers=3, - rnn_size=1024, - use_gru=False, - share_rnn_weights=True, - apply_online=True): - super().__init__() - self.rnn_size = rnn_size - self.feat_size = feat_size # 161 for linear - self.dict_size = dict_size - self.num_rnn_layers = num_rnn_layers - self.apply_online = apply_online - self.conv = ConvStack(feat_size, num_conv_layers) - - i_size = self.conv.output_height # H after conv stack - - - self.rnn = LayerList() - self.layernorm_list = LayerList() - - if (apply_online == True): - rnn_direction = 'forward' - else: - rnn_direction = 'bidirect' - - if use_gru == True: - self.rnn.append(GRU(input_size=i_size, hidden_size=rnn_size, num_layers=1, direction = rnn_direction)) - self.layernorm_list.append(LayerNorm(rnn_size)) - for i in range(1, num_rnn_layers): - self.rnn.append(GRU(input_size=rnn_size, hidden_size=rnn_size, num_layers=1, direction = rnn_direction)) - self.layernorm_list.append(LayerNorm(rnn_size)) - else: - self.rnn.append(LSTM(input_size=i_size, hidden_size=rnn_size, num_layers=1, direction = rnn_direction)) - self.layernorm_list.append(LayerNorm(rnn_size)) - for i in range(1, num_rnn_layers): - self.rnn.append(LSTM(input_size=rnn_size, hidden_size=rnn_size, num_layers=1, direction = rnn_direction)) - self.layernorm_list.append(LayerNorm(rnn_size)) - """ - self.rnn = RNNStack( - i_size=i_size, - h_size=rnn_size, - num_stacks=num_rnn_layers, - use_gru=use_gru, - share_rnn_weights=share_rnn_weights) - """ - @property - def output_size(self): - return self.rnn_size - - def forward(self, audio, audio_len): - """Compute Encoder outputs - - Args: - audio (Tensor): [B, Tmax, D] - text (Tensor): [B, Umax] - audio_len (Tensor): [B] - text_len (Tensor): [B] - Returns: - x (Tensor): encoder outputs, [B, T, D] - x_lens (Tensor): encoder length, [B] - """ - # [B, T, D] -> [B, D, T] - audio = audio.transpose([0, 2, 1]) - # [B, D, T] -> [B, C=1, D, T] - x = audio.unsqueeze(1) - x_lens = audio_len - - # convolution group - x, x_lens = self.conv(x, x_lens) - - # convert data from convolution feature map to sequence of vectors - #B, C, D, T = paddle.shape(x) # not work under jit - x = x.transpose([0, 3, 1, 2]) #[B, T, C, D] - #x = x.reshape([B, T, C * D]) #[B, T, C*D] # not work under jit - x = x.reshape([0, 0, -1]) #[B, T, C*D] - - # remove padding part - print ("x.shape:", x.shape) - x, output_state = self.rnn[0](x, None, x_lens) - x = self.layernorm_list[0](x) - for i in range(1, self.num_rnn_layers): - x, output_state = self.rnn[i](x, output_state, x_lens) #[B, T, D] - x = self.layernorm_list[i](x) - """ - x, x_lens = self.rnn(x, x_lens) - """ - return x, x_lens - - -class DeepSpeech2Model(nn.Layer): - """The DeepSpeech2 network structure. - - :param audio_data: Audio spectrogram data layer. - :type audio_data: Variable - :param text_data: Transcription text data layer. - :type text_data: Variable - :param audio_len: Valid sequence length data layer. - :type audio_len: Variable - :param masks: Masks data layer to reset padding. - :type masks: Variable - :param dict_size: Dictionary size for tokenized transcription. - :type dict_size: int - :param num_conv_layers: Number of stacking convolution layers. - :type num_conv_layers: int - :param num_rnn_layers: Number of stacking RNN layers. - :type num_rnn_layers: int - :param rnn_size: RNN layer size (dimension of RNN cells). - :type rnn_size: int - :param use_gru: Use gru if set True. Use simple rnn if set False. - :type use_gru: bool - :param share_rnn_weights: Whether to share input-hidden weights between - forward and backward direction RNNs. - It is only available when use_gru=False. - :type share_weights: bool - :return: A tuple of an output unnormalized log probability layer ( - before softmax) and a ctc cost layer. - :rtype: tuple of LayerOutput - """ - - @classmethod - def params(cls, config: Optional[CfgNode]=None) -> CfgNode: - default = CfgNode( - dict( - num_conv_layers=2, #Number of stacking convolution layers. - num_rnn_layers=3, #Number of stacking RNN layers. - rnn_layer_size=1024, #RNN layer size (number of RNN cells). - use_gru=True, #Use gru if set True. Use simple rnn if set False. - share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported. - )) - if config is not None: - config.merge_from_other_cfg(default) - return default - - def __init__(self, - feat_size, - dict_size, - num_conv_layers=2, - num_rnn_layers=3, - rnn_size=1024, - use_gru=False, - share_rnn_weights=True, - apply_online = True): - super().__init__() - self.encoder = CRNNEncoder( - feat_size=feat_size, - dict_size=dict_size, - num_conv_layers=num_conv_layers, - num_rnn_layers=num_rnn_layers, - rnn_size=rnn_size, - use_gru=use_gru, - share_rnn_weights=share_rnn_weights, - apply_online=apply_online) - assert (self.encoder.output_size == rnn_size) - - self.decoder = CTCDecoder( - odim=dict_size, # is in vocab - enc_n_units=self.encoder.output_size, - blank_id=0, # first token is - dropout_rate=0.0, - reduction=True, # sum - batch_average=True) # sum / batch_size - - def forward(self, audio, audio_len, text, text_len): - """Compute Model loss - - Args: - audio (Tenosr): [B, T, D] - audio_len (Tensor): [B] - text (Tensor): [B, U] - text_len (Tensor): [B] - - Returns: - loss (Tenosr): [1] - """ - eouts, eouts_len = self.encoder(audio, audio_len) - loss = self.decoder(eouts, eouts_len, text, text_len) - return loss - - @paddle.no_grad() - def decode(self, audio, audio_len, vocab_list, decoding_method, - lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, - cutoff_top_n, num_processes): - # init once - # decoders only accept string encoded in utf-8 - self.decoder.init_decode( - beam_alpha=beam_alpha, - beam_beta=beam_beta, - lang_model_path=lang_model_path, - vocab_list=vocab_list, - decoding_method=decoding_method) - - eouts, eouts_len = self.encoder(audio, audio_len) - probs = self.decoder.softmax(eouts) - return self.decoder.decode_probs( - probs.numpy(), eouts_len, vocab_list, decoding_method, - lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, - cutoff_top_n, num_processes) - - @classmethod - def from_pretrained(cls, dataloader, config, checkpoint_path): - """Build a DeepSpeech2Model model from a pretrained model. - Parameters - ---------- - dataloader: paddle.io.DataLoader - - config: yacs.config.CfgNode - model configs - - checkpoint_path: Path or str - the path of pretrained model checkpoint, without extension name - - Returns - ------- - DeepSpeech2Model - The model built from pretrained result. - """ - model = cls(feat_size=dataloader.collate_fn.feature_size, - dict_size=dataloader.collate_fn.vocab_size, - num_conv_layers=config.model.num_conv_layers, - num_rnn_layers=config.model.num_rnn_layers, - rnn_size=config.model.rnn_layer_size, - use_gru=config.model.use_gru, - share_rnn_weights=config.model.share_rnn_weights, - apply_online=config.model.apply_online) - infos = Checkpoint().load_parameters( - model, checkpoint_path=checkpoint_path) - logger.info(f"checkpoint info: {infos}") - layer_tools.summary(model) - return model - - -class DeepSpeech2InferModel(DeepSpeech2Model): - def __init__(self, - feat_size, - dict_size, - num_conv_layers=2, - num_rnn_layers=3, - rnn_size=1024, - use_gru=False, - share_rnn_weights=True, - apply_online = True): - super().__init__( - feat_size=feat_size, - dict_size=dict_size, - num_conv_layers=num_conv_layers, - num_rnn_layers=num_rnn_layers, - rnn_size=rnn_size, - use_gru=use_gru, - share_rnn_weights=share_rnn_weights, - apply_online=apply_online) - - def forward(self, audio, audio_len): - """export model function - - Args: - audio (Tensor): [B, T, D] - audio_len (Tensor): [B] - - Returns: - probs: probs after softmax - """ - eouts, eouts_len = self.encoder(audio, audio_len) - probs = self.decoder.softmax(eouts) - return probs diff --git a/deepspeech/models/ds2/rnn.py b/deepspeech/models/ds2/rnn.py new file mode 100644 index 000000000..01b55c4a2 --- /dev/null +++ b/deepspeech/models/ds2/rnn.py @@ -0,0 +1,314 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import paddle +from paddle import nn +from paddle.nn import functional as F +from paddle.nn import initializer as I + +from deepspeech.modules.activation import brelu +from deepspeech.modules.mask import make_non_pad_mask +from deepspeech.utils.log import Log + +logger = Log(__name__).getlog() + +__all__ = ['RNNStack'] + + +class RNNCell(nn.RNNCellBase): + r""" + Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it + computes the outputs and updates states. + The formula used is as follows: + .. math:: + h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh}) + y_{t} & = h_{t} + + where :math:`act` is for :attr:`activation`. + """ + + def __init__(self, + hidden_size: int, + activation="tanh", + weight_ih_attr=None, + weight_hh_attr=None, + bias_ih_attr=None, + bias_hh_attr=None, + name=None): + super().__init__() + std = 1.0 / math.sqrt(hidden_size) + self.weight_hh = self.create_parameter( + (hidden_size, hidden_size), + weight_hh_attr, + default_initializer=I.Uniform(-std, std)) + self.bias_ih = None + self.bias_hh = self.create_parameter( + (hidden_size, ), + bias_hh_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std)) + + self.hidden_size = hidden_size + if activation not in ["tanh", "relu", "brelu"]: + raise ValueError( + "activation for SimpleRNNCell should be tanh or relu, " + "but get {}".format(activation)) + self.activation = activation + self._activation_fn = paddle.tanh \ + if activation == "tanh" \ + else F.relu + if activation == 'brelu': + self._activation_fn = brelu + + def forward(self, inputs, states=None): + if states is None: + states = self.get_initial_states(inputs, self.state_shape) + pre_h = states + i2h = inputs + if self.bias_ih is not None: + i2h += self.bias_ih + h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True) + if self.bias_hh is not None: + h2h += self.bias_hh + h = self._activation_fn(i2h + h2h) + return h, h + + @property + def state_shape(self): + return (self.hidden_size, ) + + +class GRUCell(nn.RNNCellBase): + r""" + Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states, + it computes the outputs and updates states. + The formula for GRU used is as follows: + .. math:: + r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr}) + z_{t} & = \sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz}) + \widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc})) + h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t} + y_{t} & = h_{t} + + where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise + multiplication operator. + """ + + def __init__(self, + input_size: int, + hidden_size: int, + weight_ih_attr=None, + weight_hh_attr=None, + bias_ih_attr=None, + bias_hh_attr=None, + name=None): + super().__init__() + std = 1.0 / math.sqrt(hidden_size) + self.weight_hh = self.create_parameter( + (3 * hidden_size, hidden_size), + weight_hh_attr, + default_initializer=I.Uniform(-std, std)) + self.bias_ih = None + self.bias_hh = self.create_parameter( + (3 * hidden_size, ), + bias_hh_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std)) + + self.hidden_size = hidden_size + self.input_size = input_size + self._gate_activation = F.sigmoid + self._activation = paddle.tanh + + def forward(self, inputs, states=None): + if states is None: + states = self.get_initial_states(inputs, self.state_shape) + + pre_hidden = states + x_gates = inputs + if self.bias_ih is not None: + x_gates = x_gates + self.bias_ih + h_gates = paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True) + if self.bias_hh is not None: + h_gates = h_gates + self.bias_hh + + x_r, x_z, x_c = paddle.split(x_gates, num_or_sections=3, axis=1) + h_r, h_z, h_c = paddle.split(h_gates, num_or_sections=3, axis=1) + + r = self._gate_activation(x_r + h_r) + z = self._gate_activation(x_z + h_z) + c = self._activation(x_c + r * h_c) # apply reset gate after mm + h = (pre_hidden - c) * z + c + # https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru + + return h, h + + @property + def state_shape(self): + r""" + The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch + size would be automatically inserted into shape). The shape corresponds + to the shape of :math:`h_{t-1}`. + """ + return (self.hidden_size, ) + + +class BiRNNWithBN(nn.Layer): + """Bidirectonal simple rnn layer with sequence-wise batch normalization. + The batch normalization is only performed on input-state weights. + + :param size: Dimension of RNN cells. + :type size: int + :param share_weights: Whether to share input-hidden weights between + forward and backward directional RNNs. + :type share_weights: bool + :return: Bidirectional simple rnn layer. + :rtype: Variable + """ + + def __init__(self, i_size: int, h_size: int, share_weights: bool): + super().__init__() + self.share_weights = share_weights + if self.share_weights: + #input-hidden weights shared between bi-directional rnn. + self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False) + # batch norm is only performed on input-state projection + self.fw_bn = nn.BatchNorm1D( + h_size, bias_attr=None, data_format='NLC') + self.bw_fc = self.fw_fc + self.bw_bn = self.fw_bn + else: + self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False) + self.fw_bn = nn.BatchNorm1D( + h_size, bias_attr=None, data_format='NLC') + self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False) + self.bw_bn = nn.BatchNorm1D( + h_size, bias_attr=None, data_format='NLC') + + self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu') + self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu') + self.fw_rnn = nn.RNN( + self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] + self.bw_rnn = nn.RNN( + self.fw_cell, is_reverse=True, time_major=False) #[B, T, D] + + def forward(self, x: paddle.Tensor, x_len: paddle.Tensor): + # x, shape [B, T, D] + fw_x = self.fw_bn(self.fw_fc(x)) + bw_x = self.bw_bn(self.bw_fc(x)) + fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len) + bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len) + x = paddle.concat([fw_x, bw_x], axis=-1) + return x, x_len + + +class BiGRUWithBN(nn.Layer): + """Bidirectonal gru layer with sequence-wise batch normalization. + The batch normalization is only performed on input-state weights. + + :param name: Name of the layer. + :type name: string + :param input: Input layer. + :type input: Variable + :param size: Dimension of GRU cells. + :type size: int + :param act: Activation type. + :type act: string + :return: Bidirectional GRU layer. + :rtype: Variable + """ + + def __init__(self, i_size: int, h_size: int): + super().__init__() + hidden_size = h_size * 3 + + self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False) + self.fw_bn = nn.BatchNorm1D( + hidden_size, bias_attr=None, data_format='NLC') + self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False) + self.bw_bn = nn.BatchNorm1D( + hidden_size, bias_attr=None, data_format='NLC') + + self.fw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size) + self.bw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size) + self.fw_rnn = nn.RNN( + self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] + self.bw_rnn = nn.RNN( + self.fw_cell, is_reverse=True, time_major=False) #[B, T, D] + + def forward(self, x, x_len): + # x, shape [B, T, D] + fw_x = self.fw_bn(self.fw_fc(x)) + bw_x = self.bw_bn(self.bw_fc(x)) + fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len) + bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len) + x = paddle.concat([fw_x, bw_x], axis=-1) + return x, x_len + + +class RNNStack(nn.Layer): + """RNN group with stacked bidirectional simple RNN or GRU layers. + + :param input: Input layer. + :type input: Variable + :param size: Dimension of RNN cells in each layer. + :type size: int + :param num_stacks: Number of stacked rnn layers. + :type num_stacks: int + :param use_gru: Use gru if set True. Use simple rnn if set False. + :type use_gru: bool + :param share_rnn_weights: Whether to share input-hidden weights between + forward and backward directional RNNs. + It is only available when use_gru=False. + :type share_weights: bool + :return: Output layer of the RNN group. + :rtype: Variable + """ + + def __init__(self, + i_size: int, + h_size: int, + num_stacks: int, + use_gru: bool, + share_rnn_weights: bool): + super().__init__() + rnn_stacks = [] + for i in range(num_stacks): + if use_gru: + #default:GRU using tanh + rnn_stacks.append(BiGRUWithBN(i_size=i_size, h_size=h_size)) + else: + rnn_stacks.append( + BiRNNWithBN( + i_size=i_size, + h_size=h_size, + share_weights=share_rnn_weights)) + i_size = h_size * 2 + + self.rnn_stacks = nn.ModuleList(rnn_stacks) + + def forward(self, x: paddle.Tensor, x_len: paddle.Tensor): + """ + x: shape [B, T, D] + x_len: shpae [B] + """ + for i, rnn in enumerate(self.rnn_stacks): + x, x_len = rnn(x, x_len) + masks = make_non_pad_mask(x_len) #[B, T] + masks = masks.unsqueeze(-1) # [B, T, 1] + # TODO(Hui Zhang): not support bool multiply + masks = masks.astype(x.dtype) + x = x.multiply(masks) + return x, x_len