[Fix] type promotion

pull/3944/head
megemini 9 months ago
parent a34bf501a5
commit bc83ac946f

@ -37,7 +37,7 @@ def compute_amplitude(waveforms, lengths=None, amp_type="avg", scale="linear"):
out = paddle.mean(paddle.abs(waveforms), axis=1, keepdim=True)
else:
wav_sum = paddle.sum(paddle.abs(waveforms), axis=1, keepdim=True)
out = wav_sum / lengths
out = wav_sum / lengths.astype(wav_sum.dtype)
elif amp_type == "peak":
out = paddle.max(paddle.abs(waveforms), axis=1, keepdim=True)
else:

Loading…
Cancel
Save