tests/unit/tts/test_losses.py: Add error message on assert failed

pull/3954/head
suzakuwcx 8 months ago
parent c1a8f996f3
commit fd5365c5b6
No known key found for this signature in database
GPG Key ID: FA07FC9584DD32FE

@ -31,10 +31,13 @@ def test_dac_losses():
loss_1.backward()
loss_1_grad = signal.audio_data.grad.sum()
assert abs((loss_1.item() - loss_origin['stft/loss'].item()) /
loss_1.item()) < 1e-5
assert abs((loss_1_grad.item() - loss_origin['stft/grad'].sum().item())
/ loss_1_grad.item()) < 1e-5
assert abs(
(loss_1.item() - loss_origin['stft/loss'].item()) /
loss_1.item()) < 1e-5, r"value incorrect for 'MultiScaleSTFTLoss'"
assert abs(
(loss_1_grad.item() - loss_origin['stft/grad'].sum().item()
) / loss_1_grad.
item()) < 1e-5, r"gradient incorrect for 'MultiScaleSTFTLoss'"
signal.audio_data.clear_grad()
recons.audio_data.clear_grad()
@ -43,11 +46,13 @@ def test_dac_losses():
loss_2.backward()
loss_2_grad = signal.audio_data.grad.sum()
assert abs((loss_2.item() - loss_origin['mel/loss'].item()) /
loss_2.item()) < 1e-5
assert abs(
(loss_2.item() - loss_origin['mel/loss'].item()) / loss_2.
item()) < 1e-5, r"value incorrect for 'MultiMelSpectrogramLoss'"
assert abs(
(signal.audio_data.grad.sum().item() -
loss_origin['mel/grad'].sum().item()) / loss_2_grad.item()) < 1e-5
loss_origin['mel/grad'].sum().item()) / loss_2_grad.
item()) < 1e-5, r"gradient incorrect for 'MultiMelSpectrogramLoss'"
signal.audio_data.clear_grad()
recons.audio_data.clear_grad()
@ -60,10 +65,11 @@ def test_dac_losses():
loss_1.backward()
loss_1_grad = signal.audio_data.grad.sum()
assert abs(loss_1.item() - loss_origin['stft/loss']
.item()) / loss_1.item() < 1e-5
assert abs(loss_1.item() - loss_origin['stft/loss'].item(
)) / loss_1.item() < 1e-5, r"value incorrect for 'MultiScaleSTFTLoss'"
assert abs(loss_1_grad.item() - loss_origin['stft/grad'].sum()
.item()) / loss_1_grad.item() < 1e-5
.item()) / loss_1_grad.item(
) < 1e-5, r"gradient incorrect for 'MultiScaleSTFTLoss'"
signal.audio_data.clear_grad()
recons.audio_data.clear_grad()
@ -72,7 +78,10 @@ def test_dac_losses():
loss_2.backward()
loss_2_grad = signal.audio_data.grad.sum()
assert abs((loss_2.item() - loss_origin['mel/loss'].item()) /
loss_2.item()) < 1e-5
assert abs((loss_2_grad.item() - loss_origin['mel/grad'].sum().item()) /
loss_2_grad.item()) < 1e-5
assert abs(
(loss_2.item() - loss_origin['mel/loss'].item()) / loss_2.
item()) < 1e-5, r"value incorrect for 'MultiMelSpectrogramLoss'"
assert abs(
(loss_2_grad.item() - loss_origin['mel/grad'].sum().item()
) / loss_2_grad.
item()) < 1e-5, r"gradient incorrect for 'MultiMelSpectrogramLoss'"

Loading…
Cancel
Save