add LayerDict

pull/916/head
Hui Zhang 3 years ago
parent 666b42d18b
commit 0f59459a66

@ -355,6 +355,8 @@ if not hasattr(paddle.Tensor, 'tolist'):
setattr(paddle.Tensor, 'tolist', tolist) setattr(paddle.Tensor, 'tolist', tolist)
########### hcak paddle.nn.functional #############
# hack loss # hack loss
def ctc_loss(logits, def ctc_loss(logits,
labels, labels,
@ -381,3 +383,152 @@ logger.debug(
"override ctc_loss of paddle.nn.functional if exists, remove this when fixed!" "override ctc_loss of paddle.nn.functional if exists, remove this when fixed!"
) )
F.ctc_loss = ctc_loss F.ctc_loss = ctc_loss
########### hcak paddle.nn #############
from paddle.nn import Layer
from typing import Optional
from typing import Mapping
from typing import Iterable
from typing import Tuple
from typing import Iterator
from collections import OrderedDict, abc as container_abcs
class LayerDict(paddle.nn.Layer):
r"""Holds submodules in a dictionary.
:class:`~paddle.nn.LayerDict` can be indexed like a regular Python dictionary,
but modules it contains are properly registered, and will be visible by all
:class:`~paddle.nn.Layer` methods.
:class:`~paddle.nn.LayerDict` is an **ordered** dictionary that respects
* the order of insertion, and
* in :meth:`~paddle.nn.LayerDict.update`, the order of the merged
``OrderedDict``, ``dict`` (started from Python 3.6) or another
:class:`~paddle.nn.LayerDict` (the argument to
:meth:`~paddle.nn.LayerDict.update`).
Note that :meth:`~paddle.nn.LayerDict.update` with other unordered mapping
types (e.g., Python's plain ``dict`` before Python version 3.6) does not
preserve the order of the merged mapping.
Args:
modules (iterable, optional): a mapping (dictionary) of (string: module)
or an iterable of key-value pairs of type (string, module)
Example::
class MyModule(nn.Layer):
def __init__(self):
super(MyModule, self).__init__()
self.choices = nn.LayerDict({
'conv': nn.Conv2d(10, 10, 3),
'pool': nn.MaxPool2d(3)
})
self.activations = nn.LayerDict([
['lrelu', nn.LeakyReLU()],
['prelu', nn.PReLU()]
])
def forward(self, x, choice, act):
x = self.choices[choice](x)
x = self.activations[act](x)
return x
"""
def __init__(self, modules: Optional[Mapping[str, Layer]] = None) -> None:
super(LayerDict, self).__init__()
if modules is not None:
self.update(modules)
def __getitem__(self, key: str) -> Layer:
return self._modules[key]
def __setitem__(self, key: str, module: Layer) -> None:
self.add_module(key, module)
def __delitem__(self, key: str) -> None:
del self._modules[key]
def __len__(self) -> int:
return len(self._modules)
def __iter__(self) -> Iterator[str]:
return iter(self._modules)
def __contains__(self, key: str) -> bool:
return key in self._modules
def clear(self) -> None:
"""Remove all items from the LayerDict.
"""
self._modules.clear()
def pop(self, key: str) -> Layer:
r"""Remove key from the LayerDict and return its module.
Args:
key (string): key to pop from the LayerDict
"""
v = self[key]
del self[key]
return v
def keys(self) -> Iterable[str]:
r"""Return an iterable of the LayerDict keys.
"""
return self._modules.keys()
def items(self) -> Iterable[Tuple[str, Layer]]:
r"""Return an iterable of the LayerDict key/value pairs.
"""
return self._modules.items()
def values(self) -> Iterable[Layer]:
r"""Return an iterable of the LayerDict values.
"""
return self._modules.values()
def update(self, modules: Mapping[str, Layer]) -> None:
r"""Update the :class:`~paddle.nn.LayerDict` with the key-value pairs from a
mapping or an iterable, overwriting existing keys.
.. note::
If :attr:`modules` is an ``OrderedDict``, a :class:`~paddle.nn.LayerDict`, or
an iterable of key-value pairs, the order of new elements in it is preserved.
Args:
modules (iterable): a mapping (dictionary) from string to :class:`~paddle.nn.Layer`,
or an iterable of key-value pairs of type (string, :class:`~paddle.nn.Layer`)
"""
if not isinstance(modules, container_abcs.Iterable):
raise TypeError("LayerDict.update should be called with an "
"iterable of key/value pairs, but got " +
type(modules).__name__)
if isinstance(modules, (OrderedDict, LayerDict, container_abcs.Mapping)):
for key, module in modules.items():
self[key] = module
else:
# modules here can be a list with two items
for j, m in enumerate(modules):
if not isinstance(m, container_abcs.Iterable):
raise TypeError("LayerDict update sequence element "
"#" + str(j) + " should be Iterable; is" +
type(m).__name__)
if not len(m) == 2:
raise ValueError("LayerDict update sequence element "
"#" + str(j) + " has length " + str(len(m)) +
"; 2 is required")
# modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
# that's too cumbersome to type correctly with overloads, so we add an ignore here
self[m[0]] = m[1] # type: ignore[assignment]
# remove forward alltogether to fallback on Module's _forward_unimplemented
if not hasattr(paddle.nn, 'LayerDict'):
logger.debug(
"register user LayerDict to paddle.nn, remove this when fixed!")
setattr(paddle.nn, 'LayerDict', LayerDict)

@ -42,7 +42,7 @@ def all_version():
"paddle_commit": paddle.version.commit, "paddle_commit": paddle.version.commit,
"soundfile": soundfile.__version__, "soundfile": soundfile.__version__,
} }
logger.info(f"Deps Module Version:{pformat(vers.items())}") logger.info(f"Deps Module Version:{pformat(list(vers.items()))}")
@contextmanager @contextmanager

Loading…
Cancel
Save