Merge pull request #2336 from Zth9730/fix_multigpu_train

[s2t] fix format test=asr
pull/2347/head
贾晓 2 years ago committed by GitHub
commit 0b544ee84e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -255,6 +255,7 @@ class BaseEncoder(nn.Layer):
xs,
att_mask,
pos_emb,
mask_pad=paddle.ones([0, 0, 0], dtype=paddle.bool),
att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
cnn_cache=cnn_cache[i:i + 1]
if paddle.shape(cnn_cache)[0] > 0 else cnn_cache, )

@ -195,8 +195,7 @@ class ConformerEncoderLayer(nn.Layer):
x: paddle.Tensor,
mask: paddle.Tensor,
pos_emb: paddle.Tensor,
mask_pad: paddle.
Tensor, # paddle.ones([0, 0, 0], dtype=paddle.bool)
mask_pad: paddle.Tensor, #paddle.ones([0, 0, 0],dtype=paddle.bool)
att_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
cnn_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:

@ -476,8 +476,12 @@ class PaddleASRConnectionHanddler:
# forward chunk
(y, self.att_cache,
self.cnn_cache) = self.model.encoder.forward_chunk(
chunk_xs, self.offset, required_cache_size, self.att_cache,
self.cnn_cache, paddle.ones([0, 0, 0], dtype=paddle.bool))
chunk_xs,
self.offset,
required_cache_size,
att_cache=self.att_cache,
cnn_cache=self.cnn_cache,
att_mask=paddle.ones([0, 0, 0], dtype=paddle.bool))
outputs.append(y)
# update the global offset, in decoding frame unit

Loading…
Cancel
Save