From 2d7683f514637385d9c07fdf34b3cabba8703617 Mon Sep 17 00:00:00 2001 From: liyulingyue <852433440@qq.com> Date: Tue, 3 Dec 2024 21:54:52 +0800 Subject: [PATCH] fix aishell3-vc0 --- paddlespeech/t2s/modules/losses.py | 3 ++- paddlespeech/t2s/modules/nets_utils.py | 7 ++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/paddlespeech/t2s/modules/losses.py b/paddlespeech/t2s/modules/losses.py index b4d78364c..1f9399b75 100644 --- a/paddlespeech/t2s/modules/losses.py +++ b/paddlespeech/t2s/modules/losses.py @@ -1115,7 +1115,8 @@ class MLMLoss(nn.Layer): paddle.reshape(xs_pad, (-1, self.odim))), axis=-1) mlm_loss = paddle.sum((loss * paddle.reshape( - mlm_loss_pos, [-1]))) / paddle.sum((mlm_loss_pos) + 1e-10) + mlm_loss_pos, + [-1]).astype(loss.dtype))) / paddle.sum((mlm_loss_pos) + 1e-10) text_mlm_loss = None diff --git a/paddlespeech/t2s/modules/nets_utils.py b/paddlespeech/t2s/modules/nets_utils.py index 0a66a1c88..8051165f4 100644 --- a/paddlespeech/t2s/modules/nets_utils.py +++ b/paddlespeech/t2s/modules/nets_utils.py @@ -465,7 +465,7 @@ def phones_masking(xs_pad: paddle.Tensor, for s, e in zip(masked_start, masked_end): masked_pos[idx, s:e] = 1 non_eos_mask = paddle.reshape(src_mask, paddle.shape(xs_pad)[:2]) - masked_pos = masked_pos * non_eos_mask + masked_pos = masked_pos * non_eos_mask.astype(masked_pos.dtype) masked_pos = paddle.cast(masked_pos, 'bool') return masked_pos @@ -549,10 +549,11 @@ def phones_text_masking(xs_pad: paddle.Tensor, for s, e in zip(masked_start, masked_end): masked_pos[idx, s:e] = 1 non_eos_mask = paddle.reshape(src_mask, shape=paddle.shape(xs_pad)[:2]) - masked_pos = masked_pos * non_eos_mask + masked_pos = masked_pos * non_eos_mask.astype(masked_pos.dtype) non_eos_text_mask = paddle.reshape( text_mask, shape=paddle.shape(text_pad)[:2]) - text_masked_pos = text_masked_pos * non_eos_text_mask + text_masked_pos = text_masked_pos * non_eos_text_mask.astype( + text_masked_pos.dtype) masked_pos = paddle.cast(masked_pos, 'bool') text_masked_pos = paddle.cast(text_masked_pos, 'bool')