|
|
|
@ -230,17 +230,15 @@ def train_sp(args, config):
|
|
|
|
|
output_dir=output_dir)
|
|
|
|
|
|
|
|
|
|
trainer = Trainer(
|
|
|
|
|
updater,
|
|
|
|
|
stop_trigger=(config.train_max_steps, "iteration"),
|
|
|
|
|
out=output_dir)
|
|
|
|
|
updater, stop_trigger=(config.max_epoch, 'epoch'), out=output_dir)
|
|
|
|
|
|
|
|
|
|
if dist.get_rank() == 0:
|
|
|
|
|
trainer.extend(
|
|
|
|
|
evaluator, trigger=(config.eval_interval_steps, 'iteration'))
|
|
|
|
|
evaluator, trigger=(config.eval_interval_epochs, 'epoch'))
|
|
|
|
|
trainer.extend(VisualDL(output_dir), trigger=(1, 'iteration'))
|
|
|
|
|
trainer.extend(
|
|
|
|
|
Snapshot(max_size=config.num_snapshots),
|
|
|
|
|
trigger=(config.save_interval_steps, 'iteration'))
|
|
|
|
|
trigger=(config.save_interval_epochs, 'epoch'))
|
|
|
|
|
|
|
|
|
|
print("Trainer Done!")
|
|
|
|
|
trainer.run()
|
|
|
|
|