changet vector train.py local_rank to rank, test=doc

pull/1884/head
xiongxinlei 2 years ago
parent 597d601dec
commit 347af638e2

@ -54,7 +54,7 @@ def main(args, config):
# stage1: we must call the paddle.distributed.init_parallel_env() api at the begining
paddle.distributed.init_parallel_env()
nranks = paddle.distributed.get_world_size()
local_rank = paddle.distributed.get_rank()
rank = paddle.distributed.get_rank()
# set the random seed, it is the necessary measures for multiprocess training
seed_everything(config.seed)
@ -112,10 +112,10 @@ def main(args, config):
state_dict = paddle.load(
os.path.join(args.load_checkpoint, 'model.pdopt'))
optimizer.set_state_dict(state_dict)
if local_rank == 0:
if rank == 0:
logger.info(f'Checkpoint loaded from {args.load_checkpoint}')
except FileExistsError:
if local_rank == 0:
if rank == 0:
logger.info('Train from scratch.')
try:
@ -219,7 +219,7 @@ def main(args, config):
timer.count() # step plus one in timer
# stage 9-10: print the log information only on 0-rank per log-freq batchs
if (batch_idx + 1) % config.log_interval == 0 and local_rank == 0:
if (batch_idx + 1) % config.log_interval == 0 and rank == 0:
lr = optimizer.get_lr()
avg_loss /= config.log_interval
avg_acc = num_corrects / num_samples
@ -250,7 +250,7 @@ def main(args, config):
# stage 9-11: save the model parameters only on 0-rank per save-freq batchs
if epoch % config.save_interval == 0 and batch_idx + 1 == steps_per_epoch:
if local_rank != 0:
if rank != 0:
paddle.distributed.barrier(
) # Wait for valid step in main process
continue # Resume trainning on other process
@ -317,7 +317,7 @@ def main(args, config):
paddle.distributed.barrier() # Main process
# stage 10: create the final trained model.pdparams with soft link
if local_rank == 0:
if rank == 0:
final_model = os.path.join(args.checkpoint_dir, "model.pdparams")
logger.info(f"we will create the final model: {final_model}")
if os.path.islink(final_model):

Loading…
Cancel
Save