|
|
|
@ -230,9 +230,7 @@ def train_sp(args, config):
|
|
|
|
|
output_dir=output_dir)
|
|
|
|
|
|
|
|
|
|
trainer = Trainer(
|
|
|
|
|
updater,
|
|
|
|
|
stop_trigger=(config.max_epoch, 'epoch'),
|
|
|
|
|
out=output_dir)
|
|
|
|
|
updater, stop_trigger=(config.max_epoch, 'epoch'), out=output_dir)
|
|
|
|
|
|
|
|
|
|
if dist.get_rank() == 0:
|
|
|
|
|
trainer.extend(
|
|
|
|
@ -240,7 +238,7 @@ def train_sp(args, config):
|
|
|
|
|
trainer.extend(VisualDL(output_dir), trigger=(1, 'iteration'))
|
|
|
|
|
trainer.extend(
|
|
|
|
|
Snapshot(max_size=config.num_snapshots),
|
|
|
|
|
trigger=(config.save_interval_steps, 'iteration'))
|
|
|
|
|
trigger=(config.save_interval_epochs, 'epoch'))
|
|
|
|
|
|
|
|
|
|
print("Trainer Done!")
|
|
|
|
|
trainer.run()
|
|
|
|
|