not hack size since it exists

pull/556/head
Hui Zhang 5 years ago
parent df1d44f5d6
commit bc6da7a123

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typeing import Union
from typeing import Any
from typing import Union
from typing import Any
import paddle
from paddle import nn
@ -21,6 +21,7 @@ from paddle.nn import functional as F
from paddle.nn import initializer as I
logger = logging.getLogger(__name__)
logger.warn = logging.warning
# TODO(Hui Zhang): remove this hack
paddle.bool = 'bool'
@ -52,11 +53,10 @@ def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
return s
if not hasattr(paddle.Tensor, 'size'):
logger.warn(
"override size of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.size = size
# logger.warn(
# "override size of paddle.Tensor if exists or register, remove this when fixed!"
# )
# paddle.Tensor.size = size
def masked_fill(xs: paddle.Tensor,

@ -272,6 +272,6 @@ def mask_finished_preds(pred: paddle.Tensor, flag: paddle.Tensor,
Returns:
paddle.Tensor: (batch_size * beam_size).
"""
beam_size = pred.size(-1)
finished = flag.repeat([1, beam_size])
beam_size = pred.shape[-1]
finished = flag.repeat(1, beam_size)
return pred.masked_fill_(finished, eos)

Loading…
Cancel
Save