|
|
|
@ -237,7 +237,7 @@ def st_reverse_pad_list(ys_pad: paddle.Tensor,
|
|
|
|
|
# >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id))
|
|
|
|
|
# >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id)
|
|
|
|
|
B = ys_pad.shape[0]
|
|
|
|
|
_sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos
|
|
|
|
|
_sos = paddle.full([B, 1], sos, dtype=ys_pad.dtype)
|
|
|
|
|
max_len = paddle.max(ys_lens)
|
|
|
|
|
index_range = paddle.arange(0, max_len, 1)
|
|
|
|
|
seq_len_expand = ys_lens.unsqueeze(1)
|
|
|
|
@ -279,6 +279,7 @@ def st_reverse_pad_list(ys_pad: paddle.Tensor,
|
|
|
|
|
# >>> tensor([[3, 2, 1],
|
|
|
|
|
# >>> [4, 8, 9],
|
|
|
|
|
# >>> [2, 2, 2]])
|
|
|
|
|
eos = paddle.full([1], eos, dtype=r_hyps.dtype)
|
|
|
|
|
r_hyps = paddle.where(seq_mask, r_hyps, eos)
|
|
|
|
|
# >>> r_hyps
|
|
|
|
|
# >>> tensor([[3, 2, 1],
|
|
|
|
|