filter key by class signature, no print tensor

pull/751/head
Hui Zhang 3 years ago
parent 3912c255ef
commit c4da9a7f3a

@ -27,6 +27,9 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm):
def __init__(self, clip_norm): def __init__(self, clip_norm):
super().__init__(clip_norm) super().__init__(clip_norm)
def __repr__(self):
return f"{self.__class__.__name__}(global_clip_norm={self.clip_norm})"
@imperative_base.no_grad @imperative_base.no_grad
def _dygraph_clip(self, params_grads): def _dygraph_clip(self, params_grads):
params_and_grads = [] params_and_grads = []

@ -20,7 +20,7 @@ from paddle.regularizer import L2Decay
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.utils.dynamic_import import dynamic_import from deepspeech.utils.dynamic_import import dynamic_import
from deepspeech.utils.dynamic_import import filter_valid_args from deepspeech.utils.dynamic_import import instance_class
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
__all__ = ["OptimizerFactory"] __all__ = ["OptimizerFactory"]
@ -80,5 +80,4 @@ class OptimizerFactory():
args.update({"grad_clip": grad_clip, "weight_decay": weight_decay}) args.update({"grad_clip": grad_clip, "weight_decay": weight_decay})
args = filter_valid_args(args) return instance_class(module_class, args)
return module_class(**args)

@ -12,15 +12,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib import importlib
import inspect
from typing import Any from typing import Any
from typing import Dict from typing import Dict
from typing import List
from typing import Text from typing import Text
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
from deepspeech.utils.tensor_utils import has_tensor
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
__all__ = ["dynamic_import", "instance_class", "filter_valid_args"] __all__ = ["dynamic_import", "instance_class"]
def dynamic_import(import_path, alias=dict()): def dynamic_import(import_path, alias=dict()):
@ -43,14 +46,22 @@ def dynamic_import(import_path, alias=dict()):
return getattr(m, objname) return getattr(m, objname)
def filter_valid_args(args: Dict[Text, Any]): def filter_valid_args(args: Dict[Text, Any], valid_keys: List[Text]):
# filter out `val` which is None # filter by `valid_keys` and filter `val` is not None
new_args = {key: val for key, val in args.items() if val is not None} new_args = {
key: val
for key, val in args.items() if (key in valid_keys and val is not None)
}
return new_args return new_args
def filter_out_tenosr(args: Dict[Text, Any]):
return {key: val for key, val in args.items() if not has_tensor(val)}
def instance_class(module_class, args: Dict[Text, Any]): def instance_class(module_class, args: Dict[Text, Any]):
# filter out `val` which is None valid_keys = inspect.signature(module_class).parameters.keys()
new_args = filter_valid_args(args) new_args = filter_valid_args(args, valid_keys)
logger.info(f"Instance: {module_class.__name__} {new_args}.") logger.info(
f"Instance: {module_class.__name__} {filter_out_tenosr(new_args)}.")
return module_class(**new_args) return module_class(**new_args)

@ -19,11 +19,25 @@ import paddle
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
__all__ = ["pad_sequence", "add_sos_eos", "th_accuracy"] __all__ = ["pad_sequence", "add_sos_eos", "th_accuracy", "has_tensor"]
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
def has_tensor(val):
if isinstance(val, (list, tuple)):
for item in val:
if has_tensor(item):
return True
elif isinstance(val, dict):
for k, v in val.items():
print(k)
if has_tensor(v):
return True
else:
return paddle.is_tensor(val)
def pad_sequence(sequences: List[paddle.Tensor], def pad_sequence(sequences: List[paddle.Tensor],
batch_first: bool=False, batch_first: bool=False,
padding_value: float=0.0) -> paddle.Tensor: padding_value: float=0.0) -> paddle.Tensor:

Loading…
Cancel
Save