fixed multi-gpu training

pull/2324/head
tianhao zhang 3 years ago
parent ed16f96a9c
commit e6b23ae0c5

@ -605,8 +605,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
xs: paddle.Tensor, xs: paddle.Tensor,
offset: int, offset: int,
required_cache_size: int, required_cache_size: int,
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), att_cache: paddle.Tensor,
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), cnn_cache: paddle.Tensor,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" Export interface for c++ call, give input chunk xs, and return """ Export interface for c++ call, give input chunk xs, and return
output from time 0 to current chunk. output from time 0 to current chunk.

@ -86,7 +86,7 @@ class MultiHeadedAttention(nn.Layer):
def forward_attention(self, def forward_attention(self,
value: paddle.Tensor, value: paddle.Tensor,
scores: paddle.Tensor, scores: paddle.Tensor,
mask: paddle.Tensor = paddle.ones([0, 0, 0], dtype=paddle.bool), mask: paddle.Tensor,
) -> paddle.Tensor: ) -> paddle.Tensor:
"""Compute attention context vector. """Compute attention context vector.
Args: Args:
@ -131,9 +131,9 @@ class MultiHeadedAttention(nn.Layer):
query: paddle.Tensor, query: paddle.Tensor,
key: paddle.Tensor, key: paddle.Tensor,
value: paddle.Tensor, value: paddle.Tensor,
mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool), mask: paddle.Tensor,
pos_emb: paddle.Tensor = paddle.empty([0]), pos_emb: paddle.Tensor,
cache: paddle.Tensor = paddle.zeros([0,0,0,0]) cache: paddle.Tensor
) -> Tuple[paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute scaled dot product attention. """Compute scaled dot product attention.
Args: Args:
@ -247,9 +247,9 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
query: paddle.Tensor, query: paddle.Tensor,
key: paddle.Tensor, key: paddle.Tensor,
value: paddle.Tensor, value: paddle.Tensor,
mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool), mask: paddle.Tensor,
pos_emb: paddle.Tensor = paddle.empty([0]), pos_emb: paddle.Tensor,
cache: paddle.Tensor = paddle.zeros([0,0,0,0]) cache: paddle.Tensor
) -> Tuple[paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding. """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args: Args:

@ -108,8 +108,8 @@ class ConvolutionModule(nn.Layer):
def forward(self, def forward(self,
x: paddle.Tensor, x: paddle.Tensor,
mask_pad: paddle.Tensor= paddle.ones([0,0,0], dtype=paddle.bool), mask_pad: paddle.Tensor,
cache: paddle.Tensor= paddle.zeros([0,0,0]), cache: paddle.Tensor
) -> Tuple[paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute convolution module. """Compute convolution module.
Args: Args:

@ -121,11 +121,13 @@ class DecoderLayer(nn.Layer):
if self.concat_after: if self.concat_after:
tgt_concat = paddle.cat( tgt_concat = paddle.cat(
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]), dim=-1) (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask,
paddle.empty([0]), paddle.zeros([0,0,0,0]))[0]), dim=-1)
x = residual + self.concat_linear1(tgt_concat) x = residual + self.concat_linear1(tgt_concat)
else: else:
x = residual + self.dropout( x = residual + self.dropout(
self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]) self.self_attn(tgt_q, tgt, tgt, tgt_q_mask,
paddle.empty([0]), paddle.zeros([0,0,0,0]))[0])
if not self.normalize_before: if not self.normalize_before:
x = self.norm1(x) x = self.norm1(x)
@ -134,11 +136,13 @@ class DecoderLayer(nn.Layer):
x = self.norm2(x) x = self.norm2(x)
if self.concat_after: if self.concat_after:
x_concat = paddle.cat( x_concat = paddle.cat(
(x, self.src_attn(x, memory, memory, memory_mask)[0]), dim=-1) (x, self.src_attn(x, memory, memory, memory_mask,
paddle.empty([0]), paddle.zeros([0,0,0,0]))[0]), dim=-1)
x = residual + self.concat_linear2(x_concat) x = residual + self.concat_linear2(x_concat)
else: else:
x = residual + self.dropout( x = residual + self.dropout(
self.src_attn(x, memory, memory, memory_mask)[0]) self.src_attn(x, memory, memory, memory_mask,
paddle.empty([0]), paddle.zeros([0,0,0,0]))[0])
if not self.normalize_before: if not self.normalize_before:
x = self.norm2(x) x = self.norm2(x)

@ -177,7 +177,8 @@ class BaseEncoder(nn.Layer):
decoding_chunk_size, self.static_chunk_size, decoding_chunk_size, self.static_chunk_size,
num_decoding_left_chunks) num_decoding_left_chunks)
for layer in self.encoders: for layer in self.encoders:
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad,
paddle.zeros([0, 0, 0, 0]), paddle.zeros([0, 0, 0, 0]))
if self.normalize_before: if self.normalize_before:
xs = self.after_norm(xs) xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just # Here we assume the mask is not changed in encoder layers, so just
@ -190,9 +191,9 @@ class BaseEncoder(nn.Layer):
xs: paddle.Tensor, xs: paddle.Tensor,
offset: int, offset: int,
required_cache_size: int, required_cache_size: int,
att_cache: paddle.Tensor = paddle.zeros([0,0,0,0]), att_cache: paddle.Tensor,
cnn_cache: paddle.Tensor = paddle.zeros([0,0,0,0]), cnn_cache: paddle.Tensor,
att_mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool), att_mask: paddle.Tensor,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" Forward just one chunk """ Forward just one chunk
Args: Args:

@ -76,9 +76,9 @@ class TransformerEncoderLayer(nn.Layer):
x: paddle.Tensor, x: paddle.Tensor,
mask: paddle.Tensor, mask: paddle.Tensor,
pos_emb: paddle.Tensor, pos_emb: paddle.Tensor,
mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), mask_pad: paddle.Tensor,
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), att_cache: paddle.Tensor,
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), cnn_cache: paddle.Tensor,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute encoded features. """Compute encoded features.
Args: Args:
@ -105,7 +105,7 @@ class TransformerEncoderLayer(nn.Layer):
if self.normalize_before: if self.normalize_before:
x = self.norm1(x) x = self.norm1(x)
x_att, new_att_cache = self.self_attn(x, x, x, mask, cache=att_cache) x_att, new_att_cache = self.self_attn(x, x, x, mask, paddle.empty([0]), cache=att_cache)
if self.concat_after: if self.concat_after:
x_concat = paddle.concat((x, x_att), axis=-1) x_concat = paddle.concat((x, x_att), axis=-1)
@ -193,9 +193,9 @@ class ConformerEncoderLayer(nn.Layer):
x: paddle.Tensor, x: paddle.Tensor,
mask: paddle.Tensor, mask: paddle.Tensor,
pos_emb: paddle.Tensor, pos_emb: paddle.Tensor,
mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), mask_pad: paddle.Tensor,
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), att_cache: paddle.Tensor,
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), cnn_cache: paddle.Tensor,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute encoded features. """Compute encoded features.
Args: Args:

Loading…
Cancel
Save