Change optimizer for vits, test=tts

pull/2791/head
WongLaw 3 years ago
parent 96d76c83ad
commit 792eec9222

@ -24,7 +24,7 @@ import yaml
from paddle import DataParallel from paddle import DataParallel
from paddle import distributed as dist from paddle import distributed as dist
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.optimizer import Adam from paddle.optimizer import AdamW
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.t2s.datasets.am_batch_fn import vits_multi_spk_batch_fn from paddlespeech.t2s.datasets.am_batch_fn import vits_multi_spk_batch_fn
@ -164,14 +164,14 @@ def train_sp(args, config):
lr_schedule_g = scheduler_classes[config["generator_scheduler"]]( lr_schedule_g = scheduler_classes[config["generator_scheduler"]](
**config["generator_scheduler_params"]) **config["generator_scheduler_params"])
optimizer_g = Adam( optimizer_g = AdamW(
learning_rate=lr_schedule_g, learning_rate=lr_schedule_g,
parameters=gen_parameters, parameters=gen_parameters,
**config["generator_optimizer_params"]) **config["generator_optimizer_params"])
lr_schedule_d = scheduler_classes[config["discriminator_scheduler"]]( lr_schedule_d = scheduler_classes[config["discriminator_scheduler"]](
**config["discriminator_scheduler_params"]) **config["discriminator_scheduler_params"])
optimizer_d = Adam( optimizer_d = AdamW(
learning_rate=lr_schedule_d, learning_rate=lr_schedule_d,
parameters=dis_parameters, parameters=dis_parameters,
**config["discriminator_optimizer_params"]) **config["discriminator_optimizer_params"])

Loading…
Cancel
Save