separate paddle.logsumexp (#3897)

pull/3898/head
zxcd 10 months ago committed by GitHub
parent 89bfd44293
commit d32ced7f1a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -971,8 +971,14 @@ class ApplyTimestampRules(LogitFilter):
# if sum of probability over timestamps is above any other token, sample timestamp
logprobs = F.log_softmax(logits, axis=-1, dtype=paddle.float32)
for k in range(tokens.shape[0]):
timestamp_logprob = paddle.logsumexp(
logprobs[k, self.tokenizer.timestamp_begin:], axis=-1)
# When using paddle.logsumexp on a 32GB Tesla-V100 GPU, we encountered CUDA error 700.
# To bypass this issue in CI, we have decomposed the operation into separate steps.
# It will raise 2e-6 difference in precision.
# TODO: revert this after logsumexp been fixed.
timestamp_logprob = paddle.exp(
logprobs[k, self.tokenizer.timestamp_begin:])
timestamp_logprob = paddle.sum(timestamp_logprob, axis=-1)
timestamp_logprob = paddle.log(timestamp_logprob)
max_text_token_logprob = paddle.max(
logprobs[k, :self.tokenizer.timestamp_begin])
if timestamp_logprob > max_text_token_logprob:

Loading…
Cancel
Save