From 347af638e28301cfefe1e0608ba28ebf02bbb5ad Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Wed, 11 May 2022 09:43:40 +0800 Subject: [PATCH] changet vector train.py local_rank to rank, test=doc --- paddlespeech/vector/exps/ecapa_tdnn/train.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/paddlespeech/vector/exps/ecapa_tdnn/train.py b/paddlespeech/vector/exps/ecapa_tdnn/train.py index aad148a98..bf014045d 100644 --- a/paddlespeech/vector/exps/ecapa_tdnn/train.py +++ b/paddlespeech/vector/exps/ecapa_tdnn/train.py @@ -54,7 +54,7 @@ def main(args, config): # stage1: we must call the paddle.distributed.init_parallel_env() api at the begining paddle.distributed.init_parallel_env() nranks = paddle.distributed.get_world_size() - local_rank = paddle.distributed.get_rank() + rank = paddle.distributed.get_rank() # set the random seed, it is the necessary measures for multiprocess training seed_everything(config.seed) @@ -112,10 +112,10 @@ def main(args, config): state_dict = paddle.load( os.path.join(args.load_checkpoint, 'model.pdopt')) optimizer.set_state_dict(state_dict) - if local_rank == 0: + if rank == 0: logger.info(f'Checkpoint loaded from {args.load_checkpoint}') except FileExistsError: - if local_rank == 0: + if rank == 0: logger.info('Train from scratch.') try: @@ -219,7 +219,7 @@ def main(args, config): timer.count() # step plus one in timer # stage 9-10: print the log information only on 0-rank per log-freq batchs - if (batch_idx + 1) % config.log_interval == 0 and local_rank == 0: + if (batch_idx + 1) % config.log_interval == 0 and rank == 0: lr = optimizer.get_lr() avg_loss /= config.log_interval avg_acc = num_corrects / num_samples @@ -250,7 +250,7 @@ def main(args, config): # stage 9-11: save the model parameters only on 0-rank per save-freq batchs if epoch % config.save_interval == 0 and batch_idx + 1 == steps_per_epoch: - if local_rank != 0: + if rank != 0: paddle.distributed.barrier( ) # Wait for valid step in main process continue # Resume trainning on other process @@ -317,7 +317,7 @@ def main(args, config): paddle.distributed.barrier() # Main process # stage 10: create the final trained model.pdparams with soft link - if local_rank == 0: + if rank == 0: final_model = os.path.join(args.checkpoint_dir, "model.pdparams") logger.info(f"we will create the final model: {final_model}") if os.path.islink(final_model):