eliminate mul

pull/2425/head
Hui Zhang 2 years ago
parent b7388ce25a
commit c4a5ae3825

@ -152,8 +152,8 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
# return pad_sequence(ys_in, padding_value=eos).transpose([1,0]), pad_sequence(ys_out, padding_value=ignore_id).transpose([1,0]) # return pad_sequence(ys_in, padding_value=eos).transpose([1,0]), pad_sequence(ys_out, padding_value=ignore_id).transpose([1,0])
B = ys_pad.shape[0] B = ys_pad.shape[0]
_sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos _sos = paddle.full([B, 1], sos, dtype=ys_pad.dtype)
_eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos _eos = paddle.full([B, 1], eos, dtype=ys_pad.dtype)
ys_in = paddle.cat([_sos, ys_pad], dim=1) ys_in = paddle.cat([_sos, ys_pad], dim=1)
mask_pad = (ys_in == ignore_id) mask_pad = (ys_in == ignore_id)
ys_in = ys_in.masked_fill(mask_pad, eos) ys_in = ys_in.masked_fill(mask_pad, eos)

Loading…
Cancel
Save