|
|
@ -15,12 +15,13 @@ from collections import OrderedDict
|
|
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
import paddle
|
|
|
|
from paddle import nn
|
|
|
|
from paddle import nn
|
|
|
|
|
|
|
|
from paddle.nn import functional as F
|
|
|
|
|
|
|
|
|
|
|
|
from deepspeech.utils.log import Log
|
|
|
|
from deepspeech.utils.log import Log
|
|
|
|
|
|
|
|
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ["get_activation", "brelu", "LinearGLUBlock", "ConvGLUBlock"]
|
|
|
|
__all__ = ["get_activation", "brelu", "LinearGLUBlock", "ConvGLUBlock", "GLU"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def brelu(x, t_min=0.0, t_max=24.0, name=None):
|
|
|
|
def brelu(x, t_min=0.0, t_max=24.0, name=None):
|
|
|
@ -30,6 +31,17 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None):
|
|
|
|
return x.maximum(t_min).minimum(t_max)
|
|
|
|
return x.maximum(t_min).minimum(t_max)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GLU(nn.Layer):
|
|
|
|
|
|
|
|
"""Gated Linear Units (GLU) Layer"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, dim: int=-1):
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
self.dim = dim
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, xs):
|
|
|
|
|
|
|
|
return F.glu(xs, axis=self.dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LinearGLUBlock(nn.Layer):
|
|
|
|
class LinearGLUBlock(nn.Layer):
|
|
|
|
"""A linear Gated Linear Units (GLU) block."""
|
|
|
|
"""A linear Gated Linear Units (GLU) block."""
|
|
|
|
|
|
|
|
|
|
|
@ -133,13 +145,18 @@ def get_activation(act):
|
|
|
|
"""Return activation function."""
|
|
|
|
"""Return activation function."""
|
|
|
|
# Lazy load to avoid unused import
|
|
|
|
# Lazy load to avoid unused import
|
|
|
|
activation_funcs = {
|
|
|
|
activation_funcs = {
|
|
|
|
|
|
|
|
"hardshrink": paddle.nn.Hardshrink,
|
|
|
|
|
|
|
|
"hardswish": paddle.nn.Hardswish,
|
|
|
|
"hardtanh": paddle.nn.Hardtanh,
|
|
|
|
"hardtanh": paddle.nn.Hardtanh,
|
|
|
|
"tanh": paddle.nn.Tanh,
|
|
|
|
"tanh": paddle.nn.Tanh,
|
|
|
|
"relu": paddle.nn.ReLU,
|
|
|
|
"relu": paddle.nn.ReLU,
|
|
|
|
|
|
|
|
"relu6": paddle.nn.ReLU6,
|
|
|
|
|
|
|
|
"leakyrelu": paddle.nn.LeakyReLU,
|
|
|
|
"selu": paddle.nn.SELU,
|
|
|
|
"selu": paddle.nn.SELU,
|
|
|
|
"swish": paddle.nn.Swish,
|
|
|
|
"swish": paddle.nn.Swish,
|
|
|
|
"gelu": paddle.nn.GELU,
|
|
|
|
"gelu": paddle.nn.GELU,
|
|
|
|
"brelu": brelu,
|
|
|
|
"glu": GLU,
|
|
|
|
|
|
|
|
"elu": paddle.nn.ELU,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return activation_funcs[act]()
|
|
|
|
return activation_funcs[act]()
|
|
|
|