diff --git a/paddlespeech/s2t/modules/decoder_layer.py b/paddlespeech/s2t/modules/decoder_layer.py index 37b124e84..cb7261107 100644 --- a/paddlespeech/s2t/modules/decoder_layer.py +++ b/paddlespeech/s2t/modules/decoder_layer.py @@ -114,10 +114,7 @@ class DecoderLayer(nn.Layer): ], f"{cache.shape} == {[tgt.shape[0], tgt.shape[1] - 1, self.size]}" tgt_q = tgt[:, -1:, :] residual = residual[:, -1:, :] - # TODO(Hui Zhang): slice not support bool type - # tgt_q_mask = tgt_mask[:, -1:, :] - tgt_q_mask = tgt_mask.cast(paddle.int64)[:, -1:, :].cast( - paddle.bool) + tgt_q_mask = tgt_mask[:, -1:, :] if self.concat_after: tgt_concat = paddle.cat(