diff --git a/paddlespeech/t2s/modules/losses.py b/paddlespeech/t2s/modules/losses.py index 12bf952d6..20e85a59c 100644 --- a/paddlespeech/t2s/modules/losses.py +++ b/paddlespeech/t2s/modules/losses.py @@ -1114,9 +1114,9 @@ class MLMLoss(nn.Layer): paddle.reshape(after_outs, (-1, self.odim)), paddle.reshape(xs_pad, (-1, self.odim))), axis=-1) + mlm_loss_pos = (mlm_loss_pos).astype(loss.dtype) mlm_loss = paddle.sum((loss * paddle.reshape( - mlm_loss_pos.astype(loss.dtype), - [-1]))) / paddle.sum((mlm_loss_pos.astype(loss.dtype)) + 1e-10) + mlm_loss_pos, [-1]))) / paddle.sum((mlm_loss_pos) + 1e-10) text_mlm_loss = None