Merge pull request #816 from PaddlePaddle/bool_mul

float_mul_bool type promote, rhs type promote to lhs type
pull/819/head
Jackwaterveg 3 years ago committed by GitHub
commit 7d0204e9fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -106,11 +106,9 @@ class ConvBn(nn.Layer):
# reset padding part to 0
masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T]
# TODO(Hui Zhang): not support bool multiply
# masks = masks.type_as(x)
masks = masks.astype(x.dtype)
x = x.multiply(masks)
# https://github.com/PaddlePaddle/Paddle/pull/29265
# rhs will type promote to lhs
x = x * masks
return x, x_len

@ -308,7 +308,8 @@ class RNNStack(nn.Layer):
x, x_len = rnn(x, x_len)
masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(-1) # [B, T, 1]
# TODO(Hui Zhang): not support bool multiply
masks = masks.astype(x.dtype)
x = x.multiply(masks)
# https://github.com/PaddlePaddle/Paddle/pull/29265
# rhs will type promote to lhs
x = x * masks
return x, x_len

@ -113,11 +113,9 @@ class ConvBn(nn.Layer):
# reset padding part to 0
masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T]
# TODO(Hui Zhang): not support bool multiply
# masks = masks.type_as(x)
masks = masks.astype(x.dtype)
x = x.multiply(masks)
# https://github.com/PaddlePaddle/Paddle/pull/29265
# rhs will type promote to lhs
x = x * masks
return x, x_len

@ -46,7 +46,6 @@ class CTCLoss(nn.Layer):
# warp-ctc need activation with shape [T, B, V + 1]
# logits: (B, L, D) -> (L, B, D)
logits = logits.transpose([1, 0, 2])
# (TODO:Hui Zhang) ctc loss does not support int64 labels
ys_pad = ys_pad.astype(paddle.int32)
loss = self.loss(
logits, ys_pad, hlens, ys_lens, norm_by_times=self.batch_average)

@ -308,7 +308,7 @@ class RNNStack(nn.Layer):
x, x_len = rnn(x, x_len)
masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(-1) # [B, T, 1]
# TODO(Hui Zhang): not support bool multiply
masks = masks.astype(x.dtype)
x = x.multiply(masks)
# https://github.com/PaddlePaddle/Paddle/pull/29265
# rhs will type promote to lhs
x = x * masks
return x, x_len

Loading…
Cancel
Save