|
|
|
@ -22,7 +22,8 @@ import paddle
|
|
|
|
|
from paddle import nn
|
|
|
|
|
from paddle.nn import functional as F
|
|
|
|
|
from paddle.nn import initializer as I
|
|
|
|
|
|
|
|
|
|
#TODO(Hui Zhang): remove fluid import
|
|
|
|
|
from paddle.fluid import core
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
########### hcak logging #############
|
|
|
|
@ -44,10 +45,51 @@ paddle.int = 'int32'
|
|
|
|
|
paddle.int64 = 'int64'
|
|
|
|
|
paddle.long = 'int64'
|
|
|
|
|
paddle.uint8 = 'uint8'
|
|
|
|
|
paddle.uint16 = 'uint16'
|
|
|
|
|
paddle.complex64 = 'complex64'
|
|
|
|
|
paddle.complex128 = 'complex128'
|
|
|
|
|
paddle.cdouble = 'complex128'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_dtype_to_string(tensor_dtype):
|
|
|
|
|
"""
|
|
|
|
|
Convert the data type in numpy to the data type in Paddle
|
|
|
|
|
Args:
|
|
|
|
|
tensor_dtype(core.VarDesc.VarType): the data type in numpy.
|
|
|
|
|
Returns:
|
|
|
|
|
core.VarDesc.VarType: the data type in Paddle.
|
|
|
|
|
"""
|
|
|
|
|
dtype = tensor_dtype
|
|
|
|
|
if dtype == core.VarDesc.VarType.FP32:
|
|
|
|
|
return paddle.float32
|
|
|
|
|
elif dtype == core.VarDesc.VarType.FP64:
|
|
|
|
|
return paddle.float64
|
|
|
|
|
elif dtype == core.VarDesc.VarType.FP16:
|
|
|
|
|
return paddle.float16
|
|
|
|
|
elif dtype == core.VarDesc.VarType.INT32:
|
|
|
|
|
return paddle.int32
|
|
|
|
|
elif dtype == core.VarDesc.VarType.INT16:
|
|
|
|
|
return paddle.int16
|
|
|
|
|
elif dtype == core.VarDesc.VarType.INT64:
|
|
|
|
|
return paddle.int64
|
|
|
|
|
elif dtype == core.VarDesc.VarType.BOOL:
|
|
|
|
|
return paddle.bool
|
|
|
|
|
elif dtype == core.VarDesc.VarType.BF16:
|
|
|
|
|
# since there is still no support for bfloat16 in NumPy,
|
|
|
|
|
# uint16 is used for casting bfloat16
|
|
|
|
|
return paddle.uint16
|
|
|
|
|
elif dtype == core.VarDesc.VarType.UINT8:
|
|
|
|
|
return paddle.uint8
|
|
|
|
|
elif dtype == core.VarDesc.VarType.INT8:
|
|
|
|
|
return paddle.int8
|
|
|
|
|
elif dtype == core.VarDesc.VarType.COMPLEX64:
|
|
|
|
|
return paddle.complex64
|
|
|
|
|
elif dtype == core.VarDesc.VarType.COMPLEX128:
|
|
|
|
|
return paddle.complex128
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Not supported tensor dtype %s" % dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not hasattr(paddle, 'softmax'):
|
|
|
|
|
logger.warn("register user softmax to paddle, remove this when fixed!")
|
|
|
|
|
setattr(paddle, 'softmax', paddle.nn.functional.softmax)
|
|
|
|
@ -126,7 +168,9 @@ if not hasattr(paddle.Tensor, 'new_full'):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor:
|
|
|
|
|
return xs.equal(paddle.to_tensor(ys, dtype=xs.dtype, place=xs.place))
|
|
|
|
|
return xs.equal(
|
|
|
|
|
paddle.to_tensor(
|
|
|
|
|
ys, dtype=convert_dtype_to_string(xs.dtype), place=xs.place))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not hasattr(paddle.Tensor, 'eq'):
|
|
|
|
@ -184,10 +228,21 @@ if not hasattr(paddle.Tensor, 'view_as'):
|
|
|
|
|
paddle.Tensor.view_as = view_as
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_broadcastable(shp1, shp2):
|
|
|
|
|
for a, b in zip(shp1[::-1], shp2[::-1]):
|
|
|
|
|
if a == 1 or b == 1 or a == b:
|
|
|
|
|
pass
|
|
|
|
|
else:
|
|
|
|
|
return False
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def masked_fill(xs: paddle.Tensor,
|
|
|
|
|
mask: paddle.Tensor,
|
|
|
|
|
value: Union[float, int]):
|
|
|
|
|
assert xs.shape == mask.shape
|
|
|
|
|
assert is_broadcastable(xs.shape, mask.shape) == True
|
|
|
|
|
bshape = paddle.broadcast_shape(xs.shape, mask.shape)
|
|
|
|
|
mask = mask.broadcast_to(bshape)
|
|
|
|
|
trues = paddle.ones_like(xs) * value
|
|
|
|
|
xs = paddle.where(mask, trues, xs)
|
|
|
|
|
return xs
|
|
|
|
@ -202,7 +257,9 @@ if not hasattr(paddle.Tensor, 'masked_fill'):
|
|
|
|
|
def masked_fill_(xs: paddle.Tensor,
|
|
|
|
|
mask: paddle.Tensor,
|
|
|
|
|
value: Union[float, int]):
|
|
|
|
|
assert xs.shape == mask.shape
|
|
|
|
|
assert is_broadcastable(xs.shape, mask.shape) == True
|
|
|
|
|
bshape = paddle.broadcast_shape(xs.shape, mask.shape)
|
|
|
|
|
mask = mask.broadcast_to(bshape)
|
|
|
|
|
trues = paddle.ones_like(xs) * value
|
|
|
|
|
ret = paddle.where(mask, trues, xs)
|
|
|
|
|
paddle.assign(ret, output=xs)
|
|
|
|
@ -414,4 +471,4 @@ if not hasattr(paddle.nn, 'ConstantPad2d'):
|
|
|
|
|
|
|
|
|
|
if not hasattr(paddle.jit, 'export'):
|
|
|
|
|
logger.warn("register user export to paddle.jit, remove this when fixed!")
|
|
|
|
|
setattr(paddle.jit, 'export', paddle.jit.to_static)
|
|
|
|
|
setattr(paddle.jit, 'export', paddle.jit.to_static)
|
|
|
|
|