|
|
@ -31,10 +31,13 @@ def test_dac_losses():
|
|
|
|
loss_1.backward()
|
|
|
|
loss_1.backward()
|
|
|
|
loss_1_grad = signal.audio_data.grad.sum()
|
|
|
|
loss_1_grad = signal.audio_data.grad.sum()
|
|
|
|
|
|
|
|
|
|
|
|
assert abs((loss_1.item() - loss_origin['stft/loss'].item()) /
|
|
|
|
assert abs(
|
|
|
|
loss_1.item()) < 1e-5
|
|
|
|
(loss_1.item() - loss_origin['stft/loss'].item()) /
|
|
|
|
assert abs((loss_1_grad.item() - loss_origin['stft/grad'].sum().item())
|
|
|
|
loss_1.item()) < 1e-5, r"value incorrect for 'MultiScaleSTFTLoss'"
|
|
|
|
/ loss_1_grad.item()) < 1e-5
|
|
|
|
assert abs(
|
|
|
|
|
|
|
|
(loss_1_grad.item() - loss_origin['stft/grad'].sum().item()
|
|
|
|
|
|
|
|
) / loss_1_grad.
|
|
|
|
|
|
|
|
item()) < 1e-5, r"gradient incorrect for 'MultiScaleSTFTLoss'"
|
|
|
|
|
|
|
|
|
|
|
|
signal.audio_data.clear_grad()
|
|
|
|
signal.audio_data.clear_grad()
|
|
|
|
recons.audio_data.clear_grad()
|
|
|
|
recons.audio_data.clear_grad()
|
|
|
@ -43,11 +46,13 @@ def test_dac_losses():
|
|
|
|
loss_2.backward()
|
|
|
|
loss_2.backward()
|
|
|
|
loss_2_grad = signal.audio_data.grad.sum()
|
|
|
|
loss_2_grad = signal.audio_data.grad.sum()
|
|
|
|
|
|
|
|
|
|
|
|
assert abs((loss_2.item() - loss_origin['mel/loss'].item()) /
|
|
|
|
assert abs(
|
|
|
|
loss_2.item()) < 1e-5
|
|
|
|
(loss_2.item() - loss_origin['mel/loss'].item()) / loss_2.
|
|
|
|
|
|
|
|
item()) < 1e-5, r"value incorrect for 'MultiMelSpectrogramLoss'"
|
|
|
|
assert abs(
|
|
|
|
assert abs(
|
|
|
|
(signal.audio_data.grad.sum().item() -
|
|
|
|
(signal.audio_data.grad.sum().item() -
|
|
|
|
loss_origin['mel/grad'].sum().item()) / loss_2_grad.item()) < 1e-5
|
|
|
|
loss_origin['mel/grad'].sum().item()) / loss_2_grad.
|
|
|
|
|
|
|
|
item()) < 1e-5, r"gradient incorrect for 'MultiMelSpectrogramLoss'"
|
|
|
|
|
|
|
|
|
|
|
|
signal.audio_data.clear_grad()
|
|
|
|
signal.audio_data.clear_grad()
|
|
|
|
recons.audio_data.clear_grad()
|
|
|
|
recons.audio_data.clear_grad()
|
|
|
@ -60,10 +65,11 @@ def test_dac_losses():
|
|
|
|
loss_1.backward()
|
|
|
|
loss_1.backward()
|
|
|
|
loss_1_grad = signal.audio_data.grad.sum()
|
|
|
|
loss_1_grad = signal.audio_data.grad.sum()
|
|
|
|
|
|
|
|
|
|
|
|
assert abs(loss_1.item() - loss_origin['stft/loss']
|
|
|
|
assert abs(loss_1.item() - loss_origin['stft/loss'].item(
|
|
|
|
.item()) / loss_1.item() < 1e-5
|
|
|
|
)) / loss_1.item() < 1e-5, r"value incorrect for 'MultiScaleSTFTLoss'"
|
|
|
|
assert abs(loss_1_grad.item() - loss_origin['stft/grad'].sum()
|
|
|
|
assert abs(loss_1_grad.item() - loss_origin['stft/grad'].sum()
|
|
|
|
.item()) / loss_1_grad.item() < 1e-5
|
|
|
|
.item()) / loss_1_grad.item(
|
|
|
|
|
|
|
|
) < 1e-5, r"gradient incorrect for 'MultiScaleSTFTLoss'"
|
|
|
|
|
|
|
|
|
|
|
|
signal.audio_data.clear_grad()
|
|
|
|
signal.audio_data.clear_grad()
|
|
|
|
recons.audio_data.clear_grad()
|
|
|
|
recons.audio_data.clear_grad()
|
|
|
@ -72,7 +78,10 @@ def test_dac_losses():
|
|
|
|
loss_2.backward()
|
|
|
|
loss_2.backward()
|
|
|
|
loss_2_grad = signal.audio_data.grad.sum()
|
|
|
|
loss_2_grad = signal.audio_data.grad.sum()
|
|
|
|
|
|
|
|
|
|
|
|
assert abs((loss_2.item() - loss_origin['mel/loss'].item()) /
|
|
|
|
assert abs(
|
|
|
|
loss_2.item()) < 1e-5
|
|
|
|
(loss_2.item() - loss_origin['mel/loss'].item()) / loss_2.
|
|
|
|
assert abs((loss_2_grad.item() - loss_origin['mel/grad'].sum().item()) /
|
|
|
|
item()) < 1e-5, r"value incorrect for 'MultiMelSpectrogramLoss'"
|
|
|
|
loss_2_grad.item()) < 1e-5
|
|
|
|
assert abs(
|
|
|
|
|
|
|
|
(loss_2_grad.item() - loss_origin['mel/grad'].sum().item()
|
|
|
|
|
|
|
|
) / loss_2_grad.
|
|
|
|
|
|
|
|
item()) < 1e-5, r"gradient incorrect for 'MultiMelSpectrogramLoss'"
|
|
|
|