# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any from typing import Dict from typing import Text import paddle from paddle.optimizer import Optimizer from paddle.regularizer import L2Decay from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.utils.dynamic_import import dynamic_import from deepspeech.utils.dynamic_import import instance_class from deepspeech.utils.log import Log __all__ = ["OptimizerFactory"] logger = Log(__name__).getlog() OPTIMIZER_DICT = { "sgd": "paddle.optimizer:SGD", "momentum": "paddle.optimizer:Momentum", "adadelta": "paddle.optimizer:Adadelta", "adam": "paddle.optimizer:Adam", "adamw": "paddle.optimizer:AdamW", } def register_optimizer(cls): """Register optimizer.""" alias = cls.__name__.lower() OPTIMIZER_DICT[cls.__name__.lower()] = cls.__module__ + ":" + cls.__name__ return cls @register_optimizer class Noam(paddle.optimizer.Adam): """Seem to: espnet/nets/pytorch_backend/transformer/optimizer.py """ def __init__(self, learning_rate=0, beta1=0.9, beta2=0.98, epsilon=1e-9, parameters=None, weight_decay=None, grad_clip=None, lazy_mode=False, multi_precision=False, name=None): super().__init__( learning_rate=learning_rate, beta1=beta1, beta2=beta2, epsilon=epsilon, parameters=parameters, weight_decay=weight_decay, grad_clip=grad_clip, lazy_mode=lazy_mode, multi_precision=multi_precision, name=name) def __repr__(self): echo = f"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> " echo += f"learning_rate: {self._learning_rate}, " echo += f"(beta1: {self._beta1} beta2: {self._beta2}), " echo += f"epsilon: {self._epsilon}" def dynamic_import_optimizer(module): """Import Optimizer class dynamically. Args: module (str): module_name:class_name or alias in `OPTIMIZER_DICT` Returns: type: Optimizer class """ module_class = dynamic_import(module, OPTIMIZER_DICT) assert issubclass(module_class, Optimizer), f"{module} does not implement Optimizer" return module_class class OptimizerFactory(): @classmethod def from_args(cls, name: str, args: Dict[Text, Any]): assert "parameters" in args, "parameters not in args." assert "learning_rate" in args, "learning_rate not in args." grad_clip = ClipGradByGlobalNormWithLog( args['grad_clip']) if "grad_clip" in args else None weight_decay = L2Decay( args['weight_decay']) if "weight_decay" in args else None if weight_decay: logger.info(f'') if grad_clip: logger.info(f'') module_class = dynamic_import_optimizer(name.lower()) args.update({"grad_clip": grad_clip, "weight_decay": weight_decay}) opt = instance_class(module_class, args) if "__repr__" in vars(opt): logger.info(f"{opt}") else: logger.info( f" LR: {args['learning_rate']}" ) return opt