fix multigpu train

pull/2324/head
tianhao zhang 3 years ago
parent e6b23ae0c5
commit 80fc0ef71a

@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math
import paddle import paddle
from paddle import nn from paddle import nn
import math
""" """
To align the initializer between paddle and torch, To align the initializer between paddle and torch,
the API below are set defalut initializer with priority higger than global initializer. the API below are set defalut initializer with priority higger than global initializer.
@ -81,10 +82,18 @@ class Linear(nn.Linear):
name=None): name=None):
if weight_attr is None: if weight_attr is None:
if global_init_type == "kaiming_uniform": if global_init_type == "kaiming_uniform":
weight_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu')) weight_attr = paddle.ParamAttr(
initializer=nn.initializer.KaimingUniform(
fan_in=None,
negative_slope=math.sqrt(5),
nonlinearity='leaky_relu'))
if bias_attr is None: if bias_attr is None:
if global_init_type == "kaiming_uniform": if global_init_type == "kaiming_uniform":
bias_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu')) bias_attr = paddle.ParamAttr(
initializer=nn.initializer.KaimingUniform(
fan_in=None,
negative_slope=math.sqrt(5),
nonlinearity='leaky_relu'))
super(Linear, self).__init__(in_features, out_features, weight_attr, super(Linear, self).__init__(in_features, out_features, weight_attr,
bias_attr, name) bias_attr, name)
@ -104,10 +113,18 @@ class Conv1D(nn.Conv1D):
data_format='NCL'): data_format='NCL'):
if weight_attr is None: if weight_attr is None:
if global_init_type == "kaiming_uniform": if global_init_type == "kaiming_uniform":
weight_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu')) weight_attr = paddle.ParamAttr(
initializer=nn.initializer.KaimingUniform(
fan_in=None,
negative_slope=math.sqrt(5),
nonlinearity='leaky_relu'))
if bias_attr is None: if bias_attr is None:
if global_init_type == "kaiming_uniform": if global_init_type == "kaiming_uniform":
bias_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu')) bias_attr = paddle.ParamAttr(
initializer=nn.initializer.KaimingUniform(
fan_in=None,
negative_slope=math.sqrt(5),
nonlinearity='leaky_relu'))
super(Conv1D, self).__init__( super(Conv1D, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation, in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, padding_mode, weight_attr, bias_attr, data_format) groups, padding_mode, weight_attr, bias_attr, data_format)
@ -128,10 +145,18 @@ class Conv2D(nn.Conv2D):
data_format='NCHW'): data_format='NCHW'):
if weight_attr is None: if weight_attr is None:
if global_init_type == "kaiming_uniform": if global_init_type == "kaiming_uniform":
weight_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu')) weight_attr = paddle.ParamAttr(
initializer=nn.initializer.KaimingUniform(
fan_in=None,
negative_slope=math.sqrt(5),
nonlinearity='leaky_relu'))
if bias_attr is None: if bias_attr is None:
if global_init_type == "kaiming_uniform": if global_init_type == "kaiming_uniform":
bias_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu')) bias_attr = paddle.ParamAttr(
initializer=nn.initializer.KaimingUniform(
fan_in=None,
negative_slope=math.sqrt(5),
nonlinearity='leaky_relu'))
super(Conv2D, self).__init__( super(Conv2D, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation, in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, padding_mode, weight_attr, bias_attr, data_format) groups, padding_mode, weight_attr, bias_attr, data_format)

