|
|
|
@ -202,7 +202,7 @@ class BiRNNWithBN(nn.Layer):
|
|
|
|
|
self.fw_rnn = nn.RNN(
|
|
|
|
|
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
|
|
|
|
|
self.bw_rnn = nn.RNN(
|
|
|
|
|
self.bw_cell, is_reverse=True, time_major=False) #[B, T, D]
|
|
|
|
|
self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]
|
|
|
|
|
|
|
|
|
|
def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
|
|
|
|
|
# x, shape [B, T, D]
|
|
|
|
@ -246,7 +246,7 @@ class BiGRUWithBN(nn.Layer):
|
|
|
|
|
self.fw_rnn = nn.RNN(
|
|
|
|
|
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
|
|
|
|
|
self.bw_rnn = nn.RNN(
|
|
|
|
|
self.bw_cell, is_reverse=True, time_major=False) #[B, T, D]
|
|
|
|
|
self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]
|
|
|
|
|
|
|
|
|
|
def forward(self, x, x_len):
|
|
|
|
|
# x, shape [B, T, D]
|
|
|
|
|