|
|
@ -309,6 +309,6 @@ class RNNStack(nn.Layer):
|
|
|
|
masks = make_non_pad_mask(x_len) #[B, T]
|
|
|
|
masks = make_non_pad_mask(x_len) #[B, T]
|
|
|
|
masks = masks.unsqueeze(-1) # [B, T, 1]
|
|
|
|
masks = masks.unsqueeze(-1) # [B, T, 1]
|
|
|
|
# TODO(Hui Zhang): not support bool multiply
|
|
|
|
# TODO(Hui Zhang): not support bool multiply
|
|
|
|
masks = masks.type_as(x)
|
|
|
|
masks = masks.astype(x.dtype)
|
|
|
|
x = x.multiply(masks)
|
|
|
|
x = x.multiply(masks)
|
|
|
|
return x, x_len
|
|
|
|
return x, x_len
|
|
|
|