change default initializer to kaiming_uniform, test=asr

pull/1577/head
huangyuxin 2 years ago
parent f55fb384ee
commit ab16d8ce3c

@ -37,6 +37,7 @@ model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
init_type: 'kaiming_uniform'
###########################################
# Data #

@ -21,6 +21,7 @@ from paddle import nn
from paddle.fluid import core
from paddle.nn import functional as F
from paddlespeech.s2t.modules import initializer
from paddlespeech.s2t.utils.log import Log
#TODO(Hui Zhang): remove fluid import
@ -505,3 +506,8 @@ if not hasattr(paddle.nn, 'LayerDict'):
logger.debug(
"register user LayerDict to paddle.nn, remove this when fixed!")
setattr(paddle.nn, 'LayerDict', LayerDict)
"""
hack KaiminigUniform: change limit from np.sqrt(6.0 / float(fan_in)) to np.sqrt(1.0 / float(fan_in))
"""
paddle.nn.initializer.KaimingUniform = initializer.KaimingUniform

@ -41,6 +41,7 @@ from paddlespeech.s2t.modules.mask import make_pad_mask
from paddlespeech.s2t.modules.mask import mask_finished_preds
from paddlespeech.s2t.modules.mask import mask_finished_scores
from paddlespeech.s2t.modules.mask import subsequent_mask
from paddlespeech.s2t.modules.nets_utils import initialize
from paddlespeech.s2t.utils import checkpoint
from paddlespeech.s2t.utils import layer_tools
from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank
@ -72,6 +73,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
nn.Layer.__init__(self)
# note that eos is the same as sos (equivalent ID)
self.sos = vocab_size - 1
self.eos = vocab_size - 1
@ -780,9 +782,14 @@ class U2DecodeModel(U2BaseModel):
class U2Model(U2DecodeModel):
def __init__(self, configs: dict):
model_conf = configs.get('model_conf', dict())
init_type = model_conf.get("init_type", None)
if init_type is not None:
logger.info(f"Use {init_type} initializer as default initializer")
initialize(self, init_type)
vocab_size, encoder, decoder, ctc = U2Model._init_from_config(configs)
nn.initializer.set_global_initializer(None)
model_conf = configs.get('model_conf', dict())
super().__init__(
vocab_size=vocab_size,
encoder=encoder,

@ -95,7 +95,7 @@ class MultiHeadedAttention(nn.Layer):
mask (paddle.Tensor): Mask, size (#batch, 1, time2) or
(#batch, time1, time2).
Returns:
paddle.Tensor: Transformed value weighted
paddle.Tensor: Transformed value weighted
by the attention score, (#batch, time1, d_model).
"""
n_batch = value.shape[0]

@ -60,8 +60,8 @@ class ConvolutionModule(nn.Layer):
)
# self.lorder is used to distinguish if it's a causal convolution,
# if self.lorder > 0:
# it's a causal convolution, the input will be padded with
# if self.lorder > 0:
# it's a causal convolution, the input will be padded with
# `self.lorder` frames on the left in forward (causal conv impl).
# else: it's a symmetrical convolution
if causal:
@ -87,10 +87,20 @@ class ConvolutionModule(nn.Layer):
assert norm in ['batch_norm', 'layer_norm']
if norm == "batch_norm":
self.use_layer_norm = False
self.norm = nn.BatchNorm1D(channels)
self.norm = nn.BatchNorm1D(
channels,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(1.0)),
bias_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(0.0)))
else:
self.use_layer_norm = True
self.norm = nn.LayerNorm(channels)
self.norm = nn.LayerNorm(
channels,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(1.0)),
bias_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(0.0)))
self.pointwise_conv2 = nn.Conv1D(
channels,

@ -76,19 +76,30 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
concat_after: bool=False, ):
assert check_argument_types()
nn.Layer.__init__(self)
self.selfattention_layer_type = 'selfattn'
attention_dim = encoder_output_size
if input_layer == "embed":
self.embed = nn.Sequential(
nn.Embedding(vocab_size, attention_dim),
nn.Embedding(
vocab_size,
attention_dim,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Normal())),
PositionalEncoding(attention_dim, positional_dropout_rate), )
else:
raise ValueError(f"only 'embed' is supported: {input_layer}")
self.normalize_before = normalize_before
self.after_norm = nn.LayerNorm(attention_dim, epsilon=1e-12)
self.after_norm = nn.LayerNorm(
attention_dim,
epsilon=1e-12,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(1.0)),
bias_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(0.0)))
self.use_output_layer = use_output_layer
self.output_layer = nn.Linear(attention_dim, vocab_size)

@ -62,9 +62,27 @@ class DecoderLayer(nn.Layer):
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.norm1 = nn.LayerNorm(size, epsilon=1e-12)
self.norm2 = nn.LayerNorm(size, epsilon=1e-12)
self.norm3 = nn.LayerNorm(size, epsilon=1e-12)
self.norm1 = nn.LayerNorm(
size,
epsilon=1e-12,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(1.0)),
bias_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(0.0)))
self.norm2 = nn.LayerNorm(
size,
epsilon=1e-12,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(1.0)),
bias_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(0.0)))
self.norm3 = nn.LayerNorm(
size,
epsilon=1e-12,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(1.0)),
bias_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(0.0)))
self.dropout = nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
self.concat_after = concat_after

@ -129,7 +129,13 @@ class BaseEncoder(nn.Layer):
d_model=output_size, dropout_rate=positional_dropout_rate), )
self.normalize_before = normalize_before
self.after_norm = nn.LayerNorm(output_size, epsilon=1e-12)
self.after_norm = nn.LayerNorm(
output_size,
epsilon=1e-12,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(1.0)),
bias_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(0.0)))
self.static_chunk_size = static_chunk_size
self.use_dynamic_chunk = use_dynamic_chunk
self.use_dynamic_left_chunk = use_dynamic_left_chunk
@ -457,6 +463,7 @@ class ConformerEncoder(BaseEncoder):
cnn_module_norm (str): cnn conv norm type, Optional['batch_norm','layer_norm']
"""
assert check_argument_types()
super().__init__(input_size, output_size, attention_heads, linear_units,
num_blocks, dropout_rate, positional_dropout_rate,
attention_dropout_rate, input_layer,

@ -39,7 +39,7 @@ class TransformerEncoderLayer(nn.Layer):
normalize_before: bool=True,
concat_after: bool=False, ):
"""Construct an EncoderLayer object.
Args:
size (int): Input dimension.
self_attn (nn.Layer): Self-attention module instance.
@ -147,7 +147,7 @@ class ConformerEncoderLayer(nn.Layer):
normalize_before: bool=True,
concat_after: bool=False, ):
"""Construct an EncoderLayer object.
Args:
size (int): Input dimension.
self_attn (nn.Layer): Self-attention module instance.
@ -174,18 +174,46 @@ class ConformerEncoderLayer(nn.Layer):
self.feed_forward = feed_forward
self.feed_forward_macaron = feed_forward_macaron
self.conv_module = conv_module
self.norm_ff = nn.LayerNorm(size, epsilon=1e-12) # for the FNN module
self.norm_mha = nn.LayerNorm(size, epsilon=1e-12) # for the MHA module
self.norm_ff = nn.LayerNorm(
size,
epsilon=1e-12,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(1.0)),
bias_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(0.0))) # for the FNN module
self.norm_mha = nn.LayerNorm(
size,
epsilon=1e-12,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(1.0)),
bias_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(0.0))) # for the MHA module
if feed_forward_macaron is not None:
self.norm_ff_macaron = nn.LayerNorm(size, epsilon=1e-12)
self.norm_ff_macaron = nn.LayerNorm(
size,
epsilon=1e-12,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(1.0)),
bias_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(0.0)))
self.ff_scale = 0.5
else:
self.ff_scale = 1.0
if self.conv_module is not None:
self.norm_conv = nn.LayerNorm(
size, epsilon=1e-12) # for the CNN module
size,
epsilon=1e-12,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(1.0)),
bias_attr=paddle.ParamAttr(initializer=nn.initializer.Constant(
0.0))) # for the CNN module
self.norm_final = nn.LayerNorm(
size, epsilon=1e-12) # for the final output of the block
size,
epsilon=1e-12,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(1.0)),
bias_attr=paddle.ParamAttr(initializer=nn.initializer.Constant(
0.0))) # for the final output of the block
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before

