From b3359f9ee380285720552cc96c2d6afbf0018af9 Mon Sep 17 00:00:00 2001 From: liyulingyue <852433440@qq.com> Date: Wed, 12 Mar 2025 20:05:17 +0800 Subject: [PATCH] fix unit error > Type promotion --- tests/unit/asr/reverse_pad_list.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/unit/asr/reverse_pad_list.py b/tests/unit/asr/reverse_pad_list.py index 215ed5ceb..1b63890a0 100644 --- a/tests/unit/asr/reverse_pad_list.py +++ b/tests/unit/asr/reverse_pad_list.py @@ -65,14 +65,16 @@ def reverse_pad_list_with_sos_eos(r_hyps, max_len = paddle.max(r_hyps_lens) index_range = paddle.arange(0, max_len, 1) seq_len_expand = r_hyps_lens.unsqueeze(1) - seq_mask = seq_len_expand > index_range # (beam, max_len) + seq_mask = seq_len_expand > index_range.astype( + seq_len_expand.dtype) # (beam, max_len) - index = (seq_len_expand - 1) - index_range # (beam, max_len) + index = (seq_len_expand - 1) - index_range.astype( + seq_len_expand.dtype) # (beam, max_len) # >>> index # >>> tensor([[ 2, 1, 0], # >>> [ 2, 1, 0], # >>> [ 0, -1, -2]]) - index = index * seq_mask + index = index * seq_mask.astype(index.dtype) # >>> index # >>> tensor([[2, 1, 0], @@ -103,7 +105,8 @@ def reverse_pad_list_with_sos_eos(r_hyps, # >>> tensor([[3, 2, 1], # >>> [4, 8, 9], # >>> [2, 2, 2]]) - r_hyps = paddle.where(seq_mask, r_hyps, eos) + r_hyps = paddle.where(seq_mask, r_hyps, + paddle.to_tensor(eos, dtype=r_hyps.dtype)) # >>> r_hyps # >>> tensor([[3, 2, 1], # >>> [4, 8, 9],