diff --git a/paddlespeech/t2s/modules/nets_utils.py b/paddlespeech/t2s/modules/nets_utils.py index 0238f4db..8cf17a6a 100644 --- a/paddlespeech/t2s/modules/nets_utils.py +++ b/paddlespeech/t2s/modules/nets_utils.py @@ -147,7 +147,7 @@ def make_pad_mask(lengths, xs=None, length_dim=-1): seq_range = paddle.arange(0, maxlen, dtype=paddle.int64) seq_range_expand = seq_range.unsqueeze(0).expand([bs, maxlen]) seq_length_expand = lengths.unsqueeze(-1) - mask = seq_range_expand >= seq_length_expand + mask = seq_range_expand >= seq_length_expand.cast(seq_range_expand.dtype) if xs is not None: assert paddle.shape(xs)[0] == bs, (paddle.shape(xs)[0], bs)