From 2d3b2aed05c36cc173dd8370dab986b1f0cf6513 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Mon, 23 Aug 2021 14:06:10 +0000 Subject: [PATCH] add seed in argparse --- deepspeech/exps/deepspeech2/bin/train.py | 4 --- deepspeech/exps/deepspeech2/model.py | 9 ------- deepspeech/training/cli.py | 33 +++++++++++++----------- deepspeech/training/trainer.py | 9 +++++++ examples/aishell/s0/local/train.sh | 12 ++++++++- examples/aishell/s1/local/train.sh | 12 ++++++++- examples/callcenter/s1/local/train.sh | 12 ++++++++- examples/librispeech/s0/local/train.sh | 12 ++++++++- examples/librispeech/s1/local/train.sh | 12 ++++++++- examples/librispeech/s2/local/train.sh | 12 ++++++++- examples/ted_en_zh/t0/local/train.sh | 12 ++++++++- examples/timit/s1/local/train.sh | 12 ++++++++- examples/tiny/s0/local/train.sh | 12 ++++++++- examples/tiny/s1/local/train.sh | 12 ++++++++- 14 files changed, 137 insertions(+), 38 deletions(-) diff --git a/deepspeech/exps/deepspeech2/bin/train.py b/deepspeech/exps/deepspeech2/bin/train.py index bb0bd43a..69ff043a 100644 --- a/deepspeech/exps/deepspeech2/bin/train.py +++ b/deepspeech/exps/deepspeech2/bin/train.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Trainer for DeepSpeech2 model.""" -import os - from paddle import distributed as dist from deepspeech.exps.deepspeech2.config import get_cfg_defaults @@ -55,7 +53,5 @@ if __name__ == "__main__": if args.dump_config: with open(args.dump_config, 'w') as f: print(config, file=f) - if config.training.seed is not None: - os.environ.setdefault('FLAGS_cudnn_deterministic', 'True') main(config, args) diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 1bd4c722..65c905a1 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Contains DeepSpeech2 and DeepSpeech2Online model.""" -import random import time from collections import defaultdict from pathlib import Path @@ -54,7 +53,6 @@ class DeepSpeech2Trainer(Trainer): weight_decay=1e-6, # the coeff of weight decay global_grad_clip=5.0, # the global norm clip n_epoch=50, # train epochs - seed=1024, #train seed )) if config is not None: @@ -63,13 +61,6 @@ class DeepSpeech2Trainer(Trainer): def __init__(self, config, args): super().__init__(config, args) - if config.training.seed is not None: - self.set_seed(config.training.seed) - - def set_seed(self, seed): - np.random.seed(seed) - random.seed(seed) - paddle.seed(seed) def train_batch(self, batch_index, batch_data, msg): start = time.time() diff --git a/deepspeech/training/cli.py b/deepspeech/training/cli.py index b83d989d..d3b85355 100644 --- a/deepspeech/training/cli.py +++ b/deepspeech/training/cli.py @@ -16,23 +16,23 @@ import argparse def default_argument_parser(): r"""A simple yet genral argument parser for experiments with parakeet. - - This is used in examples with parakeet. And it is intended to be used by - other experiments with parakeet. It requires a minimal set of command line + + This is used in examples with parakeet. And it is intended to be used by + other experiments with parakeet. It requires a minimal set of command line arguments to start a training script. - - The ``--config`` and ``--opts`` are used for overwrite the deault + + The ``--config`` and ``--opts`` are used for overwrite the deault configuration. - - The ``--data`` and ``--output`` specifies the data path and output path. - Resuming training from existing progress at the output directory is the + + The ``--data`` and ``--output`` specifies the data path and output path. + Resuming training from existing progress at the output directory is the intended default behavior. - + The ``--checkpoint_path`` specifies the checkpoint to load from. - + The ``--device`` and ``--nprocs`` specifies how to run the training. - - + + See Also -------- parakeet.training.experiment @@ -53,10 +53,10 @@ def default_argument_parser(): # load from saved checkpoint parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load") - # save jit model to + # save jit model to parser.add_argument("--export_path", type=str, help="path of the jit model to save") - # save asr result to + # save asr result to parser.add_argument("--result_file", type=str, help="path of save the asr result") # running @@ -65,10 +65,13 @@ def default_argument_parser(): parser.add_argument("--nprocs", type=int, default=1, help="number of parallel processes to use.") # overwrite extra config and default config - # parser.add_argument("--opts", nargs=argparse.REMAINDER, + # parser.add_argument("--opts", nargs=argparse.REMAINDER, # help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") parser.add_argument("--opts", type=str, default=[], nargs='+', help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") + + parser.add_argument("--seed", type=int, default=None, + help="seed to use for paddle, np and random. The default value is None") # yapd: enable return parser diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 209e2240..2ab7eac0 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import random import time from pathlib import Path +import numpy as np import paddle from paddle import distributed as dist from tensorboardX import SummaryWriter @@ -93,6 +95,13 @@ class Trainer(): self.checkpoint_dir = None self.iteration = 0 self.epoch = 0 + if args.seed is not None: + self.set_seed(args.seed) + + def set_seed(self, seed): + np.random.seed(seed) + random.seed(seed) + paddle.seed(seed) def setup(self): """Setup the experiment. diff --git a/examples/aishell/s0/local/train.sh b/examples/aishell/s0/local/train.sh index c6a63180..d42e51fa 100755 --- a/examples/aishell/s0/local/train.sh +++ b/examples/aishell/s0/local/train.sh @@ -19,12 +19,22 @@ fi mkdir -p exp +seed=1024 +if [ ${seed} ]; then + export FLAGS_cudnn_deterministic=True +fi + python3 -u ${BIN_DIR}/train.py \ --device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ ---model_type ${model_type} +--model_type ${model_type} \ +--seed ${seed} + +if [ ${seed} ]; then + unset FLAGS_cudnn_deterministic +fi if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/aishell/s1/local/train.sh b/examples/aishell/s1/local/train.sh index f3eb98da..ec17054a 100755 --- a/examples/aishell/s1/local/train.sh +++ b/examples/aishell/s1/local/train.sh @@ -19,11 +19,21 @@ echo "using ${device}..." mkdir -p exp +seed=1024 +if [ ${seed} ]; then + export FLAGS_cudnn_deterministic=True +fi + python3 -u ${BIN_DIR}/train.py \ --device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ ---output exp/${ckpt_name} +--output exp/${ckpt_name} \ +--seed ${seed} + +if [ ${seed} ]; then + unset FLAGS_cudnn_deterministic +fi if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/callcenter/s1/local/train.sh b/examples/callcenter/s1/local/train.sh index f750568a..928c6492 100755 --- a/examples/callcenter/s1/local/train.sh +++ b/examples/callcenter/s1/local/train.sh @@ -19,11 +19,21 @@ echo "using ${device}..." mkdir -p exp +seed=1024 +if [ ${seed} ]; then + export FLAGS_cudnn_deterministic=True +fi + python3 -u ${BIN_DIR}/train.py \ --device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ ---output exp/${ckpt_name} +--output exp/${ckpt_name} \ +--seed ${seed} + +if [ ${seed} ]; then + unset FLAGS_cudnn_deterministic +fi if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/librispeech/s0/local/train.sh b/examples/librispeech/s0/local/train.sh index 039b9cea..dcd21df3 100755 --- a/examples/librispeech/s0/local/train.sh +++ b/examples/librispeech/s0/local/train.sh @@ -20,12 +20,22 @@ echo "using ${device}..." mkdir -p exp +seed=1024 +if [ ${seed} ]; then + export FLAGS_cudnn_deterministic=True +fi + python3 -u ${BIN_DIR}/train.py \ --device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ ---model_type ${model_type} +--model_type ${model_type} \ +--seed ${seed} + +if [ ${seed} ]; then + unset FLAGS_cudnn_deterministic +fi if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/librispeech/s1/local/train.sh b/examples/librispeech/s1/local/train.sh index f3eb98da..ec17054a 100755 --- a/examples/librispeech/s1/local/train.sh +++ b/examples/librispeech/s1/local/train.sh @@ -19,11 +19,21 @@ echo "using ${device}..." mkdir -p exp +seed=1024 +if [ ${seed} ]; then + export FLAGS_cudnn_deterministic=True +fi + python3 -u ${BIN_DIR}/train.py \ --device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ ---output exp/${ckpt_name} +--output exp/${ckpt_name} \ +--seed ${seed} + +if [ ${seed} ]; then + unset FLAGS_cudnn_deterministic +fi if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/librispeech/s2/local/train.sh b/examples/librispeech/s2/local/train.sh index f3eb98da..ec17054a 100755 --- a/examples/librispeech/s2/local/train.sh +++ b/examples/librispeech/s2/local/train.sh @@ -19,11 +19,21 @@ echo "using ${device}..." mkdir -p exp +seed=1024 +if [ ${seed} ]; then + export FLAGS_cudnn_deterministic=True +fi + python3 -u ${BIN_DIR}/train.py \ --device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ ---output exp/${ckpt_name} +--output exp/${ckpt_name} \ +--seed ${seed} + +if [ ${seed} ]; then + unset FLAGS_cudnn_deterministic +fi if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/ted_en_zh/t0/local/train.sh b/examples/ted_en_zh/t0/local/train.sh index f3eb98da..ec17054a 100755 --- a/examples/ted_en_zh/t0/local/train.sh +++ b/examples/ted_en_zh/t0/local/train.sh @@ -19,11 +19,21 @@ echo "using ${device}..." mkdir -p exp +seed=1024 +if [ ${seed} ]; then + export FLAGS_cudnn_deterministic=True +fi + python3 -u ${BIN_DIR}/train.py \ --device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ ---output exp/${ckpt_name} +--output exp/${ckpt_name} \ +--seed ${seed} + +if [ ${seed} ]; then + unset FLAGS_cudnn_deterministic +fi if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/timit/s1/local/train.sh b/examples/timit/s1/local/train.sh index f3eb98da..ec17054a 100755 --- a/examples/timit/s1/local/train.sh +++ b/examples/timit/s1/local/train.sh @@ -19,11 +19,21 @@ echo "using ${device}..." mkdir -p exp +seed=1024 +if [ ${seed} ]; then + export FLAGS_cudnn_deterministic=True +fi + python3 -u ${BIN_DIR}/train.py \ --device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ ---output exp/${ckpt_name} +--output exp/${ckpt_name} \ +--seed ${seed} + +if [ ${seed} ]; then + unset FLAGS_cudnn_deterministic +fi if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/tiny/s0/local/train.sh b/examples/tiny/s0/local/train.sh index c6a63180..d42e51fa 100755 --- a/examples/tiny/s0/local/train.sh +++ b/examples/tiny/s0/local/train.sh @@ -19,12 +19,22 @@ fi mkdir -p exp +seed=1024 +if [ ${seed} ]; then + export FLAGS_cudnn_deterministic=True +fi + python3 -u ${BIN_DIR}/train.py \ --device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ ---model_type ${model_type} +--model_type ${model_type} \ +--seed ${seed} + +if [ ${seed} ]; then + unset FLAGS_cudnn_deterministic +fi if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/tiny/s1/local/train.sh b/examples/tiny/s1/local/train.sh index f6bd2c98..2fb3a95a 100755 --- a/examples/tiny/s1/local/train.sh +++ b/examples/tiny/s1/local/train.sh @@ -18,11 +18,21 @@ fi mkdir -p exp +seed=1024 +if [ ${seed} ]; then + export FLAGS_cudnn_deterministic=True +fi + python3 -u ${BIN_DIR}/train.py \ --device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ ---output exp/${ckpt_name} +--output exp/${ckpt_name} \ +--seed ${seed} + +if [ ${seed} ]; then + unset FLAGS_cudnn_deterministic +fi if [ $? -ne 0 ]; then echo "Failed in training!"