|
|
|
@ -12,8 +12,8 @@
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
import logging
|
|
|
|
|
from typeing import Union
|
|
|
|
|
from typeing import Any
|
|
|
|
|
from typing import Union
|
|
|
|
|
from typing import Any
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
|
from paddle import nn
|
|
|
|
@ -21,6 +21,7 @@ from paddle.nn import functional as F
|
|
|
|
|
from paddle.nn import initializer as I
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
logger.warn = logging.warning
|
|
|
|
|
|
|
|
|
|
# TODO(Hui Zhang): remove this hack
|
|
|
|
|
paddle.bool = 'bool'
|
|
|
|
@ -52,11 +53,10 @@ def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
|
|
|
|
|
return s
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not hasattr(paddle.Tensor, 'size'):
|
|
|
|
|
logger.warn(
|
|
|
|
|
"override size of paddle.Tensor if exists or register, remove this when fixed!"
|
|
|
|
|
)
|
|
|
|
|
paddle.Tensor.size = size
|
|
|
|
|
# logger.warn(
|
|
|
|
|
# "override size of paddle.Tensor if exists or register, remove this when fixed!"
|
|
|
|
|
# )
|
|
|
|
|
# paddle.Tensor.size = size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def masked_fill(xs: paddle.Tensor,
|
|
|
|
|