@ -83,11 +83,11 @@ class MultiHeadedAttention(nn.Layer):
return q, k, v return q, k, v
def forward_attention(self, def forward_attention(
self,
value: paddle.Tensor, value: paddle.Tensor,
scores: paddle.Tensor, scores: paddle.Tensor,
mask: paddle.Tensor, mask: paddle.Tensor, ) -> paddle.Tensor:
) -> paddle.Tensor:
"""Compute attention context vector. """Compute attention context vector.
Args: Args:
value (paddle.Tensor): Transformed value, size value (paddle.Tensor): Transformed value, size
@ -133,8 +133,7 @@ class MultiHeadedAttention(nn.Layer):
value: paddle.Tensor, value: paddle.Tensor,
mask: paddle.Tensor, mask: paddle.Tensor,
pos_emb: paddle.Tensor, pos_emb: paddle.Tensor,
cache: paddle.Tensor cache: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]:
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute scaled dot product attention. """Compute scaled dot product attention.
Args: Args:
query (paddle.Tensor): Query tensor (#batch, time1, size). query (paddle.Tensor): Query tensor (#batch, time1, size).
@ -249,8 +248,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
value: paddle.Tensor, value: paddle.Tensor,
mask: paddle.Tensor, mask: paddle.Tensor,
pos_emb: paddle.Tensor, pos_emb: paddle.Tensor,
cache: paddle.Tensor cache: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]:
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding. """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args: Args:
query (paddle.Tensor): Query tensor (#batch, time1, size). query (paddle.Tensor): Query tensor (#batch, time1, size).

@ -109,8 +109,7 @@ class ConvolutionModule(nn.Layer):
def forward(self, def forward(self,
x: paddle.Tensor, x: paddle.Tensor,
mask_pad: paddle.Tensor, mask_pad: paddle.Tensor,
cache: paddle.Tensor cache: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]:
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute convolution module. """Compute convolution module.
Args: Args:
x (paddle.Tensor): Input tensor (#batch, time, channels). x (paddle.Tensor): Input tensor (#batch, time, channels).

@ -122,12 +122,15 @@ class DecoderLayer(nn.Layer):
if self.concat_after: if self.concat_after:
tgt_concat = paddle.cat( tgt_concat = paddle.cat(
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask, (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask,
paddle.empty([0]), paddle.zeros([0,0,0,0]))[0]), dim=-1) paddle.empty([0]),
paddle.zeros([0, 0, 0, 0]))[0]),
dim=-1)
x = residual + self.concat_linear1(tgt_concat) x = residual + self.concat_linear1(tgt_concat)
else: else:
x = residual + self.dropout( x = residual + self.dropout(
self.self_attn(tgt_q, tgt, tgt, tgt_q_mask, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask,
paddle.empty([0]), paddle.zeros([0,0,0,0]))[0]) paddle.empty([0]), paddle.zeros([0, 0, 0, 0]))[
0])
if not self.normalize_before: if not self.normalize_before:
x = self.norm1(x) x = self.norm1(x)
@ -137,12 +140,14 @@ class DecoderLayer(nn.Layer):
if self.concat_after: if self.concat_after:
x_concat = paddle.cat( x_concat = paddle.cat(
(x, self.src_attn(x, memory, memory, memory_mask, (x, self.src_attn(x, memory, memory, memory_mask,
paddle.empty([0]), paddle.zeros([0,0,0,0]))[0]), dim=-1) paddle.empty([0]),
paddle.zeros([0, 0, 0, 0]))[0]),
dim=-1)
x = residual + self.concat_linear2(x_concat) x = residual + self.concat_linear2(x_concat)
else: else:
x = residual + self.dropout( x = residual + self.dropout(
self.src_attn(x, memory, memory, memory_mask, self.src_attn(x, memory, memory, memory_mask,
paddle.empty([0]), paddle.zeros([0,0,0,0]))[0]) paddle.empty([0]), paddle.zeros([0, 0, 0, 0]))[0])
if not self.normalize_before: if not self.normalize_before:
x = self.norm2(x) x = self.norm2(x)

@ -178,7 +178,8 @@ class BaseEncoder(nn.Layer):
num_decoding_left_chunks) num_decoding_left_chunks)
for layer in self.encoders: for layer in self.encoders:
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad, xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad,
paddle.zeros([0, 0, 0, 0]), paddle.zeros([0, 0, 0, 0])) paddle.zeros([0, 0, 0, 0]),
paddle.zeros([0, 0, 0, 0]))
if self.normalize_before: if self.normalize_before:
xs = self.after_norm(xs) xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just # Here we assume the mask is not changed in encoder layers, so just
@ -253,13 +254,15 @@ class BaseEncoder(nn.Layer):
# att_cache[i:i+1] = (1, head, cache_t1, d_k*2) # att_cache[i:i+1] = (1, head, cache_t1, d_k*2)
# cnn_cache[i:i+1] = (1, B=1, hidden-dim, cache_t2) # cnn_cache[i:i+1] = (1, B=1, hidden-dim, cache_t2)
xs, _, new_att_cache, new_cnn_cache = layer( xs, _, new_att_cache, new_cnn_cache = layer(
xs, att_mask, pos_emb, xs,
att_cache=att_cache[i:i+1] if elayers > 0 else att_cache, att_mask,
cnn_cache=cnn_cache[i:i+1] if paddle.shape(cnn_cache)[0] > 0 else cnn_cache, pos_emb,
) att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
cnn_cache=cnn_cache[i:i + 1]
if paddle.shape(cnn_cache)[0] > 0 else cnn_cache, )
# new_att_cache = (1, head, attention_key_size, d_k*2) # new_att_cache = (1, head, attention_key_size, d_k*2)
# new_cnn_cache = (B=1, hidden-dim, cache_t2) # new_cnn_cache = (B=1, hidden-dim, cache_t2)
r_att_cache.append(new_att_cache[:,:, next_cache_start:, :]) r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) # add elayer dim r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) # add elayer dim
if self.normalize_before: if self.normalize_before:
@ -271,7 +274,6 @@ class BaseEncoder(nn.Layer):
r_cnn_cache = paddle.concat(r_cnn_cache, axis=0) r_cnn_cache = paddle.concat(r_cnn_cache, axis=0)
return xs, r_att_cache, r_cnn_cache return xs, r_att_cache, r_cnn_cache
def forward_chunk_by_chunk( def forward_chunk_by_chunk(
self, self,
xs: paddle.Tensor, xs: paddle.Tensor,
@ -316,8 +318,8 @@ class BaseEncoder(nn.Layer):
num_frames = xs.shape[1] num_frames = xs.shape[1]
required_cache_size = decoding_chunk_size * num_decoding_left_chunks required_cache_size = decoding_chunk_size * num_decoding_left_chunks
att_cache: paddle.Tensor = paddle.zeros([0,0,0,0]) att_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0])
cnn_cache: paddle.Tensor = paddle.zeros([0,0,0,0]) cnn_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0])
outputs = [] outputs = []
offset = 0 offset = 0

@ -105,7 +105,8 @@ class TransformerEncoderLayer(nn.Layer):
if self.normalize_before: if self.normalize_before:
x = self.norm1(x) x = self.norm1(x)
x_att, new_att_cache = self.self_attn(x, x, x, mask, paddle.empty([0]), cache=att_cache) x_att, new_att_cache = self.self_attn(
x, x, x, mask, paddle.empty([0]), cache=att_cache)
if self.concat_after: if self.concat_after:
x_concat = paddle.concat((x, x_att), axis=-1) x_concat = paddle.concat((x, x_att), axis=-1)

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
class DefaultInitializerContext(object): class DefaultInitializerContext(object):
""" """
egs: egs:

Loading…
Cancel
Save