diff --git a/paddlespeech/t2s/exps/vits/train.py b/paddlespeech/t2s/exps/vits/train.py index c0238a98..07301db5 100644 --- a/paddlespeech/t2s/exps/vits/train.py +++ b/paddlespeech/t2s/exps/vits/train.py @@ -24,7 +24,7 @@ import yaml from paddle import DataParallel from paddle import distributed as dist from paddle.io import DataLoader -from paddle.optimizer import Adam +from paddle.optimizer import AdamW from yacs.config import CfgNode from paddlespeech.t2s.datasets.am_batch_fn import vits_multi_spk_batch_fn @@ -164,14 +164,14 @@ def train_sp(args, config): lr_schedule_g = scheduler_classes[config["generator_scheduler"]]( **config["generator_scheduler_params"]) - optimizer_g = Adam( + optimizer_g = AdamW( learning_rate=lr_schedule_g, parameters=gen_parameters, **config["generator_optimizer_params"]) lr_schedule_d = scheduler_classes[config["discriminator_scheduler"]]( **config["discriminator_scheduler_params"]) - optimizer_d = Adam( + optimizer_d = AdamW( learning_rate=lr_schedule_d, parameters=dis_parameters, **config["discriminator_optimizer_params"])