diff --git a/paddlespeech/t2s/exps/vits/train.py b/paddlespeech/t2s/exps/vits/train.py index 0e74bf631..8e166beb7 100644 --- a/paddlespeech/t2s/exps/vits/train.py +++ b/paddlespeech/t2s/exps/vits/train.py @@ -231,7 +231,7 @@ def train_sp(args, config): trainer = Trainer( updater, - stop_trigger=(config.train_max_steps, "iteration"), + stop_trigger=(config.max_epoch, 'epoch'), out=output_dir) if dist.get_rank() == 0: diff --git a/paddlespeech/t2s/models/vits/vits_updater.py b/paddlespeech/t2s/models/vits/vits_updater.py index 7926bb6a5..e61e617cc 100644 --- a/paddlespeech/t2s/models/vits/vits_updater.py +++ b/paddlespeech/t2s/models/vits/vits_updater.py @@ -166,6 +166,7 @@ class VITSUpdater(StandardUpdater): gen_loss.backward() self.optimizer_g.step() + # learning rate updates on each epoch. if self.state.iteration % self.updates_per_epoch == 0: self.scheduler_g.step() @@ -203,6 +204,7 @@ class VITSUpdater(StandardUpdater): dis_loss.backward() self.optimizer_d.step() + # learning rate updates on each epoch. if self.state.iteration % self.updates_per_epoch == 0: self.scheduler_d.step()