reconstruct the rnn state, from list to tensor

pull/735/head
huangyuxin 4 years ago
parent 8f062cad6b
commit 722c55e4c5

@ -376,16 +376,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
dtype='float32'), # audio, [B,T,D] dtype='float32'), # audio, [B,T,D]
paddle.static.InputSpec(shape=[None], paddle.static.InputSpec(shape=[None],
dtype='int64'), # audio_length, [B] dtype='int64'), # audio_length, [B]
[
(
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, None, None], dtype='float32' shape=[None, None, None], dtype='float32'),
), #num_rnn_layers * num_dirctions, rnn_size
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, None, None], dtype='float32' shape=[None, None, None], dtype='float32')
) #num_rnn_layers * num_dirctions, rnn_size
) for i in range(self.config.model.num_rnn_layers)
]
]) ])
else: else:
raise Exception("wrong model type") raise Exception("wrong model type")

@ -48,6 +48,7 @@ class CRNNEncoder(nn.Layer):
self.num_fc_layers = num_fc_layers self.num_fc_layers = num_fc_layers
self.rnn_direction = rnn_direction self.rnn_direction = rnn_direction
self.fc_layers_size_list = fc_layers_size_list self.fc_layers_size_list = fc_layers_size_list
self.use_gru = use_gru
self.conv = Conv2dSubsampling4Online(feat_size, 32, dropout_rate=0.0) self.conv = Conv2dSubsampling4Online(feat_size, 32, dropout_rate=0.0)
i_size = self.conv.output_dim i_size = self.conv.output_dim
@ -96,7 +97,8 @@ class CRNNEncoder(nn.Layer):
Returns: Returns:
x (Tensor): encoder outputs, [B, T_output, D] x (Tensor): encoder outputs, [B, T_output, D]
x_lens (Tensor): encoder length, [B] x_lens (Tensor): encoder length, [B]
final_state_list: list of final_states for RNN layers, [num_directions, batch_size, hidden_size] * num_rnn_layers final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
""" """
# [B, T, D] # [B, T, D]
# convolution group # convolution group
@ -118,32 +120,79 @@ class CRNNEncoder(nn.Layer):
for i in range(self.num_fc_layers): for i in range(self.num_fc_layers):
x = self.fc_layers_list[i](x) x = self.fc_layers_list[i](x)
x = F.relu(x) x = F.relu(x)
return x, x_lens, final_state_list
def forward_chunk(self, x, x_lens, init_state_list): if self.use_gru == True:
final_state_h_box = paddle.concat(final_state_list, axis=0)
final_state_c_box = paddle.zeros_like(final_state_h_box)
else:
final_state_h_list = [
final_state_list[i][0] for i in range(self.num_rnn_layers)
]
final_state_c_list = [
final_state_list[i][1] for i in range(self.num_rnn_layers)
]
final_state_h_box = paddle.concat(final_state_h_list, axis=0)
final_state_c_box = paddle.concat(final_state_c_list, axis=0)
return x, x_lens, final_state_h_box, final_state_c_box
def forward_chunk(self, x, x_lens, init_state_h_box, init_state_c_box):
"""Compute Encoder outputs """Compute Encoder outputs
Args: Args:
x (Tensor): [B, feature_chunk_size, D] x (Tensor): [B, feature_size, D]
x_lens (Tensor): [B] x_lens (Tensor): [B]
init_state_list (list of Tensors): [ num_directions, batch_size, hidden_size] * num_rnn_layers init_state_h_box(Tensor): init_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
init_state_c_box(Tensor): init_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
Returns: Returns:
x (Tensor): encoder outputs, [B, chunk_size, D] x (Tensor): encoder outputs, [B, size, D]
x_lens (Tensor): encoder length, [B] x_lens (Tensor): encoder length, [B]
chunk_final_state_list: list of final_states for RNN layers, [num_directions, batch_size, hidden_size] * num_rnn_layers final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
""" """
if init_state_h_box is not None:
init_state_list = None
init_state_h_list = paddle.split(
init_state_h_box, self.num_rnn_layers, axis=0)
init_state_c_list = paddle.split(
init_state_c_box, self.num_rnn_layers, axis=0)
if self.use_gru == True:
init_state_list = init_state_h_list
else:
init_state_list = [(init_state_h_list[i], init_state_c_list[i])
for i in range(self.num_rnn_layers)]
else:
init_state_list = [None] * self.num_rnn_layers
x, x_lens = self.conv(x, x_lens) x, x_lens = self.conv(x, x_lens)
chunk_final_state_list = [] final_chunk_state_list = []
for i in range(0, self.num_rnn_layers): for i in range(0, self.num_rnn_layers):
x, final_state = self.rnn[i](x, init_state_list[i], x, final_state = self.rnn[i](x, init_state_list[i],
x_lens) #[B, T, D] x_lens) #[B, T, D]
chunk_final_state_list.append(final_state) final_chunk_state_list.append(final_state)
x = self.layernorm_list[i](x) x = self.layernorm_list[i](x)
for i in range(self.num_fc_layers): for i in range(self.num_fc_layers):
x = self.fc_layers_list[i](x) x = self.fc_layers_list[i](x)
x = F.relu(x) x = F.relu(x)
return x, x_lens, chunk_final_state_list
if self.use_gru == True:
final_chunk_state_h_box = paddle.concat(
final_chunk_state_list, axis=0)
final_chunk_state_c_box = paddle.zeros_like(final_chunk_state_h_box)
else:
final_chunk_state_h_list = [
final_chunk_state_list[i][0] for i in range(self.num_rnn_layers)
]
final_chunk_state_c_list = [
final_chunk_state_list[i][1] for i in range(self.num_rnn_layers)
]
final_chunk_state_h_box = paddle.concat(
final_chunk_state_h_list, axis=0)
final_chunk_state_c_box = paddle.concat(
final_chunk_state_c_list, axis=0)
return x, x_lens, final_chunk_state_h_box, final_chunk_state_c_box
def forward_chunk_by_chunk(self, x, x_lens, decoder_chunk_size=8): def forward_chunk_by_chunk(self, x, x_lens, decoder_chunk_size=8):
"""Compute Encoder outputs """Compute Encoder outputs
@ -153,9 +202,10 @@ class CRNNEncoder(nn.Layer):
x_lens (Tensor): [B] x_lens (Tensor): [B]
decoder_chunk_size: The chunk size of decoder decoder_chunk_size: The chunk size of decoder
Returns: Returns:
eouts_chunk_list (List of Tensor): The list of encoder outputs in chunk_size, [B, chunk_size, D] * num_chunks eouts_list (List of Tensor): The list of encoder outputs in chunk_size, [B, chunk_size, D] * num_chunks
eouts_chunk_lens_list (List of Tensor): The list of encoder length in chunk_size, [B] * num_chunks eouts_lens_list (List of Tensor): The list of encoder length in chunk_size, [B] * num_chunks
final_chunk_state_list: list of final_states for RNN layers, [num_directions, batch_size, hidden_size] * num_rnn_layers final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
""" """
subsampling_rate = self.conv.subsampling_rate subsampling_rate = self.conv.subsampling_rate
receptive_field_length = self.conv.receptive_field_length receptive_field_length = self.conv.receptive_field_length
@ -173,8 +223,10 @@ class CRNNEncoder(nn.Layer):
padded_x = paddle.concat([x, padding], axis=1) padded_x = paddle.concat([x, padding], axis=1)
num_chunk = (max_len + padding_len - chunk_size) / chunk_stride + 1 num_chunk = (max_len + padding_len - chunk_size) / chunk_stride + 1
num_chunk = int(num_chunk) num_chunk = int(num_chunk)
chunk_state_list = [None] * self.num_rnn_layers chunk_state_h_box = None
final_chunk_state_list = None chunk_state_c_box = None
final_state_h_box = None
final_state_c_box = None
for i in range(0, num_chunk): for i in range(0, num_chunk):
start = i * chunk_stride start = i * chunk_stride
end = start + chunk_size end = start + chunk_size
@ -190,13 +242,14 @@ class CRNNEncoder(nn.Layer):
x_chunk_lens = paddle.where(x_len_left < x_chunk_len_tmp, x_chunk_lens = paddle.where(x_len_left < x_chunk_len_tmp,
x_len_left, x_chunk_len_tmp) x_len_left, x_chunk_len_tmp)
eouts_chunk, eouts_chunk_lens, chunk_state_list = self.forward_chunk( eouts_chunk, eouts_chunk_lens, chunk_state_h_box, chunk_state_c_box = self.forward_chunk(
x_chunk, x_chunk_lens, chunk_state_list) x_chunk, x_chunk_lens, chunk_state_h_box, chunk_state_c_box)
eouts_chunk_list.append(eouts_chunk) eouts_chunk_list.append(eouts_chunk)
eouts_chunk_lens_list.append(eouts_chunk_lens) eouts_chunk_lens_list.append(eouts_chunk_lens)
final_chunk_state_list = chunk_state_list final_state_h_box = chunk_state_h_box
return eouts_chunk_list, eouts_chunk_lens_list, final_chunk_state_list final_state_c_box = chunk_state_c_box
return eouts_chunk_list, eouts_chunk_lens_list, final_state_h_box, final_state_c_box
class DeepSpeech2ModelOnline(nn.Layer): class DeepSpeech2ModelOnline(nn.Layer):
@ -283,7 +336,8 @@ class DeepSpeech2ModelOnline(nn.Layer):
Returns: Returns:
loss (Tenosr): [1] loss (Tenosr): [1]
""" """
eouts, eouts_len, final_state_list = self.encoder(audio, audio_len) eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder(
audio, audio_len)
loss = self.decoder(eouts, eouts_len, text, text_len) loss = self.decoder(eouts, eouts_len, text, text_len)
return loss return loss
@ -300,7 +354,8 @@ class DeepSpeech2ModelOnline(nn.Layer):
vocab_list=vocab_list, vocab_list=vocab_list,
decoding_method=decoding_method) decoding_method=decoding_method)
eouts, eouts_len, final_state_list = self.encoder(audio, audio_len) eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder(
audio, audio_len)
probs = self.decoder.softmax(eouts) probs = self.decoder.softmax(eouts)
return self.decoder.decode_probs( return self.decoder.decode_probs(
probs.numpy(), eouts_len, vocab_list, decoding_method, probs.numpy(), eouts_len, vocab_list, decoding_method,
@ -363,8 +418,9 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline):
fc_layers_size_list=fc_layers_size_list, fc_layers_size_list=fc_layers_size_list,
use_gru=use_gru) use_gru=use_gru)
def forward(self, audio_chunk, audio_chunk_lens, chunk_state_list): def forward(self, audio_chunk, audio_chunk_lens, chunk_state_h_box,
eouts_chunk, eouts_chunk_lens, final_state_list = self.encoder.forward_chunk( chunk_state_c_box):
audio_chunk, audio_chunk_lens, chunk_state_list) eouts_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box = self.encoder.forward_chunk(
audio_chunk, audio_chunk_lens, chunk_state_h_box, chunk_state_c_box)
probs_chunk = self.decoder.softmax(eouts_chunk) probs_chunk = self.decoder.softmax(eouts_chunk)
return probs_chunk, eouts_chunk_lens, final_state_list return probs_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box

