|
|
|
@ -166,8 +166,19 @@ def broadcast_shape(shp1, shp2):
|
|
|
|
|
def masked_fill(xs: paddle.Tensor,
|
|
|
|
|
mask: paddle.Tensor,
|
|
|
|
|
value: Union[float, int]):
|
|
|
|
|
mask = mask.astype(xs.dtype)
|
|
|
|
|
return xs * (1.0 - mask) + mask * value
|
|
|
|
|
# will be nan when value is `inf`.
|
|
|
|
|
# mask = mask.astype(xs.dtype)
|
|
|
|
|
# return xs * (1.0 - mask) + mask * value
|
|
|
|
|
|
|
|
|
|
bshape = broadcast_shape(xs.shape, mask.shape)
|
|
|
|
|
mask.stop_gradient = True
|
|
|
|
|
# tmp = paddle.ones(shape=[len(bshape)], dtype='int32')
|
|
|
|
|
# for index in range(len(bshape)):
|
|
|
|
|
# tmp[index] = bshape[index]
|
|
|
|
|
mask = mask.broadcast_to(bshape)
|
|
|
|
|
trues = paddle.full_like(xs, fill_value=value)
|
|
|
|
|
xs = paddle.where(mask, trues, xs)
|
|
|
|
|
return xs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not hasattr(paddle.Tensor, 'masked_fill'):
|
|
|
|
|