|
|
|
@ -29,13 +29,13 @@ __all__ = ['RNNStack']
|
|
|
|
|
|
|
|
|
|
class RNNCell(nn.RNNCellBase):
|
|
|
|
|
r"""
|
|
|
|
|
Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it
|
|
|
|
|
Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it
|
|
|
|
|
computes the outputs and updates states.
|
|
|
|
|
The formula used is as follows:
|
|
|
|
|
.. math::
|
|
|
|
|
h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh})
|
|
|
|
|
y_{t} & = h_{t}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
where :math:`act` is for :attr:`activation`.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
@ -92,7 +92,7 @@ class RNNCell(nn.RNNCellBase):
|
|
|
|
|
|
|
|
|
|
class GRUCell(nn.RNNCellBase):
|
|
|
|
|
r"""
|
|
|
|
|
Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states,
|
|
|
|
|
Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states,
|
|
|
|
|
it computes the outputs and updates states.
|
|
|
|
|
The formula for GRU used is as follows:
|
|
|
|
|
.. math::
|
|
|
|
@ -101,8 +101,8 @@ class GRUCell(nn.RNNCellBase):
|
|
|
|
|
\widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc}))
|
|
|
|
|
h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t}
|
|
|
|
|
y_{t} & = h_{t}
|
|
|
|
|
|
|
|
|
|
where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise
|
|
|
|
|
|
|
|
|
|
where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise
|
|
|
|
|
multiplication operator.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
@ -202,7 +202,7 @@ class BiRNNWithBN(nn.Layer):
|
|
|
|
|
self.fw_rnn = nn.RNN(
|
|
|
|
|
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
|
|
|
|
|
self.bw_rnn = nn.RNN(
|
|
|
|
|
self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]
|
|
|
|
|
self.bw_cell, is_reverse=True, time_major=False) #[B, T, D]
|
|
|
|
|
|
|
|
|
|
def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
|
|
|
|
|
# x, shape [B, T, D]
|
|
|
|
@ -246,7 +246,7 @@ class BiGRUWithBN(nn.Layer):
|
|
|
|
|
self.fw_rnn = nn.RNN(
|
|
|
|
|
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
|
|
|
|
|
self.bw_rnn = nn.RNN(
|
|
|
|
|
self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]
|
|
|
|
|
self.bw_cell, is_reverse=True, time_major=False) #[B, T, D]
|
|
|
|
|
|
|
|
|
|
def forward(self, x, x_len):
|
|
|
|
|
# x, shape [B, T, D]
|
|
|
|
|