Merge pull request #1788 from zh794390558/jit

[asr] patch func to var
pull/1792/head
Hui Zhang 3 years ago committed by GitHub
commit c336d98a7f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -131,12 +131,14 @@ if not hasattr(paddle.Tensor, 'long'):
"override long of paddle.Tensor if exists or register, remove this when fixed!" "override long of paddle.Tensor if exists or register, remove this when fixed!"
) )
paddle.Tensor.long = func_long paddle.Tensor.long = func_long
paddle.static.Variable.long = func_long
if not hasattr(paddle.Tensor, 'numel'): if not hasattr(paddle.Tensor, 'numel'):
logger.debug( logger.debug(
"override numel of paddle.Tensor if exists or register, remove this when fixed!" "override numel of paddle.Tensor if exists or register, remove this when fixed!"
) )
paddle.Tensor.numel = paddle.numel paddle.Tensor.numel = paddle.numel
paddle.static.Variable.numel = paddle.numel
def new_full(x: paddle.Tensor, def new_full(x: paddle.Tensor,
@ -151,6 +153,7 @@ if not hasattr(paddle.Tensor, 'new_full'):
"override new_full of paddle.Tensor if exists or register, remove this when fixed!" "override new_full of paddle.Tensor if exists or register, remove this when fixed!"
) )
paddle.Tensor.new_full = new_full paddle.Tensor.new_full = new_full
paddle.static.Variable.new_full = new_full
def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor: def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor:
@ -166,6 +169,7 @@ if not hasattr(paddle.Tensor, 'eq'):
"override eq of paddle.Tensor if exists or register, remove this when fixed!" "override eq of paddle.Tensor if exists or register, remove this when fixed!"
) )
paddle.Tensor.eq = eq paddle.Tensor.eq = eq
paddle.static.Variable.eq = eq
if not hasattr(paddle, 'eq'): if not hasattr(paddle, 'eq'):
logger.debug( logger.debug(
@ -182,6 +186,7 @@ if not hasattr(paddle.Tensor, 'contiguous'):
"override contiguous of paddle.Tensor if exists or register, remove this when fixed!" "override contiguous of paddle.Tensor if exists or register, remove this when fixed!"
) )
paddle.Tensor.contiguous = contiguous paddle.Tensor.contiguous = contiguous
paddle.static.Variable.contiguous = contiguous
def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor: def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
@ -200,6 +205,7 @@ logger.debug(
"(`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!" "(`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!"
) )
paddle.Tensor.size = size paddle.Tensor.size = size
paddle.static.Variable.size = size
def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor: def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
@ -209,6 +215,7 @@ def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'view'): if not hasattr(paddle.Tensor, 'view'):
logger.debug("register user view to paddle.Tensor, remove this when fixed!") logger.debug("register user view to paddle.Tensor, remove this when fixed!")
paddle.Tensor.view = view paddle.Tensor.view = view
paddle.static.Variable.view = view
def view_as(xs: paddle.Tensor, ys: paddle.Tensor) -> paddle.Tensor: def view_as(xs: paddle.Tensor, ys: paddle.Tensor) -> paddle.Tensor:
@ -219,6 +226,7 @@ if not hasattr(paddle.Tensor, 'view_as'):
logger.debug( logger.debug(
"register user view_as to paddle.Tensor, remove this when fixed!") "register user view_as to paddle.Tensor, remove this when fixed!")
paddle.Tensor.view_as = view_as paddle.Tensor.view_as = view_as
paddle.static.Variable.view_as = view_as
def is_broadcastable(shp1, shp2): def is_broadcastable(shp1, shp2):
@ -246,6 +254,7 @@ if not hasattr(paddle.Tensor, 'masked_fill'):
logger.debug( logger.debug(
"register user masked_fill to paddle.Tensor, remove this when fixed!") "register user masked_fill to paddle.Tensor, remove this when fixed!")
paddle.Tensor.masked_fill = masked_fill paddle.Tensor.masked_fill = masked_fill
paddle.static.Variable.masked_fill = masked_fill
def masked_fill_(xs: paddle.Tensor, def masked_fill_(xs: paddle.Tensor,
@ -264,6 +273,7 @@ if not hasattr(paddle.Tensor, 'masked_fill_'):
logger.debug( logger.debug(
"register user masked_fill_ to paddle.Tensor, remove this when fixed!") "register user masked_fill_ to paddle.Tensor, remove this when fixed!")
paddle.Tensor.masked_fill_ = masked_fill_ paddle.Tensor.masked_fill_ = masked_fill_
paddle.static.Variable.maksed_fill_ = masked_fill_
def fill_(xs: paddle.Tensor, value: Union[float, int]) -> paddle.Tensor: def fill_(xs: paddle.Tensor, value: Union[float, int]) -> paddle.Tensor:
@ -276,6 +286,7 @@ if not hasattr(paddle.Tensor, 'fill_'):
logger.debug( logger.debug(
"register user fill_ to paddle.Tensor, remove this when fixed!") "register user fill_ to paddle.Tensor, remove this when fixed!")
paddle.Tensor.fill_ = fill_ paddle.Tensor.fill_ = fill_
paddle.static.Variable.fill_ = fill_
def repeat(xs: paddle.Tensor, *size: Any) -> paddle.Tensor: def repeat(xs: paddle.Tensor, *size: Any) -> paddle.Tensor:
@ -286,6 +297,7 @@ if not hasattr(paddle.Tensor, 'repeat'):
logger.debug( logger.debug(
"register user repeat to paddle.Tensor, remove this when fixed!") "register user repeat to paddle.Tensor, remove this when fixed!")
paddle.Tensor.repeat = repeat paddle.Tensor.repeat = repeat
paddle.static.Variable.repeat = repeat
if not hasattr(paddle.Tensor, 'softmax'): if not hasattr(paddle.Tensor, 'softmax'):
logger.debug( logger.debug(
@ -310,6 +322,8 @@ if not hasattr(paddle.Tensor, 'type_as'):
logger.debug( logger.debug(
"register user type_as to paddle.Tensor, remove this when fixed!") "register user type_as to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'type_as', type_as) setattr(paddle.Tensor, 'type_as', type_as)
setattr(paddle.static.Variable, 'type_as', type_as)
def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor: def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor:
@ -325,6 +339,7 @@ def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'to'): if not hasattr(paddle.Tensor, 'to'):
logger.debug("register user to to paddle.Tensor, remove this when fixed!") logger.debug("register user to to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'to', to) setattr(paddle.Tensor, 'to', to)
setattr(paddle.static.Variable, 'to', to)
def func_float(x: paddle.Tensor) -> paddle.Tensor: def func_float(x: paddle.Tensor) -> paddle.Tensor:
@ -335,6 +350,7 @@ if not hasattr(paddle.Tensor, 'float'):
logger.debug( logger.debug(
"register user float to paddle.Tensor, remove this when fixed!") "register user float to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'float', func_float) setattr(paddle.Tensor, 'float', func_float)
setattr(paddle.static.Variable, 'float', func_float)
def func_int(x: paddle.Tensor) -> paddle.Tensor: def func_int(x: paddle.Tensor) -> paddle.Tensor:
@ -344,6 +360,7 @@ def func_int(x: paddle.Tensor) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'int'): if not hasattr(paddle.Tensor, 'int'):
logger.debug("register user int to paddle.Tensor, remove this when fixed!") logger.debug("register user int to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'int', func_int) setattr(paddle.Tensor, 'int', func_int)
setattr(paddle.static.Variable, 'int', func_int)
def tolist(x: paddle.Tensor) -> List[Any]: def tolist(x: paddle.Tensor) -> List[Any]:
@ -354,6 +371,8 @@ if not hasattr(paddle.Tensor, 'tolist'):
logger.debug( logger.debug(
"register user tolist to paddle.Tensor, remove this when fixed!") "register user tolist to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'tolist', tolist) setattr(paddle.Tensor, 'tolist', tolist)
setattr(paddle.static.Variable, 'tolist', tolist)
########### hack paddle.nn ############# ########### hack paddle.nn #############
from paddle.nn import Layer from paddle.nn import Layer

Loading…
Cancel
Save