bool type slice

pull/2425/head
Hui Zhang 3 years ago
parent c2c8a662b1
commit 3d7ca93861

@ -114,10 +114,7 @@ class DecoderLayer(nn.Layer):
], f"{cache.shape} == {[tgt.shape[0], tgt.shape[1] - 1, self.size]}"
tgt_q = tgt[:, -1:, :]
residual = residual[:, -1:, :]
# TODO(Hui Zhang): slice not support bool type
# tgt_q_mask = tgt_mask[:, -1:, :]
tgt_q_mask = tgt_mask.cast(paddle.int64)[:, -1:, :].cast(
paddle.bool)
tgt_q_mask = tgt_mask[:, -1:, :]
if self.concat_after:
tgt_concat = paddle.cat(

Loading…
Cancel
Save