|
|
|
@ -152,8 +152,8 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
|
|
|
|
|
# return pad_sequence(ys_in, padding_value=eos).transpose([1,0]), pad_sequence(ys_out, padding_value=ignore_id).transpose([1,0])
|
|
|
|
|
|
|
|
|
|
B = ys_pad.shape[0]
|
|
|
|
|
_sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos
|
|
|
|
|
_eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos
|
|
|
|
|
_sos = paddle.full([B, 1], sos, dtype=ys_pad.dtype)
|
|
|
|
|
_eos = paddle.full([B, 1], eos, dtype=ys_pad.dtype)
|
|
|
|
|
ys_in = paddle.cat([_sos, ys_pad], dim=1)
|
|
|
|
|
mask_pad = (ys_in == ignore_id)
|
|
|
|
|
ys_in = ys_in.masked_fill(mask_pad, eos)
|
|
|
|
|