update lr scheduler

pull/2117/head
TianYuan 3 years ago
parent 94688264c7
commit 028742b69a

@ -71,9 +71,10 @@ model:
########################################################### ###########################################################
# OPTIMIZER SETTING # # OPTIMIZER SETTING #
########################################################### ###########################################################
optimizer: scheduler_params:
optim: adam # optimizer type d_model: 384
learning_rate: 0.001 # learning rate warmup_steps: 4000
grad_clip: 1.0
########################################################### ###########################################################
# TRAINING SETTING # # TRAINING SETTING #
@ -84,7 +85,7 @@ num_snapshots: 5
########################################################### ###########################################################
# OTHER SETTING # # OTHER SETTING #
########################################################### ###########################################################
seed: 10086 seed: 0
token_list: token_list:
- <blank> - <blank>

@ -23,8 +23,10 @@ import paddle
import yaml import yaml
from paddle import DataParallel from paddle import DataParallel
from paddle import distributed as dist from paddle import distributed as dist
from paddle import nn
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler from paddle.io import DistributedBatchSampler
from paddle.optimizer import Adam
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn
@ -34,7 +36,6 @@ from paddlespeech.t2s.models.ernie_sat import ErnieSATEvaluator
from paddlespeech.t2s.models.ernie_sat import ErnieSATUpdater from paddlespeech.t2s.models.ernie_sat import ErnieSATUpdater
from paddlespeech.t2s.training.extensions.snapshot import Snapshot from paddlespeech.t2s.training.extensions.snapshot import Snapshot
from paddlespeech.t2s.training.extensions.visualizer import VisualDL from paddlespeech.t2s.training.extensions.visualizer import VisualDL
from paddlespeech.t2s.training.optimizer import build_optimizers
from paddlespeech.t2s.training.seeding import seed_everything from paddlespeech.t2s.training.seeding import seed_everything
from paddlespeech.t2s.training.trainer import Trainer from paddlespeech.t2s.training.trainer import Trainer
@ -118,12 +119,27 @@ def train_sp(args, config):
odim = config.n_mels odim = config.n_mels
model = ErnieSAT(idim=vocab_size, odim=odim, **config["model"]) model = ErnieSAT(idim=vocab_size, odim=odim, **config["model"])
# model_path = "/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/ernie_sat/pretrained_model/paddle_checkpoint_en/model.pdparams"
# state_dict = paddle.load(model_path)
# new_state_dict = {}
# for key, value in state_dict.items():
# new_key = "model." + key
# new_state_dict[new_key] = value
# model.set_state_dict(new_state_dict)
if world_size > 1: if world_size > 1:
model = DataParallel(model) model = DataParallel(model)
print("model done!") print("model done!")
optimizer = build_optimizers(model, **config["optimizer"]) scheduler = paddle.optimizer.lr.NoamDecay(
d_model=config["scheduler_params"]["d_model"],
warmup_steps=config["scheduler_params"]["warmup_steps"])
grad_clip = nn.ClipGradByGlobalNorm(config["grad_clip"])
optimizer = Adam(
learning_rate=scheduler,
grad_clip=grad_clip,
parameters=model.parameters())
print("optimizer done!") print("optimizer done!")
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
@ -136,6 +152,7 @@ def train_sp(args, config):
updater = ErnieSATUpdater( updater = ErnieSATUpdater(
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler,
dataloader=train_dataloader, dataloader=train_dataloader,
text_masking=config["model"]["text_masking"], text_masking=config["model"]["text_masking"],
odim=odim, odim=odim,

@ -18,6 +18,7 @@ from paddle import distributed as dist
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.nn import Layer from paddle.nn import Layer
from paddle.optimizer import Optimizer from paddle.optimizer import Optimizer
from paddle.optimizer.lr import LRScheduler
from paddlespeech.t2s.modules.losses import MLMLoss from paddlespeech.t2s.modules.losses import MLMLoss
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
@ -34,12 +35,14 @@ class ErnieSATUpdater(StandardUpdater):
def __init__(self, def __init__(self,
model: Layer, model: Layer,
optimizer: Optimizer, optimizer: Optimizer,
scheduler: LRScheduler,
dataloader: DataLoader, dataloader: DataLoader,
init_state=None, init_state=None,
text_masking: bool=False, text_masking: bool=False,
odim: int=80, odim: int=80,
output_dir: Path=None): output_dir: Path=None):
super().__init__(model, optimizer, dataloader, init_state=None) super().__init__(model, optimizer, dataloader, init_state=None)
self.scheduler = scheduler
self.criterion = MLMLoss(text_masking=text_masking, odim=odim) self.criterion = MLMLoss(text_masking=text_masking, odim=odim)
@ -75,10 +78,12 @@ class ErnieSATUpdater(StandardUpdater):
loss = mlm_loss + text_mlm_loss if text_mlm_loss is not None else mlm_loss loss = mlm_loss + text_mlm_loss if text_mlm_loss is not None else mlm_loss
optimizer = self.optimizer self.optimizer.clear_grad()
optimizer.clear_grad()
loss.backward() loss.backward()
optimizer.step() self.optimizer.step()
self.scheduler.step()
scheduler_msg = 'lr: {}'.format(self.scheduler.last_lr)
report("train/loss", float(loss)) report("train/loss", float(loss))
report("train/mlm_loss", float(mlm_loss)) report("train/mlm_loss", float(mlm_loss))
@ -90,6 +95,7 @@ class ErnieSATUpdater(StandardUpdater):
losses_dict["loss"] = float(loss) losses_dict["loss"] = float(loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v) self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items()) for k, v in losses_dict.items())
self.msg += ', ' + scheduler_msg
class ErnieSATEvaluator(StandardEvaluator): class ErnieSATEvaluator(StandardEvaluator):

Loading…
Cancel
Save