@ -119,9 +119,9 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
paddle.device.set_device("cpu") paddle.device.set_device("cpu")
de_ch_size = 9 de_ch_size = 9
eouts, eouts_lens, final_state_list = model.encoder(self.audio, eouts, eouts_lens, final_state_h_box, final_state_c_box = model.encoder(
self.audio_len) self.audio, self.audio_len)
eouts_by_chk_list, eouts_lens_by_chk_list, final_state_list_by_chk = model.encoder.forward_chunk_by_chunk( eouts_by_chk_list, eouts_lens_by_chk_list, final_state_h_box_chk, final_state_c_box_chk = model.encoder.forward_chunk_by_chunk(
self.audio, self.audio_len, de_ch_size) self.audio, self.audio_len, de_ch_size)
eouts_by_chk = paddle.concat(eouts_by_chk_list, axis=1) eouts_by_chk = paddle.concat(eouts_by_chk_list, axis=1)
eouts_lens_by_chk = paddle.add_n(eouts_lens_by_chk_list) eouts_lens_by_chk = paddle.add_n(eouts_lens_by_chk_list)
@ -134,6 +134,10 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
self.assertEqual( self.assertEqual(
paddle.sum(paddle.abs(paddle.subtract(eouts, eouts_by_chk))), 0) paddle.sum(paddle.abs(paddle.subtract(eouts, eouts_by_chk))), 0)
self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True) self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True)
self.assertEqual(
paddle.allclose(final_state_h_box, final_state_h_box_chk), True)
self.assertEqual(
paddle.allclose(final_state_c_box, final_state_c_box_chk), True)
""" """
print ("conv_x", conv_x) print ("conv_x", conv_x)
print ("conv_x_by_chk", conv_x_by_chk) print ("conv_x_by_chk", conv_x_by_chk)

Loading…
Cancel
Save