|
|
|
@ -114,10 +114,7 @@ class DecoderLayer(nn.Layer):
|
|
|
|
|
], f"{cache.shape} == {[tgt.shape[0], tgt.shape[1] - 1, self.size]}"
|
|
|
|
|
tgt_q = tgt[:, -1:, :]
|
|
|
|
|
residual = residual[:, -1:, :]
|
|
|
|
|
# TODO(Hui Zhang): slice not support bool type
|
|
|
|
|
# tgt_q_mask = tgt_mask[:, -1:, :]
|
|
|
|
|
tgt_q_mask = tgt_mask.cast(paddle.int64)[:, -1:, :].cast(
|
|
|
|
|
paddle.bool)
|
|
|
|
|
tgt_q_mask = tgt_mask[:, -1:, :]
|
|
|
|
|
|
|
|
|
|
if self.concat_after:
|
|
|
|
|
tgt_concat = paddle.cat(
|
|
|
|
|