fix aishell3-vc0

pull/3928/head
liyulingyue 10 months ago
parent 62bac0a1d1
commit 2d7683f514

@ -1115,7 +1115,8 @@ class MLMLoss(nn.Layer):
paddle.reshape(xs_pad, (-1, self.odim))),
axis=-1)
mlm_loss = paddle.sum((loss * paddle.reshape(
mlm_loss_pos, [-1]))) / paddle.sum((mlm_loss_pos) + 1e-10)
mlm_loss_pos,
[-1]).astype(loss.dtype))) / paddle.sum((mlm_loss_pos) + 1e-10)
text_mlm_loss = None

@ -465,7 +465,7 @@ def phones_masking(xs_pad: paddle.Tensor,
for s, e in zip(masked_start, masked_end):
masked_pos[idx, s:e] = 1
non_eos_mask = paddle.reshape(src_mask, paddle.shape(xs_pad)[:2])
masked_pos = masked_pos * non_eos_mask
masked_pos = masked_pos * non_eos_mask.astype(masked_pos.dtype)
masked_pos = paddle.cast(masked_pos, 'bool')
return masked_pos
@ -549,10 +549,11 @@ def phones_text_masking(xs_pad: paddle.Tensor,
for s, e in zip(masked_start, masked_end):
masked_pos[idx, s:e] = 1
non_eos_mask = paddle.reshape(src_mask, shape=paddle.shape(xs_pad)[:2])
masked_pos = masked_pos * non_eos_mask
masked_pos = masked_pos * non_eos_mask.astype(masked_pos.dtype)
non_eos_text_mask = paddle.reshape(
text_mask, shape=paddle.shape(text_pad)[:2])
text_masked_pos = text_masked_pos * non_eos_text_mask
text_masked_pos = text_masked_pos * non_eos_text_mask.astype(
text_masked_pos.dtype)
masked_pos = paddle.cast(masked_pos, 'bool')
text_masked_pos = paddle.cast(text_masked_pos, 'bool')

Loading…
Cancel
Save