diff --git a/paddlespeech/t2s/exps/vits/train.py b/paddlespeech/t2s/exps/vits/train.py index cdfd30034..0e74bf631 100644 --- a/paddlespeech/t2s/exps/vits/train.py +++ b/paddlespeech/t2s/exps/vits/train.py @@ -230,15 +230,17 @@ def train_sp(args, config): output_dir=output_dir) trainer = Trainer( - updater, stop_trigger=(config.max_epoch, 'epoch'), out=output_dir) + updater, + stop_trigger=(config.train_max_steps, "iteration"), + out=output_dir) if dist.get_rank() == 0: trainer.extend( - evaluator, trigger=(config.eval_interval_epochs, 'epoch')) + evaluator, trigger=(config.eval_interval_steps, 'iteration')) trainer.extend(VisualDL(output_dir), trigger=(1, 'iteration')) trainer.extend( Snapshot(max_size=config.num_snapshots), - trigger=(config.save_interval_epochs, 'epoch')) + trigger=(config.save_interval_steps, 'iteration')) print("Trainer Done!") trainer.run()