diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 4acfad86b..51ef1de47 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -376,16 +376,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): dtype='float32'), # audio, [B,T,D] paddle.static.InputSpec(shape=[None], dtype='int64'), # audio_length, [B] - [ - ( - paddle.static.InputSpec( - shape=[None, None, None], dtype='float32' - ), #num_rnn_layers * num_dirctions, rnn_size - paddle.static.InputSpec( - shape=[None, None, None], dtype='float32' - ) #num_rnn_layers * num_dirctions, rnn_size - ) for i in range(self.config.model.num_rnn_layers) - ] + paddle.static.InputSpec( + shape=[None, None, None], dtype='float32'), + paddle.static.InputSpec( + shape=[None, None, None], dtype='float32') ]) else: raise Exception("wrong model type") diff --git a/deepspeech/models/ds2_online/deepspeech2.py b/deepspeech/models/ds2_online/deepspeech2.py index bed9c41d3..b42ac8ec1 100644 --- a/deepspeech/models/ds2_online/deepspeech2.py +++ b/deepspeech/models/ds2_online/deepspeech2.py @@ -48,6 +48,7 @@ class CRNNEncoder(nn.Layer): self.num_fc_layers = num_fc_layers self.rnn_direction = rnn_direction self.fc_layers_size_list = fc_layers_size_list + self.use_gru = use_gru self.conv = Conv2dSubsampling4Online(feat_size, 32, dropout_rate=0.0) i_size = self.conv.output_dim @@ -96,7 +97,8 @@ class CRNNEncoder(nn.Layer): Returns: x (Tensor): encoder outputs, [B, T_output, D] x_lens (Tensor): encoder length, [B] - final_state_list: list of final_states for RNN layers, [num_directions, batch_size, hidden_size] * num_rnn_layers + final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size + final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size """ # [B, T, D] # convolution group @@ -118,32 +120,79 @@ class CRNNEncoder(nn.Layer): for i in range(self.num_fc_layers): x = self.fc_layers_list[i](x) x = F.relu(x) - return x, x_lens, final_state_list - def forward_chunk(self, x, x_lens, init_state_list): + if self.use_gru == True: + final_state_h_box = paddle.concat(final_state_list, axis=0) + final_state_c_box = paddle.zeros_like(final_state_h_box) + else: + final_state_h_list = [ + final_state_list[i][0] for i in range(self.num_rnn_layers) + ] + final_state_c_list = [ + final_state_list[i][1] for i in range(self.num_rnn_layers) + ] + final_state_h_box = paddle.concat(final_state_h_list, axis=0) + final_state_c_box = paddle.concat(final_state_c_list, axis=0) + + return x, x_lens, final_state_h_box, final_state_c_box + + def forward_chunk(self, x, x_lens, init_state_h_box, init_state_c_box): """Compute Encoder outputs Args: - x (Tensor): [B, feature_chunk_size, D] + x (Tensor): [B, feature_size, D] x_lens (Tensor): [B] - init_state_list (list of Tensors): [ num_directions, batch_size, hidden_size] * num_rnn_layers + init_state_h_box(Tensor): init_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size + init_state_c_box(Tensor): init_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size Returns: - x (Tensor): encoder outputs, [B, chunk_size, D] + x (Tensor): encoder outputs, [B, size, D] x_lens (Tensor): encoder length, [B] - chunk_final_state_list: list of final_states for RNN layers, [num_directions, batch_size, hidden_size] * num_rnn_layers + final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size + final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size """ + if init_state_h_box is not None: + init_state_list = None + init_state_h_list = paddle.split( + init_state_h_box, self.num_rnn_layers, axis=0) + init_state_c_list = paddle.split( + init_state_c_box, self.num_rnn_layers, axis=0) + if self.use_gru == True: + init_state_list = init_state_h_list + else: + init_state_list = [(init_state_h_list[i], init_state_c_list[i]) + for i in range(self.num_rnn_layers)] + else: + init_state_list = [None] * self.num_rnn_layers + x, x_lens = self.conv(x, x_lens) - chunk_final_state_list = [] + final_chunk_state_list = [] for i in range(0, self.num_rnn_layers): x, final_state = self.rnn[i](x, init_state_list[i], x_lens) #[B, T, D] - chunk_final_state_list.append(final_state) + final_chunk_state_list.append(final_state) x = self.layernorm_list[i](x) for i in range(self.num_fc_layers): x = self.fc_layers_list[i](x) x = F.relu(x) - return x, x_lens, chunk_final_state_list + + if self.use_gru == True: + final_chunk_state_h_box = paddle.concat( + final_chunk_state_list, axis=0) + final_chunk_state_c_box = paddle.zeros_like(final_chunk_state_h_box) + else: + final_chunk_state_h_list = [ + final_chunk_state_list[i][0] for i in range(self.num_rnn_layers) + ] + final_chunk_state_c_list = [ + final_chunk_state_list[i][1] for i in range(self.num_rnn_layers) + ] + final_chunk_state_h_box = paddle.concat( + final_chunk_state_h_list, axis=0) + final_chunk_state_c_box = paddle.concat( + final_chunk_state_c_list, axis=0) + + return x, x_lens, final_chunk_state_h_box, final_chunk_state_c_box def forward_chunk_by_chunk(self, x, x_lens, decoder_chunk_size=8): """Compute Encoder outputs @@ -153,9 +202,10 @@ class CRNNEncoder(nn.Layer): x_lens (Tensor): [B] decoder_chunk_size: The chunk size of decoder Returns: - eouts_chunk_list (List of Tensor): The list of encoder outputs in chunk_size, [B, chunk_size, D] * num_chunks - eouts_chunk_lens_list (List of Tensor): The list of encoder length in chunk_size, [B] * num_chunks - final_chunk_state_list: list of final_states for RNN layers, [num_directions, batch_size, hidden_size] * num_rnn_layers + eouts_list (List of Tensor): The list of encoder outputs in chunk_size, [B, chunk_size, D] * num_chunks + eouts_lens_list (List of Tensor): The list of encoder length in chunk_size, [B] * num_chunks + final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size + final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size """ subsampling_rate = self.conv.subsampling_rate receptive_field_length = self.conv.receptive_field_length @@ -173,8 +223,10 @@ class CRNNEncoder(nn.Layer): padded_x = paddle.concat([x, padding], axis=1) num_chunk = (max_len + padding_len - chunk_size) / chunk_stride + 1 num_chunk = int(num_chunk) - chunk_state_list = [None] * self.num_rnn_layers - final_chunk_state_list = None + chunk_state_h_box = None + chunk_state_c_box = None + final_state_h_box = None + final_state_c_box = None for i in range(0, num_chunk): start = i * chunk_stride end = start + chunk_size @@ -190,13 +242,14 @@ class CRNNEncoder(nn.Layer): x_chunk_lens = paddle.where(x_len_left < x_chunk_len_tmp, x_len_left, x_chunk_len_tmp) - eouts_chunk, eouts_chunk_lens, chunk_state_list = self.forward_chunk( - x_chunk, x_chunk_lens, chunk_state_list) + eouts_chunk, eouts_chunk_lens, chunk_state_h_box, chunk_state_c_box = self.forward_chunk( + x_chunk, x_chunk_lens, chunk_state_h_box, chunk_state_c_box) eouts_chunk_list.append(eouts_chunk) eouts_chunk_lens_list.append(eouts_chunk_lens) - final_chunk_state_list = chunk_state_list - return eouts_chunk_list, eouts_chunk_lens_list, final_chunk_state_list + final_state_h_box = chunk_state_h_box + final_state_c_box = chunk_state_c_box + return eouts_chunk_list, eouts_chunk_lens_list, final_state_h_box, final_state_c_box class DeepSpeech2ModelOnline(nn.Layer): @@ -283,7 +336,8 @@ class DeepSpeech2ModelOnline(nn.Layer): Returns: loss (Tenosr): [1] """ - eouts, eouts_len, final_state_list = self.encoder(audio, audio_len) + eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder( + audio, audio_len) loss = self.decoder(eouts, eouts_len, text, text_len) return loss @@ -300,7 +354,8 @@ class DeepSpeech2ModelOnline(nn.Layer): vocab_list=vocab_list, decoding_method=decoding_method) - eouts, eouts_len, final_state_list = self.encoder(audio, audio_len) + eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder( + audio, audio_len) probs = self.decoder.softmax(eouts) return self.decoder.decode_probs( probs.numpy(), eouts_len, vocab_list, decoding_method, @@ -363,8 +418,9 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline): fc_layers_size_list=fc_layers_size_list, use_gru=use_gru) - def forward(self, audio_chunk, audio_chunk_lens, chunk_state_list): - eouts_chunk, eouts_chunk_lens, final_state_list = self.encoder.forward_chunk( - audio_chunk, audio_chunk_lens, chunk_state_list) + def forward(self, audio_chunk, audio_chunk_lens, chunk_state_h_box, + chunk_state_c_box): + eouts_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box = self.encoder.forward_chunk( + audio_chunk, audio_chunk_lens, chunk_state_h_box, chunk_state_c_box) probs_chunk = self.decoder.softmax(eouts_chunk) - return probs_chunk, eouts_chunk_lens, final_state_list + return probs_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box diff --git a/tests/deepspeech2_online_model_test.py b/tests/deepspeech2_online_model_test.py index ce235cd6d..fd1dfc4b5 100644 --- a/tests/deepspeech2_online_model_test.py +++ b/tests/deepspeech2_online_model_test.py @@ -119,9 +119,9 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase): paddle.device.set_device("cpu") de_ch_size = 9 - eouts, eouts_lens, final_state_list = model.encoder(self.audio, - self.audio_len) - eouts_by_chk_list, eouts_lens_by_chk_list, final_state_list_by_chk = model.encoder.forward_chunk_by_chunk( + eouts, eouts_lens, final_state_h_box, final_state_c_box = model.encoder( + self.audio, self.audio_len) + eouts_by_chk_list, eouts_lens_by_chk_list, final_state_h_box_chk, final_state_c_box_chk = model.encoder.forward_chunk_by_chunk( self.audio, self.audio_len, de_ch_size) eouts_by_chk = paddle.concat(eouts_by_chk_list, axis=1) eouts_lens_by_chk = paddle.add_n(eouts_lens_by_chk_list) @@ -134,6 +134,10 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase): self.assertEqual( paddle.sum(paddle.abs(paddle.subtract(eouts, eouts_by_chk))), 0) self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True) + self.assertEqual( + paddle.allclose(final_state_h_box, final_state_h_box_chk), True) + self.assertEqual( + paddle.allclose(final_state_c_box, final_state_c_box_chk), True) """ print ("conv_x", conv_x) print ("conv_x_by_chk", conv_x_by_chk)