diff --git a/tests/unit/tts/test_losses.py b/tests/unit/tts/test_losses.py index 4948b2065..26b1b20c2 100644 --- a/tests/unit/tts/test_losses.py +++ b/tests/unit/tts/test_losses.py @@ -31,10 +31,13 @@ def test_dac_losses(): loss_1.backward() loss_1_grad = signal.audio_data.grad.sum() - assert abs((loss_1.item() - loss_origin['stft/loss'].item()) / - loss_1.item()) < 1e-5 - assert abs((loss_1_grad.item() - loss_origin['stft/grad'].sum().item()) - / loss_1_grad.item()) < 1e-5 + assert abs( + (loss_1.item() - loss_origin['stft/loss'].item()) / + loss_1.item()) < 1e-5, r"value incorrect for 'MultiScaleSTFTLoss'" + assert abs( + (loss_1_grad.item() - loss_origin['stft/grad'].sum().item() + ) / loss_1_grad. + item()) < 1e-5, r"gradient incorrect for 'MultiScaleSTFTLoss'" signal.audio_data.clear_grad() recons.audio_data.clear_grad() @@ -43,11 +46,13 @@ def test_dac_losses(): loss_2.backward() loss_2_grad = signal.audio_data.grad.sum() - assert abs((loss_2.item() - loss_origin['mel/loss'].item()) / - loss_2.item()) < 1e-5 + assert abs( + (loss_2.item() - loss_origin['mel/loss'].item()) / loss_2. + item()) < 1e-5, r"value incorrect for 'MultiMelSpectrogramLoss'" assert abs( (signal.audio_data.grad.sum().item() - - loss_origin['mel/grad'].sum().item()) / loss_2_grad.item()) < 1e-5 + loss_origin['mel/grad'].sum().item()) / loss_2_grad. + item()) < 1e-5, r"gradient incorrect for 'MultiMelSpectrogramLoss'" signal.audio_data.clear_grad() recons.audio_data.clear_grad() @@ -60,10 +65,11 @@ def test_dac_losses(): loss_1.backward() loss_1_grad = signal.audio_data.grad.sum() - assert abs(loss_1.item() - loss_origin['stft/loss'] - .item()) / loss_1.item() < 1e-5 + assert abs(loss_1.item() - loss_origin['stft/loss'].item( + )) / loss_1.item() < 1e-5, r"value incorrect for 'MultiScaleSTFTLoss'" assert abs(loss_1_grad.item() - loss_origin['stft/grad'].sum() - .item()) / loss_1_grad.item() < 1e-5 + .item()) / loss_1_grad.item( + ) < 1e-5, r"gradient incorrect for 'MultiScaleSTFTLoss'" signal.audio_data.clear_grad() recons.audio_data.clear_grad() @@ -72,7 +78,10 @@ def test_dac_losses(): loss_2.backward() loss_2_grad = signal.audio_data.grad.sum() - assert abs((loss_2.item() - loss_origin['mel/loss'].item()) / - loss_2.item()) < 1e-5 - assert abs((loss_2_grad.item() - loss_origin['mel/grad'].sum().item()) / - loss_2_grad.item()) < 1e-5 + assert abs( + (loss_2.item() - loss_origin['mel/loss'].item()) / loss_2. + item()) < 1e-5, r"value incorrect for 'MultiMelSpectrogramLoss'" + assert abs( + (loss_2_grad.item() - loss_origin['mel/grad'].sum().item() + ) / loss_2_grad. + item()) < 1e-5, r"gradient incorrect for 'MultiMelSpectrogramLoss'"