fix input dtype of elementwise_mul op from bool to int64 (#3054)

pull/3056/head
TianYuan 1 year ago committed by GitHub
parent 31a4562ae8
commit d5720e4e7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -559,8 +559,9 @@ class VITSGenerator(nn.Layer):
y_lengths = paddle.cast(
paddle.clip(paddle.sum(dur, [1, 2]), min=1), dtype='int64')
y_mask = make_non_pad_mask(y_lengths).unsqueeze(1)
attn_mask = paddle.unsqueeze(x_mask, 2) * paddle.unsqueeze(y_mask,
-1)
tmp_a = paddle.cast(paddle.unsqueeze(x_mask, 2), dtype='int64')
tmp_b = paddle.cast(paddle.unsqueeze(y_mask, -1), dtype='int64')
attn_mask = tmp_a * tmp_b
attn = self._generate_path(dur, attn_mask)
# expand the length to match with the feature sequence

Loading…
Cancel
Save