From ce2b7b82c6ee5dc892d9caf62134614a260787f5 Mon Sep 17 00:00:00 2001 From: megemini Date: Mon, 9 Dec 2024 15:39:53 +0800 Subject: [PATCH] [Fix] type promotion --- paddlespeech/t2s/modules/losses.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddlespeech/t2s/modules/losses.py b/paddlespeech/t2s/modules/losses.py index 12bf952d6..20e85a59c 100644 --- a/paddlespeech/t2s/modules/losses.py +++ b/paddlespeech/t2s/modules/losses.py @@ -1114,9 +1114,9 @@ class MLMLoss(nn.Layer): paddle.reshape(after_outs, (-1, self.odim)), paddle.reshape(xs_pad, (-1, self.odim))), axis=-1) + mlm_loss_pos = (mlm_loss_pos).astype(loss.dtype) mlm_loss = paddle.sum((loss * paddle.reshape( - mlm_loss_pos.astype(loss.dtype), - [-1]))) / paddle.sum((mlm_loss_pos.astype(loss.dtype)) + 1e-10) + mlm_loss_pos, [-1]))) / paddle.sum((mlm_loss_pos) + 1e-10) text_mlm_loss = None