parent
c7071dd2b4
commit
269eecb3be
@ -0,0 +1,7 @@
|
||||
from .deepspeech2 import DeepSpeech2Model
|
||||
from .deepspeech2 import DeepSpeech2InferModel
|
||||
|
||||
__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel']
|
||||
|
||||
|
||||
|
@ -0,0 +1,172 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
|
||||
from deepspeech.modules.activation import brelu
|
||||
from deepspeech.modules.mask import make_non_pad_mask
|
||||
from deepspeech.utils.log import Log
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
__all__ = ['ConvStack', "conv_output_size"]
|
||||
|
||||
|
||||
def conv_output_size(I, F, P, S):
|
||||
# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters
|
||||
# Output size after Conv:
|
||||
# By noting I the length of the input volume size,
|
||||
# F the length of the filter,
|
||||
# P the amount of zero padding,
|
||||
# S the stride,
|
||||
# then the output size O of the feature map along that dimension is given by:
|
||||
# O = (I - F + Pstart + Pend) // S + 1
|
||||
# When Pstart == Pend == P, we can replace Pstart + Pend by 2P.
|
||||
# When Pstart == Pend == 0
|
||||
# O = (I - F - S) // S
|
||||
# https://iq.opengenus.org/output-size-of-convolution/
|
||||
# Output height = (Input height + padding height top + padding height bottom - kernel height) / (stride height) + 1
|
||||
# Output width = (Output width + padding width right + padding width left - kernel width) / (stride width) + 1
|
||||
return (I - F + 2 * P - S) // S
|
||||
|
||||
|
||||
# receptive field calculator
|
||||
# https://fomoro.com/research/article/receptive-field-calculator
|
||||
# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters
|
||||
# https://distill.pub/2019/computing-receptive-fields/
|
||||
# Rl-1 = Sl * Rl + (Kl - Sl)
|
||||
|
||||
|
||||
class ConvBn(nn.Layer):
|
||||
"""Convolution layer with batch normalization.
|
||||
|
||||
:param kernel_size: The x dimension of a filter kernel. Or input a tuple for
|
||||
two image dimension.
|
||||
:type kernel_size: int|tuple|list
|
||||
:param num_channels_in: Number of input channels.
|
||||
:type num_channels_in: int
|
||||
:param num_channels_out: Number of output channels.
|
||||
:type num_channels_out: int
|
||||
:param stride: The x dimension of the stride. Or input a tuple for two
|
||||
image dimension.
|
||||
:type stride: int|tuple|list
|
||||
:param padding: The x dimension of the padding. Or input a tuple for two
|
||||
image dimension.
|
||||
:type padding: int|tuple|list
|
||||
:param act: Activation type, relu|brelu
|
||||
:type act: string
|
||||
:return: Batch norm layer after convolution layer.
|
||||
:rtype: Variable
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, num_channels_in, num_channels_out, kernel_size, stride,
|
||||
padding, act):
|
||||
|
||||
super().__init__()
|
||||
assert len(kernel_size) == 2
|
||||
assert len(stride) == 2
|
||||
assert len(padding) == 2
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
|
||||
self.conv = nn.Conv2D(
|
||||
num_channels_in,
|
||||
num_channels_out,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
weight_attr=None,
|
||||
bias_attr=False,
|
||||
data_format='NCHW')
|
||||
|
||||
self.bn = nn.BatchNorm2D(
|
||||
num_channels_out,
|
||||
weight_attr=None,
|
||||
bias_attr=None,
|
||||
data_format='NCHW')
|
||||
self.act = F.relu if act == 'relu' else brelu
|
||||
|
||||
def forward(self, x, x_len):
|
||||
"""
|
||||
x(Tensor): audio, shape [B, C, D, T]
|
||||
"""
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.act(x)
|
||||
|
||||
x_len = (x_len - self.kernel_size[1] + 2 * self.padding[1]
|
||||
) // self.stride[1] + 1
|
||||
|
||||
# reset padding part to 0
|
||||
masks = make_non_pad_mask(x_len) #[B, T]
|
||||
masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T]
|
||||
# TODO(Hui Zhang): not support bool multiply
|
||||
# masks = masks.type_as(x)
|
||||
masks = masks.astype(x.dtype)
|
||||
x = x.multiply(masks)
|
||||
|
||||
return x, x_len
|
||||
|
||||
|
||||
class ConvStack(nn.Layer):
|
||||
"""Convolution group with stacked convolution layers.
|
||||
|
||||
:param feat_size: audio feature dim.
|
||||
:type feat_size: int
|
||||
:param num_stacks: Number of stacked convolution layers.
|
||||
:type num_stacks: int
|
||||
"""
|
||||
|
||||
def __init__(self, feat_size, num_stacks):
|
||||
super().__init__()
|
||||
self.feat_size = feat_size # D
|
||||
self.num_stacks = num_stacks
|
||||
|
||||
self.conv_in = ConvBn(
|
||||
num_channels_in=1,
|
||||
num_channels_out=32,
|
||||
kernel_size=(41, 11), #[D, T]
|
||||
stride=(2, 3),
|
||||
padding=(20, 5),
|
||||
act='brelu')
|
||||
|
||||
out_channel = 32
|
||||
convs = [
|
||||
ConvBn(
|
||||
num_channels_in=32,
|
||||
num_channels_out=out_channel,
|
||||
kernel_size=(21, 11),
|
||||
stride=(2, 1),
|
||||
padding=(10, 5),
|
||||
act='brelu') for i in range(num_stacks - 1)
|
||||
]
|
||||
self.conv_stack = nn.LayerList(convs)
|
||||
|
||||
# conv output feat_dim
|
||||
output_height = (feat_size - 1) // 2 + 1
|
||||
for i in range(self.num_stacks - 1):
|
||||
output_height = (output_height - 1) // 2 + 1
|
||||
self.output_height = out_channel * output_height
|
||||
|
||||
def forward(self, x, x_len):
|
||||
"""
|
||||
x: shape [B, C, D, T]
|
||||
x_len : shape [B]
|
||||
"""
|
||||
x, x_len = self.conv_in(x, x_len)
|
||||
for i, conv in enumerate(self.conv_stack):
|
||||
x, x_len = conv(x, x_len)
|
||||
return x, x_len
|
@ -0,0 +1,312 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Deepspeech2 ASR Model"""
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from deepspeech.models.ds2.conv import ConvStack
|
||||
from deepspeech.modules.ctc import CTCDecoder
|
||||
from deepspeech.models.ds2.rnn import RNNStack
|
||||
from deepspeech.utils import layer_tools
|
||||
from deepspeech.utils.checkpoint import Checkpoint
|
||||
from deepspeech.utils.log import Log
|
||||
|
||||
from paddle.nn import LSTM, GRU
|
||||
from paddle.nn import LayerNorm
|
||||
from paddle.nn import LayerList
|
||||
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferMode']
|
||||
|
||||
|
||||
class CRNNEncoder(nn.Layer):
|
||||
def __init__(self,
|
||||
feat_size,
|
||||
dict_size,
|
||||
num_conv_layers=2,
|
||||
num_rnn_layers=3,
|
||||
rnn_size=1024,
|
||||
use_gru=False,
|
||||
share_rnn_weights=True,
|
||||
apply_online=True):
|
||||
super().__init__()
|
||||
self.rnn_size = rnn_size
|
||||
self.feat_size = feat_size # 161 for linear
|
||||
self.dict_size = dict_size
|
||||
self.num_rnn_layers = num_rnn_layers
|
||||
self.apply_online = apply_online
|
||||
self.conv = ConvStack(feat_size, num_conv_layers)
|
||||
|
||||
i_size = self.conv.output_height # H after conv stack
|
||||
|
||||
|
||||
self.rnn = LayerList()
|
||||
self.layernorm_list = LayerList()
|
||||
|
||||
if (apply_online == True):
|
||||
rnn_direction = 'forward'
|
||||
layernorm_size = rnn_size
|
||||
else:
|
||||
rnn_direction = 'bidirect'
|
||||
layernorm_size = 2 * rnn_size
|
||||
|
||||
if use_gru == True:
|
||||
self.rnn.append(GRU(input_size=i_size, hidden_size=rnn_size, num_layers=1, direction = rnn_direction))
|
||||
self.layernorm_list.append(LayerNorm(layernorm_size))
|
||||
for i in range(1, num_rnn_layers):
|
||||
self.rnn.append(GRU(input_size=layernorm_size, hidden_size=rnn_size, num_layers=1, direction = rnn_direction))
|
||||
self.layernorm_list.append(LayerNorm(layernorm_size))
|
||||
else:
|
||||
self.rnn.append(LSTM(input_size=i_size, hidden_size=rnn_size, num_layers=1, direction = rnn_direction))
|
||||
self.layernorm_list.append(LayerNorm(layernorm_size))
|
||||
for i in range(1, num_rnn_layers):
|
||||
self.rnn.append(LSTM(input_size=layernorm_size, hidden_size=rnn_size, num_layers=1, direction = rnn_direction))
|
||||
self.layernorm_list.append(LayerNorm(layernorm_size))
|
||||
"""
|
||||
self.rnn = RNNStack(
|
||||
i_size=i_size,
|
||||
h_size=rnn_size,
|
||||
num_stacks=num_rnn_layers,
|
||||
use_gru=use_gru,
|
||||
share_rnn_weights=share_rnn_weights)
|
||||
"""
|
||||
@property
|
||||
def output_size(self):
|
||||
if (self.apply_online == True):
|
||||
return self.rnn_size
|
||||
else:
|
||||
return 2 * self.rnn_size
|
||||
|
||||
def forward(self, audio, audio_len):
|
||||
"""Compute Encoder outputs
|
||||
|
||||
Args:
|
||||
audio (Tensor): [B, Tmax, D]
|
||||
text (Tensor): [B, Umax]
|
||||
audio_len (Tensor): [B]
|
||||
text_len (Tensor): [B]
|
||||
Returns:
|
||||
x (Tensor): encoder outputs, [B, T, D]
|
||||
x_lens (Tensor): encoder length, [B]
|
||||
"""
|
||||
# [B, T, D] -> [B, D, T]
|
||||
audio = audio.transpose([0, 2, 1])
|
||||
# [B, D, T] -> [B, C=1, D, T]
|
||||
x = audio.unsqueeze(1)
|
||||
x_lens = audio_len
|
||||
|
||||
# convolution group
|
||||
x, x_lens = self.conv(x, x_lens)
|
||||
|
||||
# convert data from convolution feature map to sequence of vectors
|
||||
#B, C, D, T = paddle.shape(x) # not work under jit
|
||||
x = x.transpose([0, 3, 1, 2]) #[B, T, C, D]
|
||||
#x = x.reshape([B, T, C * D]) #[B, T, C*D] # not work under jit
|
||||
x = x.reshape([0, 0, -1]) #[B, T, C*D]
|
||||
|
||||
# remove padding part
|
||||
x, output_state = self.rnn[0](x, None, x_lens)
|
||||
x = self.layernorm_list[0](x)
|
||||
for i in range(1, self.num_rnn_layers):
|
||||
x, output_state = self.rnn[i](x, output_state, x_lens) #[B, T, D]
|
||||
x = self.layernorm_list[i](x)
|
||||
"""
|
||||
x, x_lens = self.rnn(x, x_lens)
|
||||
"""
|
||||
return x, x_lens
|
||||
|
||||
|
||||
class DeepSpeech2Model(nn.Layer):
|
||||
"""The DeepSpeech2 network structure.
|
||||
|
||||
:param audio_data: Audio spectrogram data layer.
|
||||
:type audio_data: Variable
|
||||
:param text_data: Transcription text data layer.
|
||||
:type text_data: Variable
|
||||
:param audio_len: Valid sequence length data layer.
|
||||
:type audio_len: Variable
|
||||
:param masks: Masks data layer to reset padding.
|
||||
:type masks: Variable
|
||||
:param dict_size: Dictionary size for tokenized transcription.
|
||||
:type dict_size: int
|
||||
:param num_conv_layers: Number of stacking convolution layers.
|
||||
:type num_conv_layers: int
|
||||
:param num_rnn_layers: Number of stacking RNN layers.
|
||||
:type num_rnn_layers: int
|
||||
:param rnn_size: RNN layer size (dimension of RNN cells).
|
||||
:type rnn_size: int
|
||||
:param use_gru: Use gru if set True. Use simple rnn if set False.
|
||||
:type use_gru: bool
|
||||
:param share_rnn_weights: Whether to share input-hidden weights between
|
||||
forward and backward direction RNNs.
|
||||
It is only available when use_gru=False.
|
||||
:type share_weights: bool
|
||||
:return: A tuple of an output unnormalized log probability layer (
|
||||
before softmax) and a ctc cost layer.
|
||||
:rtype: tuple of LayerOutput
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
|
||||
default = CfgNode(
|
||||
dict(
|
||||
num_conv_layers=2, #Number of stacking convolution layers.
|
||||
num_rnn_layers=3, #Number of stacking RNN layers.
|
||||
rnn_layer_size=1024, #RNN layer size (number of RNN cells).
|
||||
use_gru=True, #Use gru if set True. Use simple rnn if set False.
|
||||
share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported.
|
||||
))
|
||||
if config is not None:
|
||||
config.merge_from_other_cfg(default)
|
||||
return default
|
||||
|
||||
def __init__(self,
|
||||
feat_size,
|
||||
dict_size,
|
||||
num_conv_layers=2,
|
||||
num_rnn_layers=3,
|
||||
rnn_size=1024,
|
||||
use_gru=False,
|
||||
share_rnn_weights=True,
|
||||
apply_online = True):
|
||||
super().__init__()
|
||||
self.encoder = CRNNEncoder(
|
||||
feat_size=feat_size,
|
||||
dict_size=dict_size,
|
||||
num_conv_layers=num_conv_layers,
|
||||
num_rnn_layers=num_rnn_layers,
|
||||
rnn_size=rnn_size,
|
||||
use_gru=use_gru,
|
||||
share_rnn_weights=share_rnn_weights,
|
||||
apply_online=apply_online)
|
||||
if (apply_online == True):
|
||||
assert (self.encoder.output_size == rnn_size)
|
||||
else:
|
||||
assert (self.encoder.output_size == 2 * rnn_size)
|
||||
|
||||
self.decoder = CTCDecoder(
|
||||
odim=dict_size, # <blank> is in vocab
|
||||
enc_n_units=self.encoder.output_size,
|
||||
blank_id=0, # first token is <blank>
|
||||
dropout_rate=0.0,
|
||||
reduction=True, # sum
|
||||
batch_average=True) # sum / batch_size
|
||||
|
||||
def forward(self, audio, audio_len, text, text_len):
|
||||
"""Compute Model loss
|
||||
|
||||
Args:
|
||||
audio (Tenosr): [B, T, D]
|
||||
audio_len (Tensor): [B]
|
||||
text (Tensor): [B, U]
|
||||
text_len (Tensor): [B]
|
||||
|
||||
Returns:
|
||||
loss (Tenosr): [1]
|
||||
"""
|
||||
eouts, eouts_len = self.encoder(audio, audio_len)
|
||||
loss = self.decoder(eouts, eouts_len, text, text_len)
|
||||
return loss
|
||||
|
||||
@paddle.no_grad()
|
||||
def decode(self, audio, audio_len, vocab_list, decoding_method,
|
||||
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
|
||||
cutoff_top_n, num_processes):
|
||||
# init once
|
||||
# decoders only accept string encoded in utf-8
|
||||
self.decoder.init_decode(
|
||||
beam_alpha=beam_alpha,
|
||||
beam_beta=beam_beta,
|
||||
lang_model_path=lang_model_path,
|
||||
vocab_list=vocab_list,
|
||||
decoding_method=decoding_method)
|
||||
|
||||
eouts, eouts_len = self.encoder(audio, audio_len)
|
||||
probs = self.decoder.softmax(eouts)
|
||||
return self.decoder.decode_probs(
|
||||
probs.numpy(), eouts_len, vocab_list, decoding_method,
|
||||
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
|
||||
cutoff_top_n, num_processes)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, dataloader, config, checkpoint_path):
|
||||
"""Build a DeepSpeech2Model model from a pretrained model.
|
||||
Parameters
|
||||
----------
|
||||
dataloader: paddle.io.DataLoader
|
||||
|
||||
config: yacs.config.CfgNode
|
||||
model configs
|
||||
|
||||
checkpoint_path: Path or str
|
||||
the path of pretrained model checkpoint, without extension name
|
||||
|
||||
Returns
|
||||
-------
|
||||
DeepSpeech2Model
|
||||
The model built from pretrained result.
|
||||
"""
|
||||
model = cls(feat_size=dataloader.collate_fn.feature_size,
|
||||
dict_size=dataloader.collate_fn.vocab_size,
|
||||
num_conv_layers=config.model.num_conv_layers,
|
||||
num_rnn_layers=config.model.num_rnn_layers,
|
||||
rnn_size=config.model.rnn_layer_size,
|
||||
use_gru=config.model.use_gru,
|
||||
share_rnn_weights=config.model.share_rnn_weights,
|
||||
apply_online=config.model.apply_online)
|
||||
infos = Checkpoint().load_parameters(
|
||||
model, checkpoint_path=checkpoint_path)
|
||||
logger.info(f"checkpoint info: {infos}")
|
||||
layer_tools.summary(model)
|
||||
return model
|
||||
|
||||
|
||||
class DeepSpeech2InferModel(DeepSpeech2Model):
|
||||
def __init__(self,
|
||||
feat_size,
|
||||
dict_size,
|
||||
num_conv_layers=2,
|
||||
num_rnn_layers=3,
|
||||
rnn_size=1024,
|
||||
use_gru=False,
|
||||
share_rnn_weights=True,
|
||||
apply_online = True):
|
||||
super().__init__(
|
||||
feat_size=feat_size,
|
||||
dict_size=dict_size,
|
||||
num_conv_layers=num_conv_layers,
|
||||
num_rnn_layers=num_rnn_layers,
|
||||
rnn_size=rnn_size,
|
||||
use_gru=use_gru,
|
||||
share_rnn_weights=share_rnn_weights,
|
||||
apply_online=apply_online)
|
||||
|
||||
def forward(self, audio, audio_len):
|
||||
"""export model function
|
||||
|
||||
Args:
|
||||
audio (Tensor): [B, T, D]
|
||||
audio_len (Tensor): [B]
|
||||
|
||||
Returns:
|
||||
probs: probs after softmax
|
||||
"""
|
||||
eouts, eouts_len = self.encoder(audio, audio_len)
|
||||
probs = self.decoder.softmax(eouts)
|
||||
return probs
|
@ -0,0 +1,314 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
from paddle.nn import initializer as I
|
||||
|
||||
from deepspeech.modules.activation import brelu
|
||||
from deepspeech.modules.mask import make_non_pad_mask
|
||||
from deepspeech.utils.log import Log
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
__all__ = ['RNNStack']
|
||||
|
||||
|
||||
class RNNCell(nn.RNNCellBase):
|
||||
r"""
|
||||
Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it
|
||||
computes the outputs and updates states.
|
||||
The formula used is as follows:
|
||||
.. math::
|
||||
h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh})
|
||||
y_{t} & = h_{t}
|
||||
|
||||
where :math:`act` is for :attr:`activation`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
activation="tanh",
|
||||
weight_ih_attr=None,
|
||||
weight_hh_attr=None,
|
||||
bias_ih_attr=None,
|
||||
bias_hh_attr=None,
|
||||
name=None):
|
||||
super().__init__()
|
||||
std = 1.0 / math.sqrt(hidden_size)
|
||||
self.weight_hh = self.create_parameter(
|
||||
(hidden_size, hidden_size),
|
||||
weight_hh_attr,
|
||||
default_initializer=I.Uniform(-std, std))
|
||||
self.bias_ih = None
|
||||
self.bias_hh = self.create_parameter(
|
||||
(hidden_size, ),
|
||||
bias_hh_attr,
|
||||
is_bias=True,
|
||||
default_initializer=I.Uniform(-std, std))
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
if activation not in ["tanh", "relu", "brelu"]:
|
||||
raise ValueError(
|
||||
"activation for SimpleRNNCell should be tanh or relu, "
|
||||
"but get {}".format(activation))
|
||||
self.activation = activation
|
||||
self._activation_fn = paddle.tanh \
|
||||
if activation == "tanh" \
|
||||
else F.relu
|
||||
if activation == 'brelu':
|
||||
self._activation_fn = brelu
|
||||
|
||||
def forward(self, inputs, states=None):
|
||||
if states is None:
|
||||
states = self.get_initial_states(inputs, self.state_shape)
|
||||
pre_h = states
|
||||
i2h = inputs
|
||||
if self.bias_ih is not None:
|
||||
i2h += self.bias_ih
|
||||
h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True)
|
||||
if self.bias_hh is not None:
|
||||
h2h += self.bias_hh
|
||||
h = self._activation_fn(i2h + h2h)
|
||||
return h, h
|
||||
|
||||
@property
|
||||
def state_shape(self):
|
||||
return (self.hidden_size, )
|
||||
|
||||
|
||||
class GRUCell(nn.RNNCellBase):
|
||||
r"""
|
||||
Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states,
|
||||
it computes the outputs and updates states.
|
||||
The formula for GRU used is as follows:
|
||||
.. math::
|
||||
r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr})
|
||||
z_{t} & = \sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz})
|
||||
\widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc}))
|
||||
h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t}
|
||||
y_{t} & = h_{t}
|
||||
|
||||
where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise
|
||||
multiplication operator.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_size: int,
|
||||
hidden_size: int,
|
||||
weight_ih_attr=None,
|
||||
weight_hh_attr=None,
|
||||
bias_ih_attr=None,
|
||||
bias_hh_attr=None,
|
||||
name=None):
|
||||
super().__init__()
|
||||
std = 1.0 / math.sqrt(hidden_size)
|
||||
self.weight_hh = self.create_parameter(
|
||||
(3 * hidden_size, hidden_size),
|
||||
weight_hh_attr,
|
||||
default_initializer=I.Uniform(-std, std))
|
||||
self.bias_ih = None
|
||||
self.bias_hh = self.create_parameter(
|
||||
(3 * hidden_size, ),
|
||||
bias_hh_attr,
|
||||
is_bias=True,
|
||||
default_initializer=I.Uniform(-std, std))
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.input_size = input_size
|
||||
self._gate_activation = F.sigmoid
|
||||
self._activation = paddle.tanh
|
||||
|
||||
def forward(self, inputs, states=None):
|
||||
if states is None:
|
||||
states = self.get_initial_states(inputs, self.state_shape)
|
||||
|
||||
pre_hidden = states
|
||||
x_gates = inputs
|
||||
if self.bias_ih is not None:
|
||||
x_gates = x_gates + self.bias_ih
|
||||
h_gates = paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True)
|
||||
if self.bias_hh is not None:
|
||||
h_gates = h_gates + self.bias_hh
|
||||
|
||||
x_r, x_z, x_c = paddle.split(x_gates, num_or_sections=3, axis=1)
|
||||
h_r, h_z, h_c = paddle.split(h_gates, num_or_sections=3, axis=1)
|
||||
|
||||
r = self._gate_activation(x_r + h_r)
|
||||
z = self._gate_activation(x_z + h_z)
|
||||
c = self._activation(x_c + r * h_c) # apply reset gate after mm
|
||||
h = (pre_hidden - c) * z + c
|
||||
# https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru
|
||||
|
||||
return h, h
|
||||
|
||||
@property
|
||||
def state_shape(self):
|
||||
r"""
|
||||
The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch
|
||||
size would be automatically inserted into shape). The shape corresponds
|
||||
to the shape of :math:`h_{t-1}`.
|
||||
"""
|
||||
return (self.hidden_size, )
|
||||
|
||||
|
||||
class BiRNNWithBN(nn.Layer):
|
||||
"""Bidirectonal simple rnn layer with sequence-wise batch normalization.
|
||||
The batch normalization is only performed on input-state weights.
|
||||
|
||||
:param size: Dimension of RNN cells.
|
||||
:type size: int
|
||||
:param share_weights: Whether to share input-hidden weights between
|
||||
forward and backward directional RNNs.
|
||||
:type share_weights: bool
|
||||
:return: Bidirectional simple rnn layer.
|
||||
:rtype: Variable
|
||||
"""
|
||||
|
||||
def __init__(self, i_size: int, h_size: int, share_weights: bool):
|
||||
super().__init__()
|
||||
self.share_weights = share_weights
|
||||
if self.share_weights:
|
||||
#input-hidden weights shared between bi-directional rnn.
|
||||
self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)
|
||||
# batch norm is only performed on input-state projection
|
||||
self.fw_bn = nn.BatchNorm1D(
|
||||
h_size, bias_attr=None, data_format='NLC')
|
||||
self.bw_fc = self.fw_fc
|
||||
self.bw_bn = self.fw_bn
|
||||
else:
|
||||
self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)
|
||||
self.fw_bn = nn.BatchNorm1D(
|
||||
h_size, bias_attr=None, data_format='NLC')
|
||||
self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False)
|
||||
self.bw_bn = nn.BatchNorm1D(
|
||||
h_size, bias_attr=None, data_format='NLC')
|
||||
|
||||
self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu')
|
||||
self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu')
|
||||
self.fw_rnn = nn.RNN(
|
||||
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
|
||||
self.bw_rnn = nn.RNN(
|
||||
self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]
|
||||
|
||||
def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
|
||||
# x, shape [B, T, D]
|
||||
fw_x = self.fw_bn(self.fw_fc(x))
|
||||
bw_x = self.bw_bn(self.bw_fc(x))
|
||||
fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
|
||||
bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
|
||||
x = paddle.concat([fw_x, bw_x], axis=-1)
|
||||
return x, x_len
|
||||
|
||||
|
||||
class BiGRUWithBN(nn.Layer):
|
||||
"""Bidirectonal gru layer with sequence-wise batch normalization.
|
||||
The batch normalization is only performed on input-state weights.
|
||||
|
||||
:param name: Name of the layer.
|
||||
:type name: string
|
||||
:param input: Input layer.
|
||||
:type input: Variable
|
||||
:param size: Dimension of GRU cells.
|
||||
:type size: int
|
||||
:param act: Activation type.
|
||||
:type act: string
|
||||
:return: Bidirectional GRU layer.
|
||||
:rtype: Variable
|
||||
"""
|
||||
|
||||
def __init__(self, i_size: int, h_size: int):
|
||||
super().__init__()
|
||||
hidden_size = h_size * 3
|
||||
|
||||
self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)
|
||||
self.fw_bn = nn.BatchNorm1D(
|
||||
hidden_size, bias_attr=None, data_format='NLC')
|
||||
self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)
|
||||
self.bw_bn = nn.BatchNorm1D(
|
||||
hidden_size, bias_attr=None, data_format='NLC')
|
||||
|
||||
self.fw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size)
|
||||
self.bw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size)
|
||||
self.fw_rnn = nn.RNN(
|
||||
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
|
||||
self.bw_rnn = nn.RNN(
|
||||
self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]
|
||||
|
||||
def forward(self, x, x_len):
|
||||
# x, shape [B, T, D]
|
||||
fw_x = self.fw_bn(self.fw_fc(x))
|
||||
bw_x = self.bw_bn(self.bw_fc(x))
|
||||
fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
|
||||
bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
|
||||
x = paddle.concat([fw_x, bw_x], axis=-1)
|
||||
return x, x_len
|
||||
|
||||
|
||||
class RNNStack(nn.Layer):
|
||||
"""RNN group with stacked bidirectional simple RNN or GRU layers.
|
||||
|
||||
:param input: Input layer.
|
||||
:type input: Variable
|
||||
:param size: Dimension of RNN cells in each layer.
|
||||
:type size: int
|
||||
:param num_stacks: Number of stacked rnn layers.
|
||||
:type num_stacks: int
|
||||
:param use_gru: Use gru if set True. Use simple rnn if set False.
|
||||
:type use_gru: bool
|
||||
:param share_rnn_weights: Whether to share input-hidden weights between
|
||||
forward and backward directional RNNs.
|
||||
It is only available when use_gru=False.
|
||||
:type share_weights: bool
|
||||
:return: Output layer of the RNN group.
|
||||
:rtype: Variable
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
i_size: int,
|
||||
h_size: int,
|
||||
num_stacks: int,
|
||||
use_gru: bool,
|
||||
share_rnn_weights: bool):
|
||||
super().__init__()
|
||||
rnn_stacks = []
|
||||
for i in range(num_stacks):
|
||||
if use_gru:
|
||||
#default:GRU using tanh
|
||||
rnn_stacks.append(BiGRUWithBN(i_size=i_size, h_size=h_size))
|
||||
else:
|
||||
rnn_stacks.append(
|
||||
BiRNNWithBN(
|
||||
i_size=i_size,
|
||||
h_size=h_size,
|
||||
share_weights=share_rnn_weights))
|
||||
i_size = h_size * 2
|
||||
|
||||
self.rnn_stacks = nn.ModuleList(rnn_stacks)
|
||||
|
||||
def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
|
||||
"""
|
||||
x: shape [B, T, D]
|
||||
x_len: shpae [B]
|
||||
"""
|
||||
for i, rnn in enumerate(self.rnn_stacks):
|
||||
x, x_len = rnn(x, x_len)
|
||||
masks = make_non_pad_mask(x_len) #[B, T]
|
||||
masks = masks.unsqueeze(-1) # [B, T, 1]
|
||||
# TODO(Hui Zhang): not support bool multiply
|
||||
masks = masks.astype(x.dtype)
|
||||
x = x.multiply(masks)
|
||||
return x, x_len
|
Loading…
Reference in new issue