|
|
@ -315,18 +315,11 @@ class U2Trainer(Trainer):
|
|
|
|
scheduler_conf = train_config.scheduler_conf
|
|
|
|
scheduler_conf = train_config.scheduler_conf
|
|
|
|
|
|
|
|
|
|
|
|
scheduler_args = {
|
|
|
|
scheduler_args = {
|
|
|
|
"learning_rate":
|
|
|
|
"learning_rate": optim_conf.lr,
|
|
|
|
optim_conf.lr,
|
|
|
|
"verbose": False,
|
|
|
|
"verbose":
|
|
|
|
"warmup_steps": scheduler_conf.warmup_steps,
|
|
|
|
False,
|
|
|
|
"gamma": scheduler_conf.lr_decay,
|
|
|
|
"warmup_steps":
|
|
|
|
"d_model": model_conf.encoder_conf.output_size,
|
|
|
|
scheduler_conf.warmup_steps
|
|
|
|
|
|
|
|
if "warmup_steps" in scheduler_conf else None,
|
|
|
|
|
|
|
|
"gamma":
|
|
|
|
|
|
|
|
scheduler_conf.lr_decay if "lr_decay" in scheduler_conf else None,
|
|
|
|
|
|
|
|
"d_model":
|
|
|
|
|
|
|
|
model_conf.encoder_conf.output_size
|
|
|
|
|
|
|
|
if scheduler_type == "noam" else None,
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
lr_scheduler = LRSchedulerFactory.from_args(scheduler_type,
|
|
|
|
lr_scheduler = LRSchedulerFactory.from_args(scheduler_type,
|
|
|
|
scheduler_args)
|
|
|
|
scheduler_args)
|
|
|
|