not hack size since it exists

pull/556/head
Hui Zhang 5 years ago
parent df1d44f5d6
commit bc6da7a123

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typeing import Union from typing import Union
from typeing import Any from typing import Any
import paddle import paddle
from paddle import nn from paddle import nn
@ -21,6 +21,7 @@ from paddle.nn import functional as F
from paddle.nn import initializer as I from paddle.nn import initializer as I
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.warn = logging.warning
# TODO(Hui Zhang): remove this hack # TODO(Hui Zhang): remove this hack
paddle.bool = 'bool' paddle.bool = 'bool'
@ -52,11 +53,10 @@ def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
return s return s
if not hasattr(paddle.Tensor, 'size'): # logger.warn(
logger.warn( # "override size of paddle.Tensor if exists or register, remove this when fixed!"
"override size of paddle.Tensor if exists or register, remove this when fixed!" # )
) # paddle.Tensor.size = size
paddle.Tensor.size = size
def masked_fill(xs: paddle.Tensor, def masked_fill(xs: paddle.Tensor,

@ -272,6 +272,6 @@ def mask_finished_preds(pred: paddle.Tensor, flag: paddle.Tensor,
Returns: Returns:
paddle.Tensor: (batch_size * beam_size). paddle.Tensor: (batch_size * beam_size).
""" """
beam_size = pred.size(-1) beam_size = pred.shape[-1]
finished = flag.repeat([1, beam_size]) finished = flag.repeat(1, beam_size)
return pred.masked_fill_(finished, eos) return pred.masked_fill_(finished, eos)

Loading…
Cancel
Save