|
|
@ -166,7 +166,6 @@ def broadcast_shape(shp1, shp2):
|
|
|
|
def masked_fill(xs: paddle.Tensor,
|
|
|
|
def masked_fill(xs: paddle.Tensor,
|
|
|
|
mask: paddle.Tensor,
|
|
|
|
mask: paddle.Tensor,
|
|
|
|
value: Union[float, int]):
|
|
|
|
value: Union[float, int]):
|
|
|
|
mask.stop_gradient = True
|
|
|
|
|
|
|
|
mask = mask.astype(xs.dtype)
|
|
|
|
mask = mask.astype(xs.dtype)
|
|
|
|
return xs * (1.0 - mask) + mask * value
|
|
|
|
return xs * (1.0 - mask) + mask * value
|
|
|
|
|
|
|
|
|
|
|
|