Change optimizer for vits, test=tts (#2791)

pull/2800/head
HuangLiangJie 3 years ago committed by GitHub
parent 96d76c83ad
commit 964211a81b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -24,7 +24,7 @@ import yaml
from paddle import DataParallel from paddle import DataParallel
from paddle import distributed as dist from paddle import distributed as dist
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.optimizer import Adam from paddle.optimizer import AdamW
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.t2s.datasets.am_batch_fn import vits_multi_spk_batch_fn from paddlespeech.t2s.datasets.am_batch_fn import vits_multi_spk_batch_fn
@ -164,14 +164,14 @@ def train_sp(args, config):
lr_schedule_g = scheduler_classes[config["generator_scheduler"]]( lr_schedule_g = scheduler_classes[config["generator_scheduler"]](
**config["generator_scheduler_params"]) **config["generator_scheduler_params"])
optimizer_g = Adam( optimizer_g = AdamW(
learning_rate=lr_schedule_g, learning_rate=lr_schedule_g,
parameters=gen_parameters, parameters=gen_parameters,
**config["generator_optimizer_params"]) **config["generator_optimizer_params"])
lr_schedule_d = scheduler_classes[config["discriminator_scheduler"]]( lr_schedule_d = scheduler_classes[config["discriminator_scheduler"]](
**config["discriminator_scheduler_params"]) **config["discriminator_scheduler_params"])
optimizer_d = Adam( optimizer_d = AdamW(
learning_rate=lr_schedule_d, learning_rate=lr_schedule_d,
parameters=dis_parameters, parameters=dis_parameters,
**config["discriminator_optimizer_params"]) **config["discriminator_optimizer_params"])

Loading…
Cancel
Save