tests/unit/tts/test_losses.py: Add error message on assert failed

pull/3954/head
suzakuwcx 8 months ago
parent c1a8f996f3
commit fd5365c5b6
No known key found for this signature in database
GPG Key ID: FA07FC9584DD32FE

@ -31,10 +31,13 @@ def test_dac_losses():
loss_1.backward() loss_1.backward()
loss_1_grad = signal.audio_data.grad.sum() loss_1_grad = signal.audio_data.grad.sum()
assert abs((loss_1.item() - loss_origin['stft/loss'].item()) / assert abs(
loss_1.item()) < 1e-5 (loss_1.item() - loss_origin['stft/loss'].item()) /
assert abs((loss_1_grad.item() - loss_origin['stft/grad'].sum().item()) loss_1.item()) < 1e-5, r"value incorrect for 'MultiScaleSTFTLoss'"
/ loss_1_grad.item()) < 1e-5 assert abs(
(loss_1_grad.item() - loss_origin['stft/grad'].sum().item()
) / loss_1_grad.
item()) < 1e-5, r"gradient incorrect for 'MultiScaleSTFTLoss'"
signal.audio_data.clear_grad() signal.audio_data.clear_grad()
recons.audio_data.clear_grad() recons.audio_data.clear_grad()
@ -43,11 +46,13 @@ def test_dac_losses():
loss_2.backward() loss_2.backward()
loss_2_grad = signal.audio_data.grad.sum() loss_2_grad = signal.audio_data.grad.sum()
assert abs((loss_2.item() - loss_origin['mel/loss'].item()) / assert abs(
loss_2.item()) < 1e-5 (loss_2.item() - loss_origin['mel/loss'].item()) / loss_2.
item()) < 1e-5, r"value incorrect for 'MultiMelSpectrogramLoss'"
assert abs( assert abs(
(signal.audio_data.grad.sum().item() - (signal.audio_data.grad.sum().item() -
loss_origin['mel/grad'].sum().item()) / loss_2_grad.item()) < 1e-5 loss_origin['mel/grad'].sum().item()) / loss_2_grad.
item()) < 1e-5, r"gradient incorrect for 'MultiMelSpectrogramLoss'"
signal.audio_data.clear_grad() signal.audio_data.clear_grad()
recons.audio_data.clear_grad() recons.audio_data.clear_grad()
@ -60,10 +65,11 @@ def test_dac_losses():
loss_1.backward() loss_1.backward()
loss_1_grad = signal.audio_data.grad.sum() loss_1_grad = signal.audio_data.grad.sum()
assert abs(loss_1.item() - loss_origin['stft/loss'] assert abs(loss_1.item() - loss_origin['stft/loss'].item(
.item()) / loss_1.item() < 1e-5 )) / loss_1.item() < 1e-5, r"value incorrect for 'MultiScaleSTFTLoss'"
assert abs(loss_1_grad.item() - loss_origin['stft/grad'].sum() assert abs(loss_1_grad.item() - loss_origin['stft/grad'].sum()
.item()) / loss_1_grad.item() < 1e-5 .item()) / loss_1_grad.item(
) < 1e-5, r"gradient incorrect for 'MultiScaleSTFTLoss'"
signal.audio_data.clear_grad() signal.audio_data.clear_grad()
recons.audio_data.clear_grad() recons.audio_data.clear_grad()
@ -72,7 +78,10 @@ def test_dac_losses():
loss_2.backward() loss_2.backward()
loss_2_grad = signal.audio_data.grad.sum() loss_2_grad = signal.audio_data.grad.sum()
assert abs((loss_2.item() - loss_origin['mel/loss'].item()) / assert abs(
loss_2.item()) < 1e-5 (loss_2.item() - loss_origin['mel/loss'].item()) / loss_2.
assert abs((loss_2_grad.item() - loss_origin['mel/grad'].sum().item()) / item()) < 1e-5, r"value incorrect for 'MultiMelSpectrogramLoss'"
loss_2_grad.item()) < 1e-5 assert abs(
(loss_2_grad.item() - loss_origin['mel/grad'].sum().item()
) / loss_2_grad.
item()) < 1e-5, r"gradient incorrect for 'MultiMelSpectrogramLoss'"

Loading…
Cancel
Save