|
|
@ -255,6 +255,7 @@ class BaseEncoder(nn.Layer):
|
|
|
|
xs,
|
|
|
|
xs,
|
|
|
|
att_mask,
|
|
|
|
att_mask,
|
|
|
|
pos_emb,
|
|
|
|
pos_emb,
|
|
|
|
|
|
|
|
mask_pad=paddle.ones([0, 0, 0], dtype=paddle.bool),
|
|
|
|
att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
|
|
|
|
att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
|
|
|
|
cnn_cache=cnn_cache[i:i + 1]
|
|
|
|
cnn_cache=cnn_cache[i:i + 1]
|
|
|
|
if paddle.shape(cnn_cache)[0] > 0 else cnn_cache, )
|
|
|
|
if paddle.shape(cnn_cache)[0] > 0 else cnn_cache, )
|
|
|
|