|
|
@ -65,14 +65,16 @@ def reverse_pad_list_with_sos_eos(r_hyps,
|
|
|
|
max_len = paddle.max(r_hyps_lens)
|
|
|
|
max_len = paddle.max(r_hyps_lens)
|
|
|
|
index_range = paddle.arange(0, max_len, 1)
|
|
|
|
index_range = paddle.arange(0, max_len, 1)
|
|
|
|
seq_len_expand = r_hyps_lens.unsqueeze(1)
|
|
|
|
seq_len_expand = r_hyps_lens.unsqueeze(1)
|
|
|
|
seq_mask = seq_len_expand > index_range # (beam, max_len)
|
|
|
|
seq_mask = seq_len_expand > index_range.astype(
|
|
|
|
|
|
|
|
seq_len_expand.dtype) # (beam, max_len)
|
|
|
|
|
|
|
|
|
|
|
|
index = (seq_len_expand - 1) - index_range # (beam, max_len)
|
|
|
|
index = (seq_len_expand - 1) - index_range.astype(
|
|
|
|
|
|
|
|
seq_len_expand.dtype) # (beam, max_len)
|
|
|
|
# >>> index
|
|
|
|
# >>> index
|
|
|
|
# >>> tensor([[ 2, 1, 0],
|
|
|
|
# >>> tensor([[ 2, 1, 0],
|
|
|
|
# >>> [ 2, 1, 0],
|
|
|
|
# >>> [ 2, 1, 0],
|
|
|
|
# >>> [ 0, -1, -2]])
|
|
|
|
# >>> [ 0, -1, -2]])
|
|
|
|
index = index * seq_mask
|
|
|
|
index = index * seq_mask.astype(index.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
# >>> index
|
|
|
|
# >>> index
|
|
|
|
# >>> tensor([[2, 1, 0],
|
|
|
|
# >>> tensor([[2, 1, 0],
|
|
|
@ -103,7 +105,8 @@ def reverse_pad_list_with_sos_eos(r_hyps,
|
|
|
|
# >>> tensor([[3, 2, 1],
|
|
|
|
# >>> tensor([[3, 2, 1],
|
|
|
|
# >>> [4, 8, 9],
|
|
|
|
# >>> [4, 8, 9],
|
|
|
|
# >>> [2, 2, 2]])
|
|
|
|
# >>> [2, 2, 2]])
|
|
|
|
r_hyps = paddle.where(seq_mask, r_hyps, eos)
|
|
|
|
r_hyps = paddle.where(seq_mask, r_hyps,
|
|
|
|
|
|
|
|
paddle.to_tensor(eos, dtype=r_hyps.dtype))
|
|
|
|
# >>> r_hyps
|
|
|
|
# >>> r_hyps
|
|
|
|
# >>> tensor([[3, 2, 1],
|
|
|
|
# >>> tensor([[3, 2, 1],
|
|
|
|
# >>> [4, 8, 9],
|
|
|
|
# >>> [4, 8, 9],
|
|
|
|