|
|
|
@ -54,7 +54,7 @@ def main(args, config):
|
|
|
|
|
# stage1: we must call the paddle.distributed.init_parallel_env() api at the begining
|
|
|
|
|
paddle.distributed.init_parallel_env()
|
|
|
|
|
nranks = paddle.distributed.get_world_size()
|
|
|
|
|
local_rank = paddle.distributed.get_rank()
|
|
|
|
|
rank = paddle.distributed.get_rank()
|
|
|
|
|
# set the random seed, it is the necessary measures for multiprocess training
|
|
|
|
|
seed_everything(config.seed)
|
|
|
|
|
|
|
|
|
@ -112,10 +112,10 @@ def main(args, config):
|
|
|
|
|
state_dict = paddle.load(
|
|
|
|
|
os.path.join(args.load_checkpoint, 'model.pdopt'))
|
|
|
|
|
optimizer.set_state_dict(state_dict)
|
|
|
|
|
if local_rank == 0:
|
|
|
|
|
if rank == 0:
|
|
|
|
|
logger.info(f'Checkpoint loaded from {args.load_checkpoint}')
|
|
|
|
|
except FileExistsError:
|
|
|
|
|
if local_rank == 0:
|
|
|
|
|
if rank == 0:
|
|
|
|
|
logger.info('Train from scratch.')
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
@ -219,7 +219,7 @@ def main(args, config):
|
|
|
|
|
timer.count() # step plus one in timer
|
|
|
|
|
|
|
|
|
|
# stage 9-10: print the log information only on 0-rank per log-freq batchs
|
|
|
|
|
if (batch_idx + 1) % config.log_interval == 0 and local_rank == 0:
|
|
|
|
|
if (batch_idx + 1) % config.log_interval == 0 and rank == 0:
|
|
|
|
|
lr = optimizer.get_lr()
|
|
|
|
|
avg_loss /= config.log_interval
|
|
|
|
|
avg_acc = num_corrects / num_samples
|
|
|
|
@ -250,7 +250,7 @@ def main(args, config):
|
|
|
|
|
|
|
|
|
|
# stage 9-11: save the model parameters only on 0-rank per save-freq batchs
|
|
|
|
|
if epoch % config.save_interval == 0 and batch_idx + 1 == steps_per_epoch:
|
|
|
|
|
if local_rank != 0:
|
|
|
|
|
if rank != 0:
|
|
|
|
|
paddle.distributed.barrier(
|
|
|
|
|
) # Wait for valid step in main process
|
|
|
|
|
continue # Resume trainning on other process
|
|
|
|
@ -317,7 +317,7 @@ def main(args, config):
|
|
|
|
|
paddle.distributed.barrier() # Main process
|
|
|
|
|
|
|
|
|
|
# stage 10: create the final trained model.pdparams with soft link
|
|
|
|
|
if local_rank == 0:
|
|
|
|
|
if rank == 0:
|
|
|
|
|
final_model = os.path.join(args.checkpoint_dir, "model.pdparams")
|
|
|
|
|
logger.info(f"we will create the final model: {final_model}")
|
|
|
|
|
if os.path.islink(final_model):
|
|
|
|
|