[Fix] type promotion

pull/3943/head
megemini 9 months ago
parent 26ba0ff684
commit ce2b7b82c6

@ -1114,9 +1114,9 @@ class MLMLoss(nn.Layer):
paddle.reshape(after_outs, (-1, self.odim)), paddle.reshape(after_outs, (-1, self.odim)),
paddle.reshape(xs_pad, (-1, self.odim))), paddle.reshape(xs_pad, (-1, self.odim))),
axis=-1) axis=-1)
mlm_loss_pos = (mlm_loss_pos).astype(loss.dtype)
mlm_loss = paddle.sum((loss * paddle.reshape( mlm_loss = paddle.sum((loss * paddle.reshape(
mlm_loss_pos.astype(loss.dtype), mlm_loss_pos, [-1]))) / paddle.sum((mlm_loss_pos) + 1e-10)
[-1]))) / paddle.sum((mlm_loss_pos.astype(loss.dtype)) + 1e-10)
text_mlm_loss = None text_mlm_loss = None

Loading…
Cancel
Save