|
|
|
@ -1114,9 +1114,9 @@ class MLMLoss(nn.Layer):
|
|
|
|
|
paddle.reshape(after_outs, (-1, self.odim)),
|
|
|
|
|
paddle.reshape(xs_pad, (-1, self.odim))),
|
|
|
|
|
axis=-1)
|
|
|
|
|
mlm_loss_pos = (mlm_loss_pos).astype(loss.dtype)
|
|
|
|
|
mlm_loss = paddle.sum((loss * paddle.reshape(
|
|
|
|
|
mlm_loss_pos.astype(loss.dtype),
|
|
|
|
|
[-1]))) / paddle.sum((mlm_loss_pos.astype(loss.dtype)) + 1e-10)
|
|
|
|
|
mlm_loss_pos, [-1]))) / paddle.sum((mlm_loss_pos) + 1e-10)
|
|
|
|
|
|
|
|
|
|
text_mlm_loss = None
|
|
|
|
|
|
|
|
|
|