tests/unit/tts/test_losses.py: Add gradient tests and update precision calculation methods

- Change precision threshold to ’1e-5‘
- Use relative error instead of absolute error
pull/3954/head
suzakuwcx 9 months ago
parent 37f60d6c2a
commit c1a8f996f3
No known key found for this signature in database
GPG Key ID: FA07FC9584DD32FE

@ -10,6 +10,10 @@ def test_dac_losses():
loss_origin = torch.load(f'tests/unit/tts/data/{i}-loss.pt')
recons = AudioSignal(f'tests/unit/tts/data/{i}-recons.wav')
signal = AudioSignal(f'tests/unit/tts/data/{i}-signal.wav')
recons.audio_data.stop_gradient = False
signal.audio_data.stop_gradient = False
loss_fn_1 = MultiScaleSTFTLoss()
loss_fn_2 = MultiMelSpectrogramLoss(
n_mels=[5, 10, 20, 40, 80, 160, 320],
@ -18,22 +22,57 @@ def test_dac_losses():
pow=1.0,
mel_fmin=[0, 0, 0, 0, 0, 0, 0],
mel_fmax=[None, None, None, None, None, None, None])
#
# Test AudioSignal
#
loss_1 = loss_fn_1(recons, signal)
loss_1.backward()
loss_1_grad = signal.audio_data.grad.sum()
assert abs((loss_1.item() - loss_origin['stft/loss'].item()) /
loss_1.item()) < 1e-5
assert abs((loss_1_grad.item() - loss_origin['stft/grad'].sum().item())
/ loss_1_grad.item()) < 1e-5
signal.audio_data.clear_grad()
recons.audio_data.clear_grad()
loss_2 = loss_fn_2(recons, signal)
loss_2.backward()
loss_2_grad = signal.audio_data.grad.sum()
assert abs((loss_2.item() - loss_origin['mel/loss'].item()) /
loss_2.item()) < 1e-5
assert abs(
loss_fn_1(recons, signal).item() - loss_origin['stft/loss']
.item()) < 1e-5
assert abs(
loss_fn_2(recons, signal).item() - loss_origin['mel/loss']
.item()) < 1e-5
(signal.audio_data.grad.sum().item() -
loss_origin['mel/grad'].sum().item()) / loss_2_grad.item()) < 1e-5
signal.audio_data.clear_grad()
recons.audio_data.clear_grad()
#
# Test Tensor
#
assert abs(
loss_fn_1(recons.audio_data, signal.audio_data).item() -
loss_origin['stft/loss'].item()) < 1e-3
assert abs(
loss_fn_2(recons.audio_data, signal.audio_data).item() -
loss_origin['mel/loss'].item()) < 1e-3
loss_1 = loss_fn_1(recons.audio_data, signal.audio_data)
loss_1.backward()
loss_1_grad = signal.audio_data.grad.sum()
assert abs(loss_1.item() - loss_origin['stft/loss']
.item()) / loss_1.item() < 1e-5
assert abs(loss_1_grad.item() - loss_origin['stft/grad'].sum()
.item()) / loss_1_grad.item() < 1e-5
signal.audio_data.clear_grad()
recons.audio_data.clear_grad()
loss_2 = loss_fn_2(recons.audio_data, signal.audio_data)
loss_2.backward()
loss_2_grad = signal.audio_data.grad.sum()
assert abs((loss_2.item() - loss_origin['mel/loss'].item()) /
loss_2.item()) < 1e-5
assert abs((loss_2_grad.item() - loss_origin['mel/grad'].sum().item()) /
loss_2_grad.item()) < 1e-5

Loading…
Cancel
Save