fix activation

pull/814/head
Hui Zhang 3 years ago
parent 7e136d0893
commit 244132c1c4

@ -351,20 +351,3 @@ if not hasattr(paddle.Tensor, 'tolist'):
logger.warn( logger.warn(
"register user tolist to paddle.Tensor, remove this when fixed!") "register user tolist to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'tolist', tolist) setattr(paddle.Tensor, 'tolist', tolist)
########### hcak paddle.nn #############
class GLU(nn.Layer):
"""Gated Linear Units (GLU) Layer"""
def __init__(self, dim: int=-1):
super().__init__()
self.dim = dim
def forward(self, xs):
return F.glu(xs, axis=self.dim)
if not hasattr(paddle.nn, 'GLU'):
logger.warn("register user GLU to paddle.nn, remove this when fixed!")
setattr(paddle.nn, 'GLU', GLU)

@ -15,12 +15,13 @@ from collections import OrderedDict
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
__all__ = ["get_activation", "brelu", "LinearGLUBlock", "ConvGLUBlock"] __all__ = ["get_activation", "brelu", "LinearGLUBlock", "ConvGLUBlock", "GLU"]
def brelu(x, t_min=0.0, t_max=24.0, name=None): def brelu(x, t_min=0.0, t_max=24.0, name=None):
@ -30,6 +31,17 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None):
return x.maximum(t_min).minimum(t_max) return x.maximum(t_min).minimum(t_max)
class GLU(nn.Layer):
"""Gated Linear Units (GLU) Layer"""
def __init__(self, dim: int=-1):
super().__init__()
self.dim = dim
def forward(self, xs):
return F.glu(xs, axis=self.dim)
class LinearGLUBlock(nn.Layer): class LinearGLUBlock(nn.Layer):
"""A linear Gated Linear Units (GLU) block.""" """A linear Gated Linear Units (GLU) block."""
@ -133,13 +145,18 @@ def get_activation(act):
"""Return activation function.""" """Return activation function."""
# Lazy load to avoid unused import # Lazy load to avoid unused import
activation_funcs = { activation_funcs = {
"hardshrink": paddle.nn.Hardshrink,
"hardswish": paddle.nn.Hardswish,
"hardtanh": paddle.nn.Hardtanh, "hardtanh": paddle.nn.Hardtanh,
"tanh": paddle.nn.Tanh, "tanh": paddle.nn.Tanh,
"relu": paddle.nn.ReLU, "relu": paddle.nn.ReLU,
"relu6": paddle.nn.ReLU6,
"leakyrelu": paddle.nn.LeakyReLU,
"selu": paddle.nn.SELU, "selu": paddle.nn.SELU,
"swish": paddle.nn.Swish, "swish": paddle.nn.Swish,
"gelu": paddle.nn.GELU, "gelu": paddle.nn.GELU,
"brelu": brelu, "glu": GLU,
"elu": paddle.nn.ELU,
} }
return activation_funcs[act]() return activation_funcs[act]()

Loading…
Cancel
Save