[Hackathon 7th] 修复 `vctk` 的 `ernie_sat` 训练时出现的类型提升问题 (#3943)

* [Fix] vctk type promotion

* [Fix] type promotion
pull/3946/head
megemini 9 months ago committed by GitHub
parent b84e86d718
commit e4038b4b6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1114,6 +1114,7 @@ class MLMLoss(nn.Layer):
paddle.reshape(after_outs, (-1, self.odim)), paddle.reshape(after_outs, (-1, self.odim)),
paddle.reshape(xs_pad, (-1, self.odim))), paddle.reshape(xs_pad, (-1, self.odim))),
axis=-1) axis=-1)
mlm_loss_pos = (mlm_loss_pos).astype(loss.dtype)
mlm_loss = paddle.sum((loss * paddle.reshape( mlm_loss = paddle.sum((loss * paddle.reshape(
mlm_loss_pos, mlm_loss_pos,
[-1]).astype(loss.dtype))) / paddle.sum((mlm_loss_pos) + 1e-10) [-1]).astype(loss.dtype))) / paddle.sum((mlm_loss_pos) + 1e-10)

Loading…
Cancel
Save