From 028742b69ae4c305a5d483d55fb018c92c684f9a Mon Sep 17 00:00:00 2001 From: TianYuan Date: Sat, 9 Jul 2022 01:11:44 +0000 Subject: [PATCH] update lr scheduler --- examples/vctk/ernie_sat/conf/default.yaml | 9 ++++---- paddlespeech/t2s/exps/ernie_sat/train.py | 21 +++++++++++++++++-- .../t2s/models/ernie_sat/ernie_sat_updater.py | 12 ++++++++--- 3 files changed, 33 insertions(+), 9 deletions(-) diff --git a/examples/vctk/ernie_sat/conf/default.yaml b/examples/vctk/ernie_sat/conf/default.yaml index 99247659b..74c847a5f 100644 --- a/examples/vctk/ernie_sat/conf/default.yaml +++ b/examples/vctk/ernie_sat/conf/default.yaml @@ -71,9 +71,10 @@ model: ########################################################### # OPTIMIZER SETTING # ########################################################### -optimizer: - optim: adam # optimizer type - learning_rate: 0.001 # learning rate +scheduler_params: + d_model: 384 + warmup_steps: 4000 +grad_clip: 1.0 ########################################################### # TRAINING SETTING # @@ -84,7 +85,7 @@ num_snapshots: 5 ########################################################### # OTHER SETTING # ########################################################### -seed: 10086 +seed: 0 token_list: - diff --git a/paddlespeech/t2s/exps/ernie_sat/train.py b/paddlespeech/t2s/exps/ernie_sat/train.py index fc02a8417..733544476 100644 --- a/paddlespeech/t2s/exps/ernie_sat/train.py +++ b/paddlespeech/t2s/exps/ernie_sat/train.py @@ -23,8 +23,10 @@ import paddle import yaml from paddle import DataParallel from paddle import distributed as dist +from paddle import nn from paddle.io import DataLoader from paddle.io import DistributedBatchSampler +from paddle.optimizer import Adam from yacs.config import CfgNode from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn @@ -34,7 +36,6 @@ from paddlespeech.t2s.models.ernie_sat import ErnieSATEvaluator from paddlespeech.t2s.models.ernie_sat import ErnieSATUpdater from paddlespeech.t2s.training.extensions.snapshot import Snapshot from paddlespeech.t2s.training.extensions.visualizer import VisualDL -from paddlespeech.t2s.training.optimizer import build_optimizers from paddlespeech.t2s.training.seeding import seed_everything from paddlespeech.t2s.training.trainer import Trainer @@ -118,12 +119,27 @@ def train_sp(args, config): odim = config.n_mels model = ErnieSAT(idim=vocab_size, odim=odim, **config["model"]) + # model_path = "/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/ernie_sat/pretrained_model/paddle_checkpoint_en/model.pdparams" + # state_dict = paddle.load(model_path) + # new_state_dict = {} + # for key, value in state_dict.items(): + # new_key = "model." + key + # new_state_dict[new_key] = value + # model.set_state_dict(new_state_dict) if world_size > 1: model = DataParallel(model) print("model done!") - optimizer = build_optimizers(model, **config["optimizer"]) + scheduler = paddle.optimizer.lr.NoamDecay( + d_model=config["scheduler_params"]["d_model"], + warmup_steps=config["scheduler_params"]["warmup_steps"]) + grad_clip = nn.ClipGradByGlobalNorm(config["grad_clip"]) + optimizer = Adam( + learning_rate=scheduler, + grad_clip=grad_clip, + parameters=model.parameters()) + print("optimizer done!") output_dir = Path(args.output_dir) @@ -136,6 +152,7 @@ def train_sp(args, config): updater = ErnieSATUpdater( model=model, optimizer=optimizer, + scheduler=scheduler, dataloader=train_dataloader, text_masking=config["model"]["text_masking"], odim=odim, diff --git a/paddlespeech/t2s/models/ernie_sat/ernie_sat_updater.py b/paddlespeech/t2s/models/ernie_sat/ernie_sat_updater.py index 399922338..17cfaae96 100644 --- a/paddlespeech/t2s/models/ernie_sat/ernie_sat_updater.py +++ b/paddlespeech/t2s/models/ernie_sat/ernie_sat_updater.py @@ -18,6 +18,7 @@ from paddle import distributed as dist from paddle.io import DataLoader from paddle.nn import Layer from paddle.optimizer import Optimizer +from paddle.optimizer.lr import LRScheduler from paddlespeech.t2s.modules.losses import MLMLoss from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator @@ -34,12 +35,14 @@ class ErnieSATUpdater(StandardUpdater): def __init__(self, model: Layer, optimizer: Optimizer, + scheduler: LRScheduler, dataloader: DataLoader, init_state=None, text_masking: bool=False, odim: int=80, output_dir: Path=None): super().__init__(model, optimizer, dataloader, init_state=None) + self.scheduler = scheduler self.criterion = MLMLoss(text_masking=text_masking, odim=odim) @@ -75,10 +78,12 @@ class ErnieSATUpdater(StandardUpdater): loss = mlm_loss + text_mlm_loss if text_mlm_loss is not None else mlm_loss - optimizer = self.optimizer - optimizer.clear_grad() + self.optimizer.clear_grad() + loss.backward() - optimizer.step() + self.optimizer.step() + self.scheduler.step() + scheduler_msg = 'lr: {}'.format(self.scheduler.last_lr) report("train/loss", float(loss)) report("train/mlm_loss", float(mlm_loss)) @@ -90,6 +95,7 @@ class ErnieSATUpdater(StandardUpdater): losses_dict["loss"] = float(loss) self.msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses_dict.items()) + self.msg += ', ' + scheduler_msg class ErnieSATEvaluator(StandardEvaluator):