masked_fill by multiply, remove while

pull/2425/head
Hui Zhang 3 years ago
parent feb27e2a84
commit b20bf7d5de

@ -166,15 +166,9 @@ def broadcast_shape(shp1, shp2):
def masked_fill(xs: paddle.Tensor, def masked_fill(xs: paddle.Tensor,
mask: paddle.Tensor, mask: paddle.Tensor,
value: Union[float, int]): value: Union[float, int]):
bshape = broadcast_shape(xs.shape, mask.shape)
mask.stop_gradient = True mask.stop_gradient = True
tmp = paddle.ones(shape=[len(bshape)], dtype='int32') mask = mask.astype(xs.dtype)
for index in range(len(bshape)): return xs * (1.0 - mask) + mask * value
tmp[index] = bshape[index]
mask = mask.broadcast_to(tmp)
trues = paddle.ones_like(xs) * value
xs = paddle.where(mask, trues, xs)
return xs
if not hasattr(paddle.Tensor, 'masked_fill'): if not hasattr(paddle.Tensor, 'masked_fill'):

Loading…
Cancel
Save