fix bug of pad_sequence in u2,test=asr (#1153)

pull/1155/head
Jackwaterveg 3 years ago committed by GitHub
parent 7ce65e13d3
commit 0151f2463f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -579,10 +579,15 @@ class U2BaseModel(ASRInterface, nn.Layer):
num_decoding_left_chunks, simulate_streaming) num_decoding_left_chunks, simulate_streaming)
assert len(hyps) == beam_size assert len(hyps) == beam_size
hyps_pad = pad_sequence([ hyp_list = []
paddle.to_tensor(hyp[0], place=device, dtype=paddle.long) for hyp in hyps:
for hyp in hyps hyp_content = hyp[0]
], True, self.ignore_id) # (beam_size, max_hyps_len) # Prevent the hyp is empty
if len(hyp_content) == 0:
hyp_content = (self.ctc.blank_id,)
hyp_content = paddle.to_tensor(hyp_content, place=device, dtype=paddle.long)
hyp_list.append(hyp_content)
hyps_pad = pad_sequence(hyp_list, True, self.ignore_id)
hyps_lens = paddle.to_tensor( hyps_lens = paddle.to_tensor(
[len(hyp[0]) for hyp in hyps], place=device, [len(hyp[0]) for hyp in hyps], place=device,
dtype=paddle.long) # (beam_size,) dtype=paddle.long) # (beam_size,)

Loading…
Cancel
Save