|
|
|
@ -324,6 +324,9 @@ class U2Trainer(Trainer):
|
|
|
|
|
if "warmup_steps" in scheduler_conf else None,
|
|
|
|
|
"gamma":
|
|
|
|
|
scheduler_conf.lr_decay if "lr_decay" in scheduler_conf else None,
|
|
|
|
|
"d_model":
|
|
|
|
|
model_conf.encoder_conf.output_size
|
|
|
|
|
if scheduler_type == "noam" else None,
|
|
|
|
|
}
|
|
|
|
|
lr_scheduler = LRSchedulerFactory.from_args(scheduler_type,
|
|
|
|
|
scheduler_args)
|
|
|
|
@ -338,6 +341,12 @@ class U2Trainer(Trainer):
|
|
|
|
|
# learning_rate=optim_conf.lr,
|
|
|
|
|
# warmup_steps=scheduler_conf.warmup_steps,
|
|
|
|
|
# verbose=False)
|
|
|
|
|
# elif scheduler_type == 'noam':
|
|
|
|
|
# lr_scheduler = paddle.optimizer.lr.NoamDecay(
|
|
|
|
|
# learning_rate=optim_conf.lr,
|
|
|
|
|
# d_model=model_conf.encoder_conf.output_size,
|
|
|
|
|
# warmup_steps=scheduler_conf.warmup_steps,
|
|
|
|
|
# verbose=False)
|
|
|
|
|
# else:
|
|
|
|
|
# raise ValueError(f"Not support scheduler: {scheduler_type}")
|
|
|
|
|
|
|
|
|
@ -356,6 +365,9 @@ class U2Trainer(Trainer):
|
|
|
|
|
"learning_rate": lr_scheduler
|
|
|
|
|
if lr_scheduler else optim_conf.lr,
|
|
|
|
|
"parameters": parameters,
|
|
|
|
|
"epsilon": 1e-9 if optim_type == 'noam' else None,
|
|
|
|
|
"beta1": 0.9 if optim_type == 'noam' else None,
|
|
|
|
|
"beat2": 0.98 if optim_type == 'noam' else None,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler)
|
|
|
|
|