[Fix] vctk type promotion

pull/3943/head
megemini 10 months ago
parent a34bf501a5
commit 26ba0ff684

@ -1115,7 +1115,8 @@ class MLMLoss(nn.Layer):
paddle.reshape(xs_pad, (-1, self.odim))), paddle.reshape(xs_pad, (-1, self.odim))),
axis=-1) axis=-1)
mlm_loss = paddle.sum((loss * paddle.reshape( mlm_loss = paddle.sum((loss * paddle.reshape(
mlm_loss_pos, [-1]))) / paddle.sum((mlm_loss_pos) + 1e-10) mlm_loss_pos.astype(loss.dtype),
[-1]))) / paddle.sum((mlm_loss_pos.astype(loss.dtype)) + 1e-10)
text_mlm_loss = None text_mlm_loss = None

@ -466,7 +466,7 @@ def phones_masking(xs_pad: paddle.Tensor,
for s, e in zip(masked_start, masked_end): for s, e in zip(masked_start, masked_end):
masked_pos[idx, s:e] = 1 masked_pos[idx, s:e] = 1
non_eos_mask = paddle.reshape(src_mask, paddle.shape(xs_pad)[:2]) non_eos_mask = paddle.reshape(src_mask, paddle.shape(xs_pad)[:2])
masked_pos = masked_pos * non_eos_mask masked_pos = masked_pos * non_eos_mask.astype(masked_pos.dtype)
masked_pos = paddle.cast(masked_pos, 'bool') masked_pos = paddle.cast(masked_pos, 'bool')
return masked_pos return masked_pos

Loading…
Cancel
Save