diff --git a/paddlespeech/t2s/modules/losses.py b/paddlespeech/t2s/modules/losses.py index 1f9399b75..e675dcab7 100644 --- a/paddlespeech/t2s/modules/losses.py +++ b/paddlespeech/t2s/modules/losses.py @@ -1114,6 +1114,7 @@ class MLMLoss(nn.Layer): paddle.reshape(after_outs, (-1, self.odim)), paddle.reshape(xs_pad, (-1, self.odim))), axis=-1) + mlm_loss_pos = (mlm_loss_pos).astype(loss.dtype) mlm_loss = paddle.sum((loss * paddle.reshape( mlm_loss_pos, [-1]).astype(loss.dtype))) / paddle.sum((mlm_loss_pos) + 1e-10)