|
|
|
@ -971,8 +971,14 @@ class ApplyTimestampRules(LogitFilter):
|
|
|
|
|
# if sum of probability over timestamps is above any other token, sample timestamp
|
|
|
|
|
logprobs = F.log_softmax(logits, axis=-1, dtype=paddle.float32)
|
|
|
|
|
for k in range(tokens.shape[0]):
|
|
|
|
|
timestamp_logprob = paddle.logsumexp(
|
|
|
|
|
logprobs[k, self.tokenizer.timestamp_begin:], axis=-1)
|
|
|
|
|
# When using paddle.logsumexp on a 32GB Tesla-V100 GPU, we encountered CUDA error 700.
|
|
|
|
|
# To bypass this issue in CI, we have decomposed the operation into separate steps.
|
|
|
|
|
# It will raise 2e-6 difference in precision.
|
|
|
|
|
# TODO: revert this after logsumexp been fixed.
|
|
|
|
|
timestamp_logprob = paddle.exp(
|
|
|
|
|
logprobs[k, self.tokenizer.timestamp_begin:])
|
|
|
|
|
timestamp_logprob = paddle.sum(timestamp_logprob, axis=-1)
|
|
|
|
|
timestamp_logprob = paddle.log(timestamp_logprob)
|
|
|
|
|
max_text_token_logprob = paddle.max(
|
|
|
|
|
logprobs[k, :self.tokenizer.timestamp_begin])
|
|
|
|
|
if timestamp_logprob > max_text_token_logprob:
|
|
|
|
|