|
|
@ -13,8 +13,7 @@
|
|
|
|
# limitations under the License.
|
|
|
|
# limitations under the License.
|
|
|
|
import paddle
|
|
|
|
import paddle
|
|
|
|
from paddle import nn
|
|
|
|
from paddle import nn
|
|
|
|
|
|
|
|
import math
|
|
|
|
from paddlespeech.s2t.modules.initializer import KaimingUniform
|
|
|
|
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
To align the initializer between paddle and torch,
|
|
|
|
To align the initializer between paddle and torch,
|
|
|
|
the API below are set defalut initializer with priority higger than global initializer.
|
|
|
|
the API below are set defalut initializer with priority higger than global initializer.
|
|
|
@ -82,10 +81,10 @@ class Linear(nn.Linear):
|
|
|
|
name=None):
|
|
|
|
name=None):
|
|
|
|
if weight_attr is None:
|
|
|
|
if weight_attr is None:
|
|
|
|
if global_init_type == "kaiming_uniform":
|
|
|
|
if global_init_type == "kaiming_uniform":
|
|
|
|
weight_attr = paddle.ParamAttr(initializer=KaimingUniform())
|
|
|
|
weight_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu'))
|
|
|
|
if bias_attr is None:
|
|
|
|
if bias_attr is None:
|
|
|
|
if global_init_type == "kaiming_uniform":
|
|
|
|
if global_init_type == "kaiming_uniform":
|
|
|
|
bias_attr = paddle.ParamAttr(initializer=KaimingUniform())
|
|
|
|
bias_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu'))
|
|
|
|
super(Linear, self).__init__(in_features, out_features, weight_attr,
|
|
|
|
super(Linear, self).__init__(in_features, out_features, weight_attr,
|
|
|
|
bias_attr, name)
|
|
|
|
bias_attr, name)
|
|
|
|
|
|
|
|
|
|
|
@ -105,10 +104,10 @@ class Conv1D(nn.Conv1D):
|
|
|
|
data_format='NCL'):
|
|
|
|
data_format='NCL'):
|
|
|
|
if weight_attr is None:
|
|
|
|
if weight_attr is None:
|
|
|
|
if global_init_type == "kaiming_uniform":
|
|
|
|
if global_init_type == "kaiming_uniform":
|
|
|
|
weight_attr = paddle.ParamAttr(initializer=KaimingUniform())
|
|
|
|
weight_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu'))
|
|
|
|
if bias_attr is None:
|
|
|
|
if bias_attr is None:
|
|
|
|
if global_init_type == "kaiming_uniform":
|
|
|
|
if global_init_type == "kaiming_uniform":
|
|
|
|
bias_attr = paddle.ParamAttr(initializer=KaimingUniform())
|
|
|
|
bias_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu'))
|
|
|
|
super(Conv1D, self).__init__(
|
|
|
|
super(Conv1D, self).__init__(
|
|
|
|
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
|
|
|
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
|
|
|
groups, padding_mode, weight_attr, bias_attr, data_format)
|
|
|
|
groups, padding_mode, weight_attr, bias_attr, data_format)
|
|
|
@ -129,10 +128,10 @@ class Conv2D(nn.Conv2D):
|
|
|
|
data_format='NCHW'):
|
|
|
|
data_format='NCHW'):
|
|
|
|
if weight_attr is None:
|
|
|
|
if weight_attr is None:
|
|
|
|
if global_init_type == "kaiming_uniform":
|
|
|
|
if global_init_type == "kaiming_uniform":
|
|
|
|
weight_attr = paddle.ParamAttr(initializer=KaimingUniform())
|
|
|
|
weight_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu'))
|
|
|
|
if bias_attr is None:
|
|
|
|
if bias_attr is None:
|
|
|
|
if global_init_type == "kaiming_uniform":
|
|
|
|
if global_init_type == "kaiming_uniform":
|
|
|
|
bias_attr = paddle.ParamAttr(initializer=KaimingUniform())
|
|
|
|
bias_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu'))
|
|
|
|
super(Conv2D, self).__init__(
|
|
|
|
super(Conv2D, self).__init__(
|
|
|
|
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
|
|
|
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
|
|
|
groups, padding_mode, weight_attr, bias_attr, data_format)
|
|
|
|
groups, padding_mode, weight_attr, bias_attr, data_format)
|
|
|
|