diff --git a/paddlespeech/s2t/__init__.py b/paddlespeech/s2t/__init__.py index 4507365d6..6663bcf87 100644 --- a/paddlespeech/s2t/__init__.py +++ b/paddlespeech/s2t/__init__.py @@ -166,8 +166,19 @@ def broadcast_shape(shp1, shp2): def masked_fill(xs: paddle.Tensor, mask: paddle.Tensor, value: Union[float, int]): - mask = mask.astype(xs.dtype) - return xs * (1.0 - mask) + mask * value + # will be nan when value is `inf`. + # mask = mask.astype(xs.dtype) + # return xs * (1.0 - mask) + mask * value + + bshape = broadcast_shape(xs.shape, mask.shape) + mask.stop_gradient = True + # tmp = paddle.ones(shape=[len(bshape)], dtype='int32') + # for index in range(len(bshape)): + # tmp[index] = bshape[index] + mask = mask.broadcast_to(bshape) + trues = paddle.full_like(xs, fill_value=value) + xs = paddle.where(mask, trues, xs) + return xs if not hasattr(paddle.Tensor, 'masked_fill'):