|
|
|
@ -10,6 +10,10 @@ def test_dac_losses():
|
|
|
|
|
loss_origin = torch.load(f'tests/unit/tts/data/{i}-loss.pt')
|
|
|
|
|
recons = AudioSignal(f'tests/unit/tts/data/{i}-recons.wav')
|
|
|
|
|
signal = AudioSignal(f'tests/unit/tts/data/{i}-signal.wav')
|
|
|
|
|
|
|
|
|
|
recons.audio_data.stop_gradient = False
|
|
|
|
|
signal.audio_data.stop_gradient = False
|
|
|
|
|
|
|
|
|
|
loss_fn_1 = MultiScaleSTFTLoss()
|
|
|
|
|
loss_fn_2 = MultiMelSpectrogramLoss(
|
|
|
|
|
n_mels=[5, 10, 20, 40, 80, 160, 320],
|
|
|
|
@ -18,22 +22,57 @@ def test_dac_losses():
|
|
|
|
|
pow=1.0,
|
|
|
|
|
mel_fmin=[0, 0, 0, 0, 0, 0, 0],
|
|
|
|
|
mel_fmax=[None, None, None, None, None, None, None])
|
|
|
|
|
|
|
|
|
|
#
|
|
|
|
|
# Test AudioSignal
|
|
|
|
|
#
|
|
|
|
|
|
|
|
|
|
loss_1 = loss_fn_1(recons, signal)
|
|
|
|
|
loss_1.backward()
|
|
|
|
|
loss_1_grad = signal.audio_data.grad.sum()
|
|
|
|
|
|
|
|
|
|
assert abs((loss_1.item() - loss_origin['stft/loss'].item()) /
|
|
|
|
|
loss_1.item()) < 1e-5
|
|
|
|
|
assert abs((loss_1_grad.item() - loss_origin['stft/grad'].sum().item())
|
|
|
|
|
/ loss_1_grad.item()) < 1e-5
|
|
|
|
|
|
|
|
|
|
signal.audio_data.clear_grad()
|
|
|
|
|
recons.audio_data.clear_grad()
|
|
|
|
|
|
|
|
|
|
loss_2 = loss_fn_2(recons, signal)
|
|
|
|
|
loss_2.backward()
|
|
|
|
|
loss_2_grad = signal.audio_data.grad.sum()
|
|
|
|
|
|
|
|
|
|
assert abs((loss_2.item() - loss_origin['mel/loss'].item()) /
|
|
|
|
|
loss_2.item()) < 1e-5
|
|
|
|
|
assert abs(
|
|
|
|
|
loss_fn_1(recons, signal).item() - loss_origin['stft/loss']
|
|
|
|
|
.item()) < 1e-5
|
|
|
|
|
assert abs(
|
|
|
|
|
loss_fn_2(recons, signal).item() - loss_origin['mel/loss']
|
|
|
|
|
.item()) < 1e-5
|
|
|
|
|
(signal.audio_data.grad.sum().item() -
|
|
|
|
|
loss_origin['mel/grad'].sum().item()) / loss_2_grad.item()) < 1e-5
|
|
|
|
|
|
|
|
|
|
signal.audio_data.clear_grad()
|
|
|
|
|
recons.audio_data.clear_grad()
|
|
|
|
|
|
|
|
|
|
#
|
|
|
|
|
# Test Tensor
|
|
|
|
|
#
|
|
|
|
|
assert abs(
|
|
|
|
|
loss_fn_1(recons.audio_data, signal.audio_data).item() -
|
|
|
|
|
loss_origin['stft/loss'].item()) < 1e-3
|
|
|
|
|
assert abs(
|
|
|
|
|
loss_fn_2(recons.audio_data, signal.audio_data).item() -
|
|
|
|
|
loss_origin['mel/loss'].item()) < 1e-3
|
|
|
|
|
|
|
|
|
|
loss_1 = loss_fn_1(recons.audio_data, signal.audio_data)
|
|
|
|
|
loss_1.backward()
|
|
|
|
|
loss_1_grad = signal.audio_data.grad.sum()
|
|
|
|
|
|
|
|
|
|
assert abs(loss_1.item() - loss_origin['stft/loss']
|
|
|
|
|
.item()) / loss_1.item() < 1e-5
|
|
|
|
|
assert abs(loss_1_grad.item() - loss_origin['stft/grad'].sum()
|
|
|
|
|
.item()) / loss_1_grad.item() < 1e-5
|
|
|
|
|
|
|
|
|
|
signal.audio_data.clear_grad()
|
|
|
|
|
recons.audio_data.clear_grad()
|
|
|
|
|
|
|
|
|
|
loss_2 = loss_fn_2(recons.audio_data, signal.audio_data)
|
|
|
|
|
loss_2.backward()
|
|
|
|
|
loss_2_grad = signal.audio_data.grad.sum()
|
|
|
|
|
|
|
|
|
|
assert abs((loss_2.item() - loss_origin['mel/loss'].item()) /
|
|
|
|
|
loss_2.item()) < 1e-5
|
|
|
|
|
assert abs((loss_2_grad.item() - loss_origin['mel/grad'].sum().item()) /
|
|
|
|
|
loss_2_grad.item()) < 1e-5
|
|
|
|
|