@ -0,0 +1,272 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from paddle.fluid import framework
from paddle.fluid.framework import in_dygraph_mode, default_main_program
import numpy as np
from paddle.fluid.core import VarDesc
from paddle.fluid import unique_name
__all__ = [
'MSRAInitializer'
]
class Initializer(object):
"""Base class for variable initializers
Defines the common interface of variable initializers.
They add operations to the init program that are used
to initialize variables. Users should not use this class
directly, but need to use one of its implementations.
"""
def __init__(self):
pass
def __call__(self, param, block=None):
"""Add corresponding initialization operations to the network
"""
raise NotImplementedError()
def _check_block(self, block):
if block is None:
block = default_main_program().global_block()
return block
def _compute_fans(self, var):
"""Compute the fan_in and the fan_out for layers
This method computes the fan_in and the fan_out
for neural network layers, if not specified. It is
not possible to perfectly estimate fan_in and fan_out.
This method will estimate it correctly for matrix multiply and
convolutions.
Args:
var: variable for which fan_in and fan_out have to be computed
Returns:
tuple of two integers (fan_in, fan_out)
"""
shape = var.shape
if not shape or len(shape) == 0:
fan_in = fan_out = 1
elif len(shape) == 1:
fan_in = fan_out = shape[0]
elif len(shape) == 2:
# This is the case for simple matrix multiply
fan_in = shape[0]
fan_out = shape[1]
else:
# Assume this to be a convolutional kernel
# In PaddlePaddle, the shape of the kernel is like:
# [num_filters, num_filter_channels, ...] where the remaining
# dimensions are the filter_size
receptive_field_size = np.prod(shape[2:])
fan_in = shape[1] * receptive_field_size
fan_out = shape[0] * receptive_field_size
return (fan_in, fan_out)
class MSRAInitializer(Initializer):
r"""Implements the MSRA initializer a.k.a. Kaiming Initializer
This class implements the weight initialization from the paper
`Delving Deep into Rectifiers: Surpassing Human-Level Performance on
ImageNet Classification <https://arxiv.org/abs/1502.01852>`_
by Kaiming He, Xiangyu Zhang, Shaoqing Ren and Jian Sun. This is a
robust initialization method that particularly considers the rectifier
nonlinearities. In case of Uniform distribution, the range is [-x, x], where
.. math::
x = \sqrt{\\frac{6.0}{fan\_in}}
In case of Normal distribution, the mean is 0 and the standard deviation
is
.. math::
\sqrt{\\frac{2.0}{fan\_in}}
Args:
uniform (bool): whether to use uniform or normal distribution
fan_in (float32|None): fan_in for MSRAInitializer. If None, it is\
inferred from the variable. default is None.
seed (int32): random seed
Note:
It is recommended to set fan_in to None for most cases.
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
paddle.enable_static()
x = fluid.data(name="data", shape=[8, 32, 32], dtype="float32")
fc = fluid.layers.fc(input=x, size=10,
param_attr=fluid.initializer.MSRA(uniform=False))
"""
def __init__(self, uniform=True, fan_in=None, seed=0):
"""Constructor for MSRAInitializer
"""
assert uniform is not None
assert seed is not None
super(MSRAInitializer, self).__init__()
self._uniform = uniform
self._fan_in = fan_in
self._seed = seed
def __call__(self, var, block=None):
"""Initialize the input tensor with MSRA initialization.
Args:
var(Tensor): Tensor that needs to be initialized.
block(Block, optional): The block in which initialization ops
should be added. Used in static graph only, default None.
Returns:
The initialization op
"""
block = self._check_block(block)
assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block)
f_in, f_out = self._compute_fans(var)
# If fan_in is passed, use it
fan_in = f_in if self._fan_in is None else self._fan_in
if self._seed == 0:
self._seed = block.program.random_seed
# to be compatible of fp16 initalizers
if var.dtype == VarDesc.VarType.FP16 or (
var.dtype == VarDesc.VarType.BF16 and not self._uniform):
out_dtype = VarDesc.VarType.FP32
out_var = block.create_var(
name=unique_name.generate(".".join(
['masra_init', var.name, 'tmp'])),
shape=var.shape,
dtype=out_dtype,
type=VarDesc.VarType.LOD_TENSOR,
persistable=False)
else:
out_dtype = var.dtype
out_var = var
if self._uniform:
limit = np.sqrt(1.0 / float(fan_in))
op = block.append_op(
type="uniform_random",
inputs={},
outputs={"Out": out_var},
attrs={
"shape": out_var.shape,
"dtype": int(out_dtype),
"min": -limit,
"max": limit,
"seed": self._seed
},
stop_gradient=True)
else:
std = np.sqrt(2.0 / float(fan_in))
op = block.append_op(
type="gaussian_random",
outputs={"Out": out_var},
attrs={
"shape": out_var.shape,
"dtype": int(out_dtype),
"mean": 0.0,
"std": std,
"seed": self._seed
},
stop_gradient=True)
if var.dtype == VarDesc.VarType.FP16 or (
var.dtype == VarDesc.VarType.BF16 and not self._uniform):
block.append_op(
type="cast",
inputs={"X": out_var},
outputs={"Out": var},
attrs={"in_dtype": out_var.dtype,
"out_dtype": var.dtype})
if not framework.in_dygraph_mode():
var.op = op
return op
class KaimingUniform(MSRAInitializer):
r"""Implements the Kaiming Uniform initializer
This class implements the weight initialization from the paper
`Delving Deep into Rectifiers: Surpassing Human-Level Performance on
ImageNet Classification <https://arxiv.org/abs/1502.01852>`_
by Kaiming He, Xiangyu Zhang, Shaoqing Ren and Jian Sun. This is a
robust initialization method that particularly considers the rectifier
nonlinearities.
In case of Uniform distribution, the range is [-x, x], where
.. math::
x = \sqrt{\frac{6.0}{fan\_in}}
Args:
fan_in (float32|None): fan_in for Kaiming uniform Initializer. If None, it is\
inferred from the variable. default is None.
Note:
It is recommended to set fan_in to None for most cases.
Examples:
.. code-block:: python
import paddle
import paddle.nn as nn
linear = nn.Linear(2,
4,
weight_attr=nn.initializer.KaimingUniform())
data = paddle.rand([30, 10, 2], dtype='float32')
res = linear(data)
"""
def __init__(self, fan_in=None):
super(KaimingUniform, self).__init__(
uniform=True, fan_in=fan_in, seed=0)
# We short the class name, since users will use the initializer with the package
# name. The sample code:
#
# import paddle.fluid as fluid
#
# hidden = fluid.layers.fc(...,
# param_attr=ParamAttr(fluid.initializer.Xavier()))
#
# It is no need to add an `Initializer` as the class suffix
MSRA = MSRAInitializer

@ -0,0 +1,44 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
from paddle import nn
from typeguard import check_argument_types
def initialize(model: nn.Layer, init: str):
"""Initialize weights of a neural network module.
Parameters are initialized using the given method or distribution.
Custom initialization routines can be implemented into submodules
Args:
model (nn.Layer): Target.
init (str): Method of initialization.
"""
assert check_argument_types()
if init == "xavier_uniform":
nn.initializer.set_global_initializer(nn.initializer.XavierUniform(),
nn.initializer.Constant())
elif init == "xavier_normal":
nn.initializer.set_global_initializer(nn.initializer.XavierNormal(),
nn.initializer.Constant())
elif init == "kaiming_uniform":
nn.initializer.set_global_initializer(nn.initializer.KaimingUniform(),
nn.initializer.KaimingUniform())
elif init == "kaiming_normal":
nn.initializer.set_global_initializer(nn.initializer.KaimingNormal(),
nn.initializer.Constant())
else:
raise ValueError("Unknown initialization: " + init)
Loading…
Cancel
Save