support noam lr and opt

pull/751/head
Hui Zhang 3 years ago
parent 1cd4d4bf83
commit 3912c255ef

@ -324,6 +324,9 @@ class U2Trainer(Trainer):
if "warmup_steps" in scheduler_conf else None,
"gamma":
scheduler_conf.lr_decay if "lr_decay" in scheduler_conf else None,
"d_model":
model_conf.encoder_conf.output_size
if scheduler_type == "noam" else None,
}
lr_scheduler = LRSchedulerFactory.from_args(scheduler_type,
scheduler_args)
@ -338,6 +341,12 @@ class U2Trainer(Trainer):
# learning_rate=optim_conf.lr,
# warmup_steps=scheduler_conf.warmup_steps,
# verbose=False)
# elif scheduler_type == 'noam':
# lr_scheduler = paddle.optimizer.lr.NoamDecay(
# learning_rate=optim_conf.lr,
# d_model=model_conf.encoder_conf.output_size,
# warmup_steps=scheduler_conf.warmup_steps,
# verbose=False)
# else:
# raise ValueError(f"Not support scheduler: {scheduler_type}")
@ -356,6 +365,9 @@ class U2Trainer(Trainer):
"learning_rate": lr_scheduler
if lr_scheduler else optim_conf.lr,
"parameters": parameters,
"epsilon": 1e-9 if optim_type == 'noam' else None,
"beta1": 0.9 if optim_type == 'noam' else None,
"beat2": 0.98 if optim_type == 'noam' else None,
}
optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler)

@ -20,6 +20,7 @@ from paddle.regularizer import L2Decay
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.utils.dynamic_import import dynamic_import
from deepspeech.utils.dynamic_import import filter_valid_args
from deepspeech.utils.log import Log
__all__ = ["OptimizerFactory"]
@ -78,4 +79,6 @@ class OptimizerFactory():
f"Optimizer: {module_class.__name__} {args['learning_rate']}")
args.update({"grad_clip": grad_clip, "weight_decay": weight_decay})
args = filter_valid_args(args)
return module_class(**args)

@ -20,7 +20,7 @@ from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ["dynamic_import", "instance_class"]
__all__ = ["dynamic_import", "instance_class", "filter_valid_args"]
def dynamic_import(import_path, alias=dict()):
@ -43,8 +43,14 @@ def dynamic_import(import_path, alias=dict()):
return getattr(m, objname)
def instance_class(module_class, args: Dict[Text, Any]):
def filter_valid_args(args: Dict[Text, Any]):
# filter out `val` which is None
new_args = {key: val for key, val in args.items() if val is not None}
return new_args
def instance_class(module_class, args: Dict[Text, Any]):
# filter out `val` which is None
new_args = filter_valid_args(args)
logger.info(f"Instance: {module_class.__name__} {new_args}.")
return module_class(**new_args)

Loading…
Cancel
Save