Merge pull request #2372 from Zth9730/fix_dp_init

[s2t] DataParallel init method changed, fixed conformer could not multi-gpu training and don't affect dy2st
pull/2390/head
Hui Zhang 2 years ago committed by GitHub
commit 07f566e0a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -605,8 +605,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
xs: paddle.Tensor,
offset: int,
required_cache_size: int,
att_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
cnn_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" Export interface for c++ call, give input chunk xs, and return
output from time 0 to current chunk.

@ -86,7 +86,7 @@ class MultiHeadedAttention(nn.Layer):
self,
value: paddle.Tensor,
scores: paddle.Tensor,
mask: paddle.Tensor, # paddle.ones([0, 0, 0], dtype=paddle.bool)
mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool)
) -> paddle.Tensor:
"""Compute attention context vector.
Args:
@ -127,15 +127,14 @@ class MultiHeadedAttention(nn.Layer):
return self.linear_out(x) # (batch, time1, d_model)
def forward(
self,
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
mask: paddle.Tensor, # paddle.ones([0,0,0], dtype=paddle.bool)
pos_emb: paddle.Tensor, # paddle.empty([0])
cache: paddle.Tensor # paddle.zeros([0,0,0,0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
def forward(self,
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
pos_emb: paddle.Tensor=paddle.empty([0]),
cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute scaled dot product attention.
Args:
query (paddle.Tensor): Query tensor (#batch, time1, size).
@ -244,15 +243,14 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
return x
def forward(
self,
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
mask: paddle.Tensor, # paddle.ones([0,0,0], dtype=paddle.bool)
pos_emb: paddle.Tensor, # paddle.empty([0])
cache: paddle.Tensor # paddle.zeros([0,0,0,0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
def forward(self,
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
pos_emb: paddle.Tensor=paddle.empty([0]),
cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (paddle.Tensor): Query tensor (#batch, time1, size).

@ -108,8 +108,8 @@ class ConvolutionModule(nn.Layer):
def forward(
self,
x: paddle.Tensor,
mask_pad: paddle.Tensor, # paddle.ones([0,0,0], dtype=paddle.bool)
cache: paddle.Tensor # paddle.zeros([0,0,0,0])
mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute convolution module.
Args:

@ -121,16 +121,11 @@ class DecoderLayer(nn.Layer):
if self.concat_after:
tgt_concat = paddle.cat(
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask,
paddle.empty([0]),
paddle.zeros([0, 0, 0, 0]))[0]),
dim=-1)
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]), dim=-1)
x = residual + self.concat_linear1(tgt_concat)
else:
x = residual + self.dropout(
self.self_attn(tgt_q, tgt, tgt, tgt_q_mask,
paddle.empty([0]), paddle.zeros([0, 0, 0, 0]))[
0])
self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
if not self.normalize_before:
x = self.norm1(x)
@ -139,15 +134,11 @@ class DecoderLayer(nn.Layer):
x = self.norm2(x)
if self.concat_after:
x_concat = paddle.cat(
(x, self.src_attn(x, memory, memory, memory_mask,
paddle.empty([0]),
paddle.zeros([0, 0, 0, 0]))[0]),
dim=-1)
(x, self.src_attn(x, memory, memory, memory_mask)[0]), dim=-1)
x = residual + self.concat_linear2(x_concat)
else:
x = residual + self.dropout(
self.src_attn(x, memory, memory, memory_mask,
paddle.empty([0]), paddle.zeros([0, 0, 0, 0]))[0])
self.src_attn(x, memory, memory, memory_mask)[0])
if not self.normalize_before:
x = self.norm2(x)

@ -175,9 +175,7 @@ class BaseEncoder(nn.Layer):
decoding_chunk_size, self.static_chunk_size,
num_decoding_left_chunks)
for layer in self.encoders:
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad,
paddle.zeros([0, 0, 0, 0]),
paddle.zeros([0, 0, 0, 0]))
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
if self.normalize_before:
xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just
@ -190,9 +188,9 @@ class BaseEncoder(nn.Layer):
xs: paddle.Tensor,
offset: int,
required_cache_size: int,
att_cache: paddle.Tensor, # paddle.zeros([0,0,0,0])
cnn_cache: paddle.Tensor, # paddle.zeros([0,0,0,0]),
att_mask: paddle.Tensor, # paddle.ones([0,0,0], dtype=paddle.bool)
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
att_mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool)
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" Forward just one chunk
Args:
@ -255,7 +253,6 @@ class BaseEncoder(nn.Layer):
xs,
att_mask,
pos_emb,
mask_pad=paddle.ones([0, 0, 0], dtype=paddle.bool),
att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
cnn_cache=cnn_cache[i:i + 1]
if paddle.shape(cnn_cache)[0] > 0 else cnn_cache, )
@ -328,8 +325,7 @@ class BaseEncoder(nn.Layer):
chunk_xs = xs[:, cur:end, :]
(y, att_cache, cnn_cache) = self.forward_chunk(
chunk_xs, offset, required_cache_size, att_cache, cnn_cache,
paddle.ones([0, 0, 0], dtype=paddle.bool))
chunk_xs, offset, required_cache_size, att_cache, cnn_cache)
outputs.append(y)
offset += y.shape[1]

@ -76,10 +76,9 @@ class TransformerEncoderLayer(nn.Layer):
x: paddle.Tensor,
mask: paddle.Tensor,
pos_emb: paddle.Tensor,
mask_pad: paddle.
Tensor, # paddle.ones([0, 0, 0], dtype=paddle.bool)
att_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
cnn_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute encoded features.
Args:
@ -106,8 +105,7 @@ class TransformerEncoderLayer(nn.Layer):
if self.normalize_before:
x = self.norm1(x)
x_att, new_att_cache = self.self_attn(
x, x, x, mask, paddle.empty([0]), cache=att_cache)
x_att, new_att_cache = self.self_attn(x, x, x, mask, cache=att_cache)
if self.concat_after:
x_concat = paddle.concat((x, x_att), axis=-1)
@ -195,9 +193,9 @@ class ConformerEncoderLayer(nn.Layer):
x: paddle.Tensor,
mask: paddle.Tensor,
pos_emb: paddle.Tensor,
mask_pad: paddle.Tensor, #paddle.ones([0, 0, 0],dtype=paddle.bool)
att_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
cnn_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute encoded features.
Args:

@ -19,6 +19,10 @@ from pathlib import Path
import paddle
from paddle import distributed as dist
world_size = dist.get_world_size()
if world_size > 1:
dist.init_parallel_env()
from visualdl import LogWriter
from paddlespeech.s2t.training.reporter import ObsScope
@ -122,9 +126,6 @@ class Trainer():
else:
raise Exception("invalid device")
if self.parallel:
self.init_parallel()
self.checkpoint = Checkpoint(
kbest_n=self.config.checkpoint.kbest_n,
latest_n=self.config.checkpoint.latest_n)
@ -173,11 +174,6 @@ class Trainer():
"""
return self.args.ngpu > 1
def init_parallel(self):
"""Init environment for multiprocess training.
"""
dist.init_parallel_env()
@mp_tools.rank_zero_only
def save(self, tag=None, infos: dict=None):
"""Save checkpoint (model parameters and optimizer states).

@ -480,8 +480,7 @@ class PaddleASRConnectionHanddler:
self.offset,
required_cache_size,
att_cache=self.att_cache,
cnn_cache=self.cnn_cache,
att_mask=paddle.ones([0, 0, 0], dtype=paddle.bool))
cnn_cache=self.cnn_cache)
outputs.append(y)
# update the global offset, in decoding frame unit

Loading…
Cancel
Save