diff --git a/deepspeech/exps/deepspeech2/bin/train.py b/deepspeech/exps/deepspeech2/bin/train.py index 69ff043a0..bb0bd43a8 100644 --- a/deepspeech/exps/deepspeech2/bin/train.py +++ b/deepspeech/exps/deepspeech2/bin/train.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Trainer for DeepSpeech2 model.""" +import os + from paddle import distributed as dist from deepspeech.exps.deepspeech2.config import get_cfg_defaults @@ -53,5 +55,7 @@ if __name__ == "__main__": if args.dump_config: with open(args.dump_config, 'w') as f: print(config, file=f) + if config.training.seed is not None: + os.environ.setdefault('FLAGS_cudnn_deterministic', 'True') main(config, args) diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 65c905a13..1bd4c722f 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Contains DeepSpeech2 and DeepSpeech2Online model.""" +import random import time from collections import defaultdict from pathlib import Path @@ -53,6 +54,7 @@ class DeepSpeech2Trainer(Trainer): weight_decay=1e-6, # the coeff of weight decay global_grad_clip=5.0, # the global norm clip n_epoch=50, # train epochs + seed=1024, #train seed )) if config is not None: @@ -61,6 +63,13 @@ class DeepSpeech2Trainer(Trainer): def __init__(self, config, args): super().__init__(config, args) + if config.training.seed is not None: + self.set_seed(config.training.seed) + + def set_seed(self, seed): + np.random.seed(seed) + random.seed(seed) + paddle.seed(seed) def train_batch(self, batch_index, batch_data, msg): start = time.time() diff --git a/deepspeech/models/ds2_online/conv.py b/deepspeech/models/ds2_online/conv.py index 1af69e28c..4a6fd5abd 100644 --- a/deepspeech/models/ds2_online/conv.py +++ b/deepspeech/models/ds2_online/conv.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import paddle -from paddle import nn -from deepspeech.modules.embedding import PositionalEncoding from deepspeech.modules.subsampling import Conv2dSubsampling4 diff --git a/deepspeech/models/ds2_online/deepspeech2.py b/deepspeech/models/ds2_online/deepspeech2.py index 3083e4b2a..d092b154b 100644 --- a/deepspeech/models/ds2_online/deepspeech2.py +++ b/deepspeech/models/ds2_online/deepspeech2.py @@ -26,7 +26,7 @@ from deepspeech.utils.checkpoint import Checkpoint from deepspeech.utils.log import Log logger = Log(__name__).getlog() -__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModeOnline'] +__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModelOnline'] class CRNNEncoder(nn.Layer): @@ -68,7 +68,7 @@ class CRNNEncoder(nn.Layer): rnn_input_size = i_size else: rnn_input_size = layernorm_size - if use_gru == True: + if use_gru is True: self.rnn.append( nn.GRU( input_size=rnn_input_size, @@ -102,18 +102,18 @@ class CRNNEncoder(nn.Layer): Args: x (Tensor): [B, feature_size, D] x_lens (Tensor): [B] - init_state_h_box(Tensor): init_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size - init_state_c_box(Tensor): init_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size - Returns: + init_state_h_box(Tensor): init_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] + init_state_c_box(Tensor): init_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] + Return: x (Tensor): encoder outputs, [B, size, D] x_lens (Tensor): encoder length, [B] - final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size - final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size + final_state_h_box(Tensor): final_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] + final_state_c_box(Tensor): final_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] """ if init_state_h_box is not None: init_state_list = None - if self.use_gru == True: + if self.use_gru is True: init_state_h_list = paddle.split( init_state_h_box, self.num_rnn_layers, axis=0) init_state_list = init_state_h_list @@ -139,10 +139,10 @@ class CRNNEncoder(nn.Layer): x = self.fc_layers_list[i](x) x = F.relu(x) - if self.use_gru == True: + if self.use_gru is True: final_chunk_state_h_box = paddle.concat( final_chunk_state_list, axis=0) - final_chunk_state_c_box = init_state_c_box #paddle.zeros_like(final_chunk_state_h_box) + final_chunk_state_c_box = init_state_c_box else: final_chunk_state_h_list = [ final_chunk_state_list[i][0] for i in range(self.num_rnn_layers) @@ -165,10 +165,10 @@ class CRNNEncoder(nn.Layer): x_lens (Tensor): [B] decoder_chunk_size: The chunk size of decoder Returns: - eouts_list (List of Tensor): The list of encoder outputs in chunk_size, [B, chunk_size, D] * num_chunks - eouts_lens_list (List of Tensor): The list of encoder length in chunk_size, [B] * num_chunks - final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size - final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size + eouts_list (List of Tensor): The list of encoder outputs in chunk_size: [B, chunk_size, D] * num_chunks + eouts_lens_list (List of Tensor): The list of encoder length in chunk_size: [B] * num_chunks + final_state_h_box(Tensor): final_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] + final_state_c_box(Tensor): final_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] """ subsampling_rate = self.conv.subsampling_rate receptive_field_length = self.conv.receptive_field_length @@ -215,12 +215,14 @@ class CRNNEncoder(nn.Layer): class DeepSpeech2ModelOnline(nn.Layer): """The DeepSpeech2 network structure for online. - :param audio_data: Audio spectrogram data layer. - :type audio_data: Variable - :param text_data: Transcription text data layer. - :type text_data: Variable + :param audio: Audio spectrogram data layer. + :type audio: Variable + :param text: Transcription text data layer. + :type text: Variable :param audio_len: Valid sequence length data layer. :type audio_len: Variable + :param feat_size: feature size for audio. + :type feat_size: int :param dict_size: Dictionary size for tokenized transcription. :type dict_size: int :param num_conv_layers: Number of stacking convolution layers. diff --git a/tests/deepspeech2_online_model_test.py b/tests/deepspeech2_online_model_test.py index 87f048870..6264070be 100644 --- a/tests/deepspeech2_online_model_test.py +++ b/tests/deepspeech2_online_model_test.py @@ -146,7 +146,7 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase): self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True) self.assertEqual( paddle.allclose(final_state_h_box, final_state_h_box_chk), True) - if use_gru == False: + if use_gru is False: self.assertEqual( paddle.allclose(final_state_c_box, final_state_c_box_chk), True) @@ -177,7 +177,7 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase): self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True) self.assertEqual( paddle.allclose(final_state_h_box, final_state_h_box_chk), True) - if use_gru == False: + if use_gru is False: self.assertEqual( paddle.allclose(final_state_c_box, final_state_c_box_chk), True)