diff --git a/deepspeech/exps/deepspeech2/bin/train.py b/deepspeech/exps/deepspeech2/bin/train.py index 32127022e..bb0bd43a8 100644 --- a/deepspeech/exps/deepspeech2/bin/train.py +++ b/deepspeech/exps/deepspeech2/bin/train.py @@ -55,7 +55,7 @@ if __name__ == "__main__": if args.dump_config: with open(args.dump_config, 'w') as f: print(config, file=f) - if config.training.seed != None: + if config.training.seed is not None: os.environ.setdefault('FLAGS_cudnn_deterministic', 'True') main(config, args) diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index a2bbee5e7..1bd4c722f 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Contains DeepSpeech2 and DeepSpeech2Online model.""" -import os import random import time from collections import defaultdict @@ -64,7 +63,7 @@ class DeepSpeech2Trainer(Trainer): def __init__(self, config, args): super().__init__(config, args) - if config.training.seed != None: + if config.training.seed is not None: self.set_seed(config.training.seed) def set_seed(self, seed): diff --git a/deepspeech/exps/u2/bin/train.py b/deepspeech/exps/u2/bin/train.py index ebd91faa8..9dd0041dd 100644 --- a/deepspeech/exps/u2/bin/train.py +++ b/deepspeech/exps/u2/bin/train.py @@ -52,10 +52,7 @@ if __name__ == "__main__": if args.dump_config: with open(args.dump_config, 'w') as f: print(config, file=f) - if config.training.seed != None: - os.environ.setdefault('FLAGS_cudnn_deterministic', 'True') - main(config, args) # Setting for profiling pr = cProfile.Profile() pr.runcall(main, config, args) diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index b248c5a6e..d661f078d 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -55,7 +55,7 @@ class U2Trainer(Trainer): log_interval=100, # steps accum_grad=1, # accum grad by # steps global_grad_clip=5.0, # the global norm clip - seed=1024, )) + )) default.optim = 'adam' default.optim_conf = CfgNode( dict( @@ -75,12 +75,6 @@ class U2Trainer(Trainer): def __init__(self, config, args): super().__init__(config, args) - if config.training.seed != None: - self.set_seed(config.training.seed) - - def set_seed(self, seed): - np.random.seed(seed) - paddle.seed(seed) def train_batch(self, batch_index, batch_data, msg): train_conf = self.config.training diff --git a/deepspeech/io/sampler.py b/deepspeech/io/sampler.py index 3b2ef757d..763a3781e 100644 --- a/deepspeech/io/sampler.py +++ b/deepspeech/io/sampler.py @@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False): """ rng = np.random.RandomState(epoch) shift_len = rng.randint(0, batch_size - 1) - batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size)) + batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size)) rng.shuffle(batch_indices) batch_indices = [item for batch in batch_indices for item in batch] assert clipped is False diff --git a/deepspeech/models/ds2_online/conv.py b/deepspeech/models/ds2_online/conv.py index 1af69e28c..4a6fd5abd 100644 --- a/deepspeech/models/ds2_online/conv.py +++ b/deepspeech/models/ds2_online/conv.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import paddle -from paddle import nn -from deepspeech.modules.embedding import PositionalEncoding from deepspeech.modules.subsampling import Conv2dSubsampling4 diff --git a/deepspeech/models/ds2_online/deepspeech2.py b/deepspeech/models/ds2_online/deepspeech2.py index e130968b4..d092b154b 100644 --- a/deepspeech/models/ds2_online/deepspeech2.py +++ b/deepspeech/models/ds2_online/deepspeech2.py @@ -26,7 +26,7 @@ from deepspeech.utils.checkpoint import Checkpoint from deepspeech.utils.log import Log logger = Log(__name__).getlog() -__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModeOnline'] +__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModelOnline'] class CRNNEncoder(nn.Layer): @@ -68,7 +68,7 @@ class CRNNEncoder(nn.Layer): rnn_input_size = i_size else: rnn_input_size = layernorm_size - if use_gru == True: + if use_gru is True: self.rnn.append( nn.GRU( input_size=rnn_input_size, @@ -113,7 +113,7 @@ class CRNNEncoder(nn.Layer): if init_state_h_box is not None: init_state_list = None - if self.use_gru == True: + if self.use_gru is True: init_state_h_list = paddle.split( init_state_h_box, self.num_rnn_layers, axis=0) init_state_list = init_state_h_list @@ -139,7 +139,7 @@ class CRNNEncoder(nn.Layer): x = self.fc_layers_list[i](x) x = F.relu(x) - if self.use_gru == True: + if self.use_gru is True: final_chunk_state_h_box = paddle.concat( final_chunk_state_list, axis=0) final_chunk_state_c_box = init_state_c_box diff --git a/examples/aishell/s0/conf/deepspeech2_online.yaml b/examples/aishell/s0/conf/deepspeech2_online.yaml index f78148e3f..fdc3a5365 100644 --- a/examples/aishell/s0/conf/deepspeech2_online.yaml +++ b/examples/aishell/s0/conf/deepspeech2_online.yaml @@ -46,7 +46,7 @@ model: training: n_epoch: 50 lr: 2e-3 - lr_decay: 0.9 # 0.83 + lr_decay: 0.91 # 0.83 weight_decay: 1e-06 global_grad_clip: 3.0 log_interval: 100 diff --git a/tests/deepspeech2_online_model_test.py b/tests/deepspeech2_online_model_test.py index 0e31b85f8..6264070be 100644 --- a/tests/deepspeech2_online_model_test.py +++ b/tests/deepspeech2_online_model_test.py @@ -143,10 +143,10 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase): eouts_lens_by_chk = paddle.add_n(eouts_lens_by_chk_list) decode_max_len = eouts.shape[1] eouts_by_chk = eouts_by_chk[:, :decode_max_len, :] - self.assertEqual(paddle.allclose(eouts_by_chk, eouts, atol=1e-5), True) + self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True) self.assertEqual( paddle.allclose(final_state_h_box, final_state_h_box_chk), True) - if use_gru == False: + if use_gru is False: self.assertEqual( paddle.allclose(final_state_c_box, final_state_c_box_chk), True) @@ -177,7 +177,7 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase): self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True) self.assertEqual( paddle.allclose(final_state_h_box, final_state_h_box_chk), True) - if use_gru == False: + if use_gru is False: self.assertEqual( paddle.allclose(final_state_c_box, final_state_c_box_chk), True)