fix unit error > Type promotion

pull/4000/head
liyulingyue 6 months ago
parent 0e51e5d8fc
commit b3359f9ee3

@ -65,14 +65,16 @@ def reverse_pad_list_with_sos_eos(r_hyps,
max_len = paddle.max(r_hyps_lens) max_len = paddle.max(r_hyps_lens)
index_range = paddle.arange(0, max_len, 1) index_range = paddle.arange(0, max_len, 1)
seq_len_expand = r_hyps_lens.unsqueeze(1) seq_len_expand = r_hyps_lens.unsqueeze(1)
seq_mask = seq_len_expand > index_range # (beam, max_len) seq_mask = seq_len_expand > index_range.astype(
seq_len_expand.dtype) # (beam, max_len)
index = (seq_len_expand - 1) - index_range # (beam, max_len) index = (seq_len_expand - 1) - index_range.astype(
seq_len_expand.dtype) # (beam, max_len)
# >>> index # >>> index
# >>> tensor([[ 2, 1, 0], # >>> tensor([[ 2, 1, 0],
# >>> [ 2, 1, 0], # >>> [ 2, 1, 0],
# >>> [ 0, -1, -2]]) # >>> [ 0, -1, -2]])
index = index * seq_mask index = index * seq_mask.astype(index.dtype)
# >>> index # >>> index
# >>> tensor([[2, 1, 0], # >>> tensor([[2, 1, 0],
@ -103,7 +105,8 @@ def reverse_pad_list_with_sos_eos(r_hyps,
# >>> tensor([[3, 2, 1], # >>> tensor([[3, 2, 1],
# >>> [4, 8, 9], # >>> [4, 8, 9],
# >>> [2, 2, 2]]) # >>> [2, 2, 2]])
r_hyps = paddle.where(seq_mask, r_hyps, eos) r_hyps = paddle.where(seq_mask, r_hyps,
paddle.to_tensor(eos, dtype=r_hyps.dtype))
# >>> r_hyps # >>> r_hyps
# >>> tensor([[3, 2, 1], # >>> tensor([[3, 2, 1],
# >>> [4, 8, 9], # >>> [4, 8, 9],

Loading…
Cancel
Save