|
|
|
@ -835,8 +835,14 @@ class BeamSearchDecoder(TokenDecoder):
|
|
|
|
|
logprob, token = paddle.topk(
|
|
|
|
|
logprobs[idx], k=self.beam_size + 1)
|
|
|
|
|
for logprob, token in zip(logprob, token):
|
|
|
|
|
new_logprob = (sum_logprobs[idx] + logprob).tolist()[0]
|
|
|
|
|
sequence = tuple(prefix + [token.tolist()[0]])
|
|
|
|
|
# after Paddle 3.0, tolist in 0-D tensor will return a float/int value instead of a list
|
|
|
|
|
new_logprob = (sum_logprobs[idx] + logprob).tolist()
|
|
|
|
|
new_logprob = new_logprob if isinstance(
|
|
|
|
|
new_logprob, float) else new_logprob[0]
|
|
|
|
|
new_token = token.tolist()
|
|
|
|
|
new_token = new_token if isinstance(new_token,
|
|
|
|
|
int) else new_token[0]
|
|
|
|
|
sequence = tuple(prefix + [new_token])
|
|
|
|
|
scores[sequence] = new_logprob
|
|
|
|
|
sources[sequence] = idx
|
|
|
|
|
|
|
|
|
|