Merge pull request #3186 from PaddlePaddle/vits_pr

[TTS]update lr schedulers from per iter to per epoch for VITS
pull/3201/head
Hui Zhang 2 years ago committed by GitHub
commit e3dcfa8815
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -179,7 +179,7 @@ generator_first: False # whether to start updating generator first
# OTHER TRAINING SETTING # # OTHER TRAINING SETTING #
########################################################## ##########################################################
num_snapshots: 10 # max number of snapshots to keep while training num_snapshots: 10 # max number of snapshots to keep while training
train_max_steps: 350000 # Number of training steps. == total_iters / ngpus, total_iters = 1000000 max_epoch: 1000 # Number of training epochs.
save_interval_steps: 1000 # Interval steps to save checkpoint. save_interval_epochs: 1 # Interval epochs to save checkpoint.
eval_interval_steps: 250 # Interval steps to evaluate the network. eval_interval_epochs: 1 # Interval steps to evaluate the network.
seed: 777 # random seed number seed: 777 # random seed number

@ -230,17 +230,15 @@ def train_sp(args, config):
output_dir=output_dir) output_dir=output_dir)
trainer = Trainer( trainer = Trainer(
updater, updater, stop_trigger=(config.max_epoch, 'epoch'), out=output_dir)
stop_trigger=(config.train_max_steps, "iteration"),
out=output_dir)
if dist.get_rank() == 0: if dist.get_rank() == 0:
trainer.extend( trainer.extend(
evaluator, trigger=(config.eval_interval_steps, 'iteration')) evaluator, trigger=(config.eval_interval_epochs, 'epoch'))
trainer.extend(VisualDL(output_dir), trigger=(1, 'iteration')) trainer.extend(VisualDL(output_dir), trigger=(1, 'iteration'))
trainer.extend( trainer.extend(
Snapshot(max_size=config.num_snapshots), Snapshot(max_size=config.num_snapshots),
trigger=(config.save_interval_steps, 'iteration')) trigger=(config.save_interval_epochs, 'epoch'))
print("Trainer Done!") print("Trainer Done!")
trainer.run() trainer.run()

@ -166,6 +166,8 @@ class VITSUpdater(StandardUpdater):
gen_loss.backward() gen_loss.backward()
self.optimizer_g.step() self.optimizer_g.step()
# learning rate updates on each epoch.
if self.state.iteration % self.updates_per_epoch == 0:
self.scheduler_g.step() self.scheduler_g.step()
# reset cache # reset cache
@ -202,6 +204,8 @@ class VITSUpdater(StandardUpdater):
dis_loss.backward() dis_loss.backward()
self.optimizer_d.step() self.optimizer_d.step()
# learning rate updates on each epoch.
if self.state.iteration % self.updates_per_epoch == 0:
self.scheduler_d.step() self.scheduler_d.step()
# reset cache # reset cache

Loading…
Cancel
Save