fix bug of pad_sequence in u2,test=asr (#1153)

pull/1155/head
Jackwaterveg 3 years ago committed by GitHub
parent 7ce65e13d3
commit 0151f2463f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -579,10 +579,15 @@ class U2BaseModel(ASRInterface, nn.Layer):
num_decoding_left_chunks, simulate_streaming)
assert len(hyps) == beam_size
hyps_pad = pad_sequence([
paddle.to_tensor(hyp[0], place=device, dtype=paddle.long)
for hyp in hyps
], True, self.ignore_id) # (beam_size, max_hyps_len)
hyp_list = []
for hyp in hyps:
hyp_content = hyp[0]
# Prevent the hyp is empty
if len(hyp_content) == 0:
hyp_content = (self.ctc.blank_id,)
hyp_content = paddle.to_tensor(hyp_content, place=device, dtype=paddle.long)
hyp_list.append(hyp_content)
hyps_pad = pad_sequence(hyp_list, True, self.ignore_id)
hyps_lens = paddle.to_tensor(
[len(hyp[0]) for hyp in hyps], place=device,
dtype=paddle.long) # (beam_size,)

Loading…
Cancel
Save