fix multigpu train

pull/2324/head
tianhao zhang 3 years ago
parent e6b23ae0c5
commit 80fc0ef71a

@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import paddle
from paddle import nn
import math
"""
To align the initializer between paddle and torch,
the API below are set defalut initializer with priority higger than global initializer.
@ -81,10 +82,18 @@ class Linear(nn.Linear):
name=None):
if weight_attr is None:
if global_init_type == "kaiming_uniform":
weight_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu'))
weight_attr = paddle.ParamAttr(
initializer=nn.initializer.KaimingUniform(
fan_in=None,
negative_slope=math.sqrt(5),
nonlinearity='leaky_relu'))
if bias_attr is None:
if global_init_type == "kaiming_uniform":
bias_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu'))
bias_attr = paddle.ParamAttr(
initializer=nn.initializer.KaimingUniform(
fan_in=None,
negative_slope=math.sqrt(5),
nonlinearity='leaky_relu'))
super(Linear, self).__init__(in_features, out_features, weight_attr,
bias_attr, name)
@ -104,10 +113,18 @@ class Conv1D(nn.Conv1D):
data_format='NCL'):
if weight_attr is None:
if global_init_type == "kaiming_uniform":
weight_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu'))
weight_attr = paddle.ParamAttr(
initializer=nn.initializer.KaimingUniform(
fan_in=None,
negative_slope=math.sqrt(5),
nonlinearity='leaky_relu'))
if bias_attr is None:
if global_init_type == "kaiming_uniform":
bias_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu'))
bias_attr = paddle.ParamAttr(
initializer=nn.initializer.KaimingUniform(
fan_in=None,
negative_slope=math.sqrt(5),
nonlinearity='leaky_relu'))
super(Conv1D, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, padding_mode, weight_attr, bias_attr, data_format)
@ -128,10 +145,18 @@ class Conv2D(nn.Conv2D):
data_format='NCHW'):
if weight_attr is None:
if global_init_type == "kaiming_uniform":
weight_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu'))
weight_attr = paddle.ParamAttr(
initializer=nn.initializer.KaimingUniform(
fan_in=None,
negative_slope=math.sqrt(5),
nonlinearity='leaky_relu'))
if bias_attr is None:
if global_init_type == "kaiming_uniform":
bias_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu'))
bias_attr = paddle.ParamAttr(
initializer=nn.initializer.KaimingUniform(
fan_in=None,
negative_slope=math.sqrt(5),
nonlinearity='leaky_relu'))
super(Conv2D, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, padding_mode, weight_attr, bias_attr, data_format)

@ -83,11 +83,11 @@ class MultiHeadedAttention(nn.Layer):
return q, k, v
def forward_attention(self,
value: paddle.Tensor,
def forward_attention(
self,
value: paddle.Tensor,
scores: paddle.Tensor,
mask: paddle.Tensor,
) -> paddle.Tensor:
mask: paddle.Tensor, ) -> paddle.Tensor:
"""Compute attention context vector.
Args:
value (paddle.Tensor): Transformed value, size
@ -108,7 +108,7 @@ class MultiHeadedAttention(nn.Layer):
# When will `if mask.size(2) > 0` be False?
# 1. onnx(16/-1, -1/-1, 16/0)
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
if paddle.shape(mask)[2] > 0: # time2 > 0
if paddle.shape(mask)[2] > 0: # time2 > 0
mask = mask.unsqueeze(1).equal(0) # (batch, 1, *, time2)
# for last chunk, time2 might be larger than scores.size(-1)
mask = mask[:, :, :, :paddle.shape(scores)[-1]]
@ -133,8 +133,7 @@ class MultiHeadedAttention(nn.Layer):
value: paddle.Tensor,
mask: paddle.Tensor,
pos_emb: paddle.Tensor,
cache: paddle.Tensor
) -> Tuple[paddle.Tensor, paddle.Tensor]:
cache: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute scaled dot product attention.
Args:
query (paddle.Tensor): Query tensor (#batch, time1, size).
@ -249,8 +248,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
value: paddle.Tensor,
mask: paddle.Tensor,
pos_emb: paddle.Tensor,
cache: paddle.Tensor
) -> Tuple[paddle.Tensor, paddle.Tensor]:
cache: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (paddle.Tensor): Query tensor (#batch, time1, size).

@ -109,8 +109,7 @@ class ConvolutionModule(nn.Layer):
def forward(self,
x: paddle.Tensor,
mask_pad: paddle.Tensor,
cache: paddle.Tensor
) -> Tuple[paddle.Tensor, paddle.Tensor]:
cache: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute convolution module.
Args:
x (paddle.Tensor): Input tensor (#batch, time, channels).
@ -127,11 +126,11 @@ class ConvolutionModule(nn.Layer):
x = x.transpose([0, 2, 1]) # [B, C, T]
# mask batch padding
if paddle.shape(mask_pad)[2] > 0: # time > 0
if paddle.shape(mask_pad)[2] > 0: # time > 0
x = x.masked_fill(mask_pad, 0.0)
if self.lorder > 0:
if paddle.shape(cache)[2] == 0: # cache_t == 0
if paddle.shape(cache)[2] == 0: # cache_t == 0
x = nn.functional.pad(
x, [self.lorder, 0], 'constant', 0.0, data_format='NCL')
else:
@ -161,7 +160,7 @@ class ConvolutionModule(nn.Layer):
x = self.pointwise_conv2(x)
# mask batch padding
if paddle.shape(mask_pad)[2] > 0: # time > 0
if paddle.shape(mask_pad)[2] > 0: # time > 0
x = x.masked_fill(mask_pad, 0.0)
x = x.transpose([0, 2, 1]) # [B, T, C]

@ -122,12 +122,15 @@ class DecoderLayer(nn.Layer):
if self.concat_after:
tgt_concat = paddle.cat(
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask,
paddle.empty([0]), paddle.zeros([0,0,0,0]))[0]), dim=-1)
paddle.empty([0]),
paddle.zeros([0, 0, 0, 0]))[0]),
dim=-1)
x = residual + self.concat_linear1(tgt_concat)
else:
x = residual + self.dropout(
self.self_attn(tgt_q, tgt, tgt, tgt_q_mask,
paddle.empty([0]), paddle.zeros([0,0,0,0]))[0])
paddle.empty([0]), paddle.zeros([0, 0, 0, 0]))[
0])
if not self.normalize_before:
x = self.norm1(x)
@ -137,12 +140,14 @@ class DecoderLayer(nn.Layer):
if self.concat_after:
x_concat = paddle.cat(
(x, self.src_attn(x, memory, memory, memory_mask,
paddle.empty([0]), paddle.zeros([0,0,0,0]))[0]), dim=-1)
paddle.empty([0]),
paddle.zeros([0, 0, 0, 0]))[0]),
dim=-1)
x = residual + self.concat_linear2(x_concat)
else:
x = residual + self.dropout(
self.src_attn(x, memory, memory, memory_mask,
paddle.empty([0]), paddle.zeros([0,0,0,0]))[0])
paddle.empty([0]), paddle.zeros([0, 0, 0, 0]))[0])
if not self.normalize_before:
x = self.norm2(x)

@ -177,8 +177,9 @@ class BaseEncoder(nn.Layer):
decoding_chunk_size, self.static_chunk_size,
num_decoding_left_chunks)
for layer in self.encoders:
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad,
paddle.zeros([0, 0, 0, 0]), paddle.zeros([0, 0, 0, 0]))
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad,
paddle.zeros([0, 0, 0, 0]),
paddle.zeros([0, 0, 0, 0]))
if self.normalize_before:
xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just
@ -228,7 +229,7 @@ class BaseEncoder(nn.Layer):
xs = self.global_cmvn(xs)
# before embed, xs=(B, T, D1), pos_emb=(B=1, T, D)
xs, pos_emb, _ = self.embed(xs, tmp_masks, offset=offset)
xs, pos_emb, _ = self.embed(xs, tmp_masks, offset=offset)
# after embed, xs=(B=1, chunk_size, hidden-dim)
elayers = paddle.shape(att_cache)[0]
@ -253,14 +254,16 @@ class BaseEncoder(nn.Layer):
# att_cache[i:i+1] = (1, head, cache_t1, d_k*2)
# cnn_cache[i:i+1] = (1, B=1, hidden-dim, cache_t2)
xs, _, new_att_cache, new_cnn_cache = layer(
xs, att_mask, pos_emb,
att_cache=att_cache[i:i+1] if elayers > 0 else att_cache,
cnn_cache=cnn_cache[i:i+1] if paddle.shape(cnn_cache)[0] > 0 else cnn_cache,
)
xs,
att_mask,
pos_emb,
att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
cnn_cache=cnn_cache[i:i + 1]
if paddle.shape(cnn_cache)[0] > 0 else cnn_cache, )
# new_att_cache = (1, head, attention_key_size, d_k*2)
# new_cnn_cache = (B=1, hidden-dim, cache_t2)
r_att_cache.append(new_att_cache[:,:, next_cache_start:, :])
r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) # add elayer dim
r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) # add elayer dim
if self.normalize_before:
xs = self.after_norm(xs)
@ -271,7 +274,6 @@ class BaseEncoder(nn.Layer):
r_cnn_cache = paddle.concat(r_cnn_cache, axis=0)
return xs, r_att_cache, r_cnn_cache
def forward_chunk_by_chunk(
self,
xs: paddle.Tensor,
@ -316,8 +318,8 @@ class BaseEncoder(nn.Layer):
num_frames = xs.shape[1]
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
att_cache: paddle.Tensor = paddle.zeros([0,0,0,0])
cnn_cache: paddle.Tensor = paddle.zeros([0,0,0,0])
att_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0])
cnn_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0])
outputs = []
offset = 0
@ -327,7 +329,7 @@ class BaseEncoder(nn.Layer):
chunk_xs = xs[:, cur:end, :]
(y, att_cache, cnn_cache) = self.forward_chunk(
chunk_xs, offset, required_cache_size, att_cache, cnn_cache)
chunk_xs, offset, required_cache_size, att_cache, cnn_cache)
outputs.append(y)
offset += y.shape[1]

@ -105,7 +105,8 @@ class TransformerEncoderLayer(nn.Layer):
if self.normalize_before:
x = self.norm1(x)
x_att, new_att_cache = self.self_attn(x, x, x, mask, paddle.empty([0]), cache=att_cache)
x_att, new_att_cache = self.self_attn(
x, x, x, mask, paddle.empty([0]), cache=att_cache)
if self.concat_after:
x_concat = paddle.concat((x, x_att), axis=-1)

@ -13,6 +13,7 @@
# limitations under the License.
import numpy as np
class DefaultInitializerContext(object):
"""
egs:

Loading…
Cancel
Save