|
|
@ -159,9 +159,7 @@ if not hasattr(paddle.Tensor, 'new_full'):
|
|
|
|
def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor:
|
|
|
|
def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor:
|
|
|
|
if convert_dtype_to_string(xs.dtype) == paddle.bool:
|
|
|
|
if convert_dtype_to_string(xs.dtype) == paddle.bool:
|
|
|
|
xs = xs.astype(paddle.int)
|
|
|
|
xs = xs.astype(paddle.int)
|
|
|
|
return xs.equal(
|
|
|
|
return xs.equal(ys)
|
|
|
|
paddle.to_tensor(
|
|
|
|
|
|
|
|
ys, dtype=convert_dtype_to_string(xs.dtype), place=xs.place))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not hasattr(paddle.Tensor, 'eq'):
|
|
|
|
if not hasattr(paddle.Tensor, 'eq'):
|
|
|
@ -219,13 +217,22 @@ def is_broadcastable(shp1, shp2):
|
|
|
|
return True
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def broadcast_shape(shp1, shp2):
|
|
|
|
|
|
|
|
result = []
|
|
|
|
|
|
|
|
for a, b in zip(shp1[::-1], shp2[::-1]):
|
|
|
|
|
|
|
|
result.append(max(a, b))
|
|
|
|
|
|
|
|
return result[::-1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def masked_fill(xs: paddle.Tensor,
|
|
|
|
def masked_fill(xs: paddle.Tensor,
|
|
|
|
mask: paddle.Tensor,
|
|
|
|
mask: paddle.Tensor,
|
|
|
|
value: Union[float, int]):
|
|
|
|
value: Union[float, int]):
|
|
|
|
assert is_broadcastable(xs.shape, mask.shape) is True, (xs.shape,
|
|
|
|
bshape = broadcast_shape(xs.shape, mask.shape)
|
|
|
|
mask.shape)
|
|
|
|
mask.stop_gradient = True
|
|
|
|
bshape = paddle.broadcast_shape(xs.shape, mask.shape)
|
|
|
|
tmp = paddle.ones(shape=[len(bshape)], dtype='int32')
|
|
|
|
mask = mask.broadcast_to(bshape)
|
|
|
|
for index in range(len(bshape)):
|
|
|
|
|
|
|
|
tmp[index] = bshape[index]
|
|
|
|
|
|
|
|
mask = mask.broadcast_to(tmp)
|
|
|
|
trues = paddle.ones_like(xs) * value
|
|
|
|
trues = paddle.ones_like(xs) * value
|
|
|
|
xs = paddle.where(mask, trues, xs)
|
|
|
|
xs = paddle.where(mask, trues, xs)
|
|
|
|
return xs
|
|
|
|
return xs
|
|
|
|