|
|
|
@ -355,6 +355,8 @@ if not hasattr(paddle.Tensor, 'tolist'):
|
|
|
|
|
setattr(paddle.Tensor, 'tolist', tolist)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
########### hcak paddle.nn.functional #############
|
|
|
|
|
# hack loss
|
|
|
|
|
def ctc_loss(logits,
|
|
|
|
|
labels,
|
|
|
|
@ -381,3 +383,152 @@ logger.debug(
|
|
|
|
|
"override ctc_loss of paddle.nn.functional if exists, remove this when fixed!"
|
|
|
|
|
)
|
|
|
|
|
F.ctc_loss = ctc_loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
########### hcak paddle.nn #############
|
|
|
|
|
from paddle.nn import Layer
|
|
|
|
|
from typing import Optional
|
|
|
|
|
from typing import Mapping
|
|
|
|
|
from typing import Iterable
|
|
|
|
|
from typing import Tuple
|
|
|
|
|
from typing import Iterator
|
|
|
|
|
from collections import OrderedDict, abc as container_abcs
|
|
|
|
|
|
|
|
|
|
class LayerDict(paddle.nn.Layer):
|
|
|
|
|
r"""Holds submodules in a dictionary.
|
|
|
|
|
|
|
|
|
|
:class:`~paddle.nn.LayerDict` can be indexed like a regular Python dictionary,
|
|
|
|
|
but modules it contains are properly registered, and will be visible by all
|
|
|
|
|
:class:`~paddle.nn.Layer` methods.
|
|
|
|
|
|
|
|
|
|
:class:`~paddle.nn.LayerDict` is an **ordered** dictionary that respects
|
|
|
|
|
|
|
|
|
|
* the order of insertion, and
|
|
|
|
|
|
|
|
|
|
* in :meth:`~paddle.nn.LayerDict.update`, the order of the merged
|
|
|
|
|
``OrderedDict``, ``dict`` (started from Python 3.6) or another
|
|
|
|
|
:class:`~paddle.nn.LayerDict` (the argument to
|
|
|
|
|
:meth:`~paddle.nn.LayerDict.update`).
|
|
|
|
|
|
|
|
|
|
Note that :meth:`~paddle.nn.LayerDict.update` with other unordered mapping
|
|
|
|
|
types (e.g., Python's plain ``dict`` before Python version 3.6) does not
|
|
|
|
|
preserve the order of the merged mapping.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
modules (iterable, optional): a mapping (dictionary) of (string: module)
|
|
|
|
|
or an iterable of key-value pairs of type (string, module)
|
|
|
|
|
|
|
|
|
|
Example::
|
|
|
|
|
|
|
|
|
|
class MyModule(nn.Layer):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(MyModule, self).__init__()
|
|
|
|
|
self.choices = nn.LayerDict({
|
|
|
|
|
'conv': nn.Conv2d(10, 10, 3),
|
|
|
|
|
'pool': nn.MaxPool2d(3)
|
|
|
|
|
})
|
|
|
|
|
self.activations = nn.LayerDict([
|
|
|
|
|
['lrelu', nn.LeakyReLU()],
|
|
|
|
|
['prelu', nn.PReLU()]
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
def forward(self, x, choice, act):
|
|
|
|
|
x = self.choices[choice](x)
|
|
|
|
|
x = self.activations[act](x)
|
|
|
|
|
return x
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, modules: Optional[Mapping[str, Layer]] = None) -> None:
|
|
|
|
|
super(LayerDict, self).__init__()
|
|
|
|
|
if modules is not None:
|
|
|
|
|
self.update(modules)
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, key: str) -> Layer:
|
|
|
|
|
return self._modules[key]
|
|
|
|
|
|
|
|
|
|
def __setitem__(self, key: str, module: Layer) -> None:
|
|
|
|
|
self.add_module(key, module)
|
|
|
|
|
|
|
|
|
|
def __delitem__(self, key: str) -> None:
|
|
|
|
|
del self._modules[key]
|
|
|
|
|
|
|
|
|
|
def __len__(self) -> int:
|
|
|
|
|
return len(self._modules)
|
|
|
|
|
|
|
|
|
|
def __iter__(self) -> Iterator[str]:
|
|
|
|
|
return iter(self._modules)
|
|
|
|
|
|
|
|
|
|
def __contains__(self, key: str) -> bool:
|
|
|
|
|
return key in self._modules
|
|
|
|
|
|
|
|
|
|
def clear(self) -> None:
|
|
|
|
|
"""Remove all items from the LayerDict.
|
|
|
|
|
"""
|
|
|
|
|
self._modules.clear()
|
|
|
|
|
|
|
|
|
|
def pop(self, key: str) -> Layer:
|
|
|
|
|
r"""Remove key from the LayerDict and return its module.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
key (string): key to pop from the LayerDict
|
|
|
|
|
"""
|
|
|
|
|
v = self[key]
|
|
|
|
|
del self[key]
|
|
|
|
|
return v
|
|
|
|
|
|
|
|
|
|
def keys(self) -> Iterable[str]:
|
|
|
|
|
r"""Return an iterable of the LayerDict keys.
|
|
|
|
|
"""
|
|
|
|
|
return self._modules.keys()
|
|
|
|
|
|
|
|
|
|
def items(self) -> Iterable[Tuple[str, Layer]]:
|
|
|
|
|
r"""Return an iterable of the LayerDict key/value pairs.
|
|
|
|
|
"""
|
|
|
|
|
return self._modules.items()
|
|
|
|
|
|
|
|
|
|
def values(self) -> Iterable[Layer]:
|
|
|
|
|
r"""Return an iterable of the LayerDict values.
|
|
|
|
|
"""
|
|
|
|
|
return self._modules.values()
|
|
|
|
|
|
|
|
|
|
def update(self, modules: Mapping[str, Layer]) -> None:
|
|
|
|
|
r"""Update the :class:`~paddle.nn.LayerDict` with the key-value pairs from a
|
|
|
|
|
mapping or an iterable, overwriting existing keys.
|
|
|
|
|
|
|
|
|
|
.. note::
|
|
|
|
|
If :attr:`modules` is an ``OrderedDict``, a :class:`~paddle.nn.LayerDict`, or
|
|
|
|
|
an iterable of key-value pairs, the order of new elements in it is preserved.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
modules (iterable): a mapping (dictionary) from string to :class:`~paddle.nn.Layer`,
|
|
|
|
|
or an iterable of key-value pairs of type (string, :class:`~paddle.nn.Layer`)
|
|
|
|
|
"""
|
|
|
|
|
if not isinstance(modules, container_abcs.Iterable):
|
|
|
|
|
raise TypeError("LayerDict.update should be called with an "
|
|
|
|
|
"iterable of key/value pairs, but got " +
|
|
|
|
|
type(modules).__name__)
|
|
|
|
|
|
|
|
|
|
if isinstance(modules, (OrderedDict, LayerDict, container_abcs.Mapping)):
|
|
|
|
|
for key, module in modules.items():
|
|
|
|
|
self[key] = module
|
|
|
|
|
else:
|
|
|
|
|
# modules here can be a list with two items
|
|
|
|
|
for j, m in enumerate(modules):
|
|
|
|
|
if not isinstance(m, container_abcs.Iterable):
|
|
|
|
|
raise TypeError("LayerDict update sequence element "
|
|
|
|
|
"#" + str(j) + " should be Iterable; is" +
|
|
|
|
|
type(m).__name__)
|
|
|
|
|
if not len(m) == 2:
|
|
|
|
|
raise ValueError("LayerDict update sequence element "
|
|
|
|
|
"#" + str(j) + " has length " + str(len(m)) +
|
|
|
|
|
"; 2 is required")
|
|
|
|
|
# modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
|
|
|
|
|
# that's too cumbersome to type correctly with overloads, so we add an ignore here
|
|
|
|
|
self[m[0]] = m[1] # type: ignore[assignment]
|
|
|
|
|
|
|
|
|
|
# remove forward alltogether to fallback on Module's _forward_unimplemented
|
|
|
|
|
|
|
|
|
|
if not hasattr(paddle.nn, 'LayerDict'):
|
|
|
|
|
logger.debug(
|
|
|
|
|
"register user LayerDict to paddle.nn, remove this when fixed!")
|
|
|
|
|
setattr(paddle.nn, 'LayerDict', LayerDict)
|
|
|
|
|