From 3912c255ef39a712fb5b3630c111c08d7eac0149 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 5 Aug 2021 10:33:23 +0000 Subject: [PATCH] support noam lr and opt --- deepspeech/exps/u2/model.py | 12 ++++++++++++ deepspeech/training/optimizer.py | 3 +++ deepspeech/utils/dynamic_import.py | 10 ++++++++-- 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 34145780..aefe73f8 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -324,6 +324,9 @@ class U2Trainer(Trainer): if "warmup_steps" in scheduler_conf else None, "gamma": scheduler_conf.lr_decay if "lr_decay" in scheduler_conf else None, + "d_model": + model_conf.encoder_conf.output_size + if scheduler_type == "noam" else None, } lr_scheduler = LRSchedulerFactory.from_args(scheduler_type, scheduler_args) @@ -338,6 +341,12 @@ class U2Trainer(Trainer): # learning_rate=optim_conf.lr, # warmup_steps=scheduler_conf.warmup_steps, # verbose=False) + # elif scheduler_type == 'noam': + # lr_scheduler = paddle.optimizer.lr.NoamDecay( + # learning_rate=optim_conf.lr, + # d_model=model_conf.encoder_conf.output_size, + # warmup_steps=scheduler_conf.warmup_steps, + # verbose=False) # else: # raise ValueError(f"Not support scheduler: {scheduler_type}") @@ -356,6 +365,9 @@ class U2Trainer(Trainer): "learning_rate": lr_scheduler if lr_scheduler else optim_conf.lr, "parameters": parameters, + "epsilon": 1e-9 if optim_type == 'noam' else None, + "beta1": 0.9 if optim_type == 'noam' else None, + "beat2": 0.98 if optim_type == 'noam' else None, } optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler) diff --git a/deepspeech/training/optimizer.py b/deepspeech/training/optimizer.py index adbc97ff..2e62a7ed 100644 --- a/deepspeech/training/optimizer.py +++ b/deepspeech/training/optimizer.py @@ -20,6 +20,7 @@ from paddle.regularizer import L2Decay from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.utils.dynamic_import import dynamic_import +from deepspeech.utils.dynamic_import import filter_valid_args from deepspeech.utils.log import Log __all__ = ["OptimizerFactory"] @@ -78,4 +79,6 @@ class OptimizerFactory(): f"Optimizer: {module_class.__name__} {args['learning_rate']}") args.update({"grad_clip": grad_clip, "weight_decay": weight_decay}) + + args = filter_valid_args(args) return module_class(**args) diff --git a/deepspeech/utils/dynamic_import.py b/deepspeech/utils/dynamic_import.py index 81586e3e..41978bc9 100644 --- a/deepspeech/utils/dynamic_import.py +++ b/deepspeech/utils/dynamic_import.py @@ -20,7 +20,7 @@ from deepspeech.utils.log import Log logger = Log(__name__).getlog() -__all__ = ["dynamic_import", "instance_class"] +__all__ = ["dynamic_import", "instance_class", "filter_valid_args"] def dynamic_import(import_path, alias=dict()): @@ -43,8 +43,14 @@ def dynamic_import(import_path, alias=dict()): return getattr(m, objname) -def instance_class(module_class, args: Dict[Text, Any]): +def filter_valid_args(args: Dict[Text, Any]): # filter out `val` which is None new_args = {key: val for key, val in args.items() if val is not None} + return new_args + + +def instance_class(module_class, args: Dict[Text, Any]): + # filter out `val` which is None + new_args = filter_valid_args(args) logger.info(f"Instance: {module_class.__name__} {new_args}.") return module_class(**new_args)