update lr scheduler

pull/2117/head
TianYuan 3 years ago
parent 94688264c7
commit 028742b69a

@ -71,9 +71,10 @@ model:
###########################################################
# OPTIMIZER SETTING #
###########################################################
optimizer:
optim: adam # optimizer type
learning_rate: 0.001 # learning rate
scheduler_params:
d_model: 384
warmup_steps: 4000
grad_clip: 1.0
###########################################################
# TRAINING SETTING #
@ -84,7 +85,7 @@ num_snapshots: 5
###########################################################
# OTHER SETTING #
###########################################################
seed: 10086
seed: 0
token_list:
- <blank>

@ -23,8 +23,10 @@ import paddle
import yaml
from paddle import DataParallel
from paddle import distributed as dist
from paddle import nn
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from paddle.optimizer import Adam
from yacs.config import CfgNode
from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn
@ -34,7 +36,6 @@ from paddlespeech.t2s.models.ernie_sat import ErnieSATEvaluator
from paddlespeech.t2s.models.ernie_sat import ErnieSATUpdater
from paddlespeech.t2s.training.extensions.snapshot import Snapshot
from paddlespeech.t2s.training.extensions.visualizer import VisualDL
from paddlespeech.t2s.training.optimizer import build_optimizers
from paddlespeech.t2s.training.seeding import seed_everything
from paddlespeech.t2s.training.trainer import Trainer
@ -118,12 +119,27 @@ def train_sp(args, config):
odim = config.n_mels
model = ErnieSAT(idim=vocab_size, odim=odim, **config["model"])
# model_path = "/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/ernie_sat/pretrained_model/paddle_checkpoint_en/model.pdparams"
# state_dict = paddle.load(model_path)
# new_state_dict = {}
# for key, value in state_dict.items():
# new_key = "model." + key
# new_state_dict[new_key] = value
# model.set_state_dict(new_state_dict)
if world_size > 1:
model = DataParallel(model)
print("model done!")
optimizer = build_optimizers(model, **config["optimizer"])
scheduler = paddle.optimizer.lr.NoamDecay(
d_model=config["scheduler_params"]["d_model"],
warmup_steps=config["scheduler_params"]["warmup_steps"])
grad_clip = nn.ClipGradByGlobalNorm(config["grad_clip"])
optimizer = Adam(
learning_rate=scheduler,
grad_clip=grad_clip,
parameters=model.parameters())
print("optimizer done!")
output_dir = Path(args.output_dir)
@ -136,6 +152,7 @@ def train_sp(args, config):
updater = ErnieSATUpdater(
model=model,
optimizer=optimizer,
scheduler=scheduler,
dataloader=train_dataloader,
text_masking=config["model"]["text_masking"],
odim=odim,

@ -18,6 +18,7 @@ from paddle import distributed as dist
from paddle.io import DataLoader
from paddle.nn import Layer
from paddle.optimizer import Optimizer
from paddle.optimizer.lr import LRScheduler
from paddlespeech.t2s.modules.losses import MLMLoss
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
@ -34,12 +35,14 @@ class ErnieSATUpdater(StandardUpdater):
def __init__(self,
model: Layer,
optimizer: Optimizer,
scheduler: LRScheduler,
dataloader: DataLoader,
init_state=None,
text_masking: bool=False,
odim: int=80,
output_dir: Path=None):
super().__init__(model, optimizer, dataloader, init_state=None)
self.scheduler = scheduler
self.criterion = MLMLoss(text_masking=text_masking, odim=odim)
@ -75,10 +78,12 @@ class ErnieSATUpdater(StandardUpdater):
loss = mlm_loss + text_mlm_loss if text_mlm_loss is not None else mlm_loss
optimizer = self.optimizer
optimizer.clear_grad()
self.optimizer.clear_grad()
loss.backward()
optimizer.step()
self.optimizer.step()
self.scheduler.step()
scheduler_msg = 'lr: {}'.format(self.scheduler.last_lr)
report("train/loss", float(loss))
report("train/mlm_loss", float(mlm_loss))
@ -90,6 +95,7 @@ class ErnieSATUpdater(StandardUpdater):
losses_dict["loss"] = float(loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
self.msg += ', ' + scheduler_msg
class ErnieSATEvaluator(StandardEvaluator):

Loading…
Cancel
Save