with all args for scheduler

pull/751/head
Hui Zhang 3 years ago
parent c4da9a7f3a
commit 820b4db287

@ -315,18 +315,11 @@ class U2Trainer(Trainer):
scheduler_conf = train_config.scheduler_conf scheduler_conf = train_config.scheduler_conf
scheduler_args = { scheduler_args = {
"learning_rate": "learning_rate": optim_conf.lr,
optim_conf.lr, "verbose": False,
"verbose": "warmup_steps": scheduler_conf.warmup_steps,
False, "gamma": scheduler_conf.lr_decay,
"warmup_steps": "d_model": model_conf.encoder_conf.output_size,
scheduler_conf.warmup_steps
if "warmup_steps" in scheduler_conf else None,
"gamma":
scheduler_conf.lr_decay if "lr_decay" in scheduler_conf else None,
"d_model":
model_conf.encoder_conf.output_size
if scheduler_type == "noam" else None,
} }
lr_scheduler = LRSchedulerFactory.from_args(scheduler_type, lr_scheduler = LRSchedulerFactory.from_args(scheduler_type,
scheduler_args) scheduler_args)

Loading…
Cancel
Save