diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py index d85a3dde..3ac68a3b 100644 --- a/deepspeech/__init__.py +++ b/deepspeech/__init__.py @@ -351,20 +351,3 @@ if not hasattr(paddle.Tensor, 'tolist'): logger.warn( "register user tolist to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'tolist', tolist) - - -########### hcak paddle.nn ############# -class GLU(nn.Layer): - """Gated Linear Units (GLU) Layer""" - - def __init__(self, dim: int=-1): - super().__init__() - self.dim = dim - - def forward(self, xs): - return F.glu(xs, axis=self.dim) - - -if not hasattr(paddle.nn, 'GLU'): - logger.warn("register user GLU to paddle.nn, remove this when fixed!") - setattr(paddle.nn, 'GLU', GLU) diff --git a/deepspeech/modules/activation.py b/deepspeech/modules/activation.py index 30132775..3cb8729e 100644 --- a/deepspeech/modules/activation.py +++ b/deepspeech/modules/activation.py @@ -15,12 +15,13 @@ from collections import OrderedDict import paddle from paddle import nn +from paddle.nn import functional as F from deepspeech.utils.log import Log logger = Log(__name__).getlog() -__all__ = ["get_activation", "brelu", "LinearGLUBlock", "ConvGLUBlock"] +__all__ = ["get_activation", "brelu", "LinearGLUBlock", "ConvGLUBlock", "GLU"] def brelu(x, t_min=0.0, t_max=24.0, name=None): @@ -30,6 +31,17 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None): return x.maximum(t_min).minimum(t_max) +class GLU(nn.Layer): + """Gated Linear Units (GLU) Layer""" + + def __init__(self, dim: int=-1): + super().__init__() + self.dim = dim + + def forward(self, xs): + return F.glu(xs, axis=self.dim) + + class LinearGLUBlock(nn.Layer): """A linear Gated Linear Units (GLU) block.""" @@ -133,13 +145,18 @@ def get_activation(act): """Return activation function.""" # Lazy load to avoid unused import activation_funcs = { + "hardshrink": paddle.nn.Hardshrink, + "hardswish": paddle.nn.Hardswish, "hardtanh": paddle.nn.Hardtanh, "tanh": paddle.nn.Tanh, "relu": paddle.nn.ReLU, + "relu6": paddle.nn.ReLU6, + "leakyrelu": paddle.nn.LeakyReLU, "selu": paddle.nn.SELU, "swish": paddle.nn.Swish, "gelu": paddle.nn.GELU, - "brelu": brelu, + "glu": GLU, + "elu": paddle.nn.ELU, } return activation_funcs[act]()