fix codestyle

pull/3988/head
cchenhaifeng 7 months ago
parent fef33fb42a
commit 400686439d

@ -30,14 +30,14 @@ def test_multi_scale_stft_loss():
x, y = get_input() x, y = get_input()
loss = MultiScaleSTFTLoss() loss = MultiScaleSTFTLoss()
pd_loss = loss(x, y) pd_loss = loss(x, y)
np.allclose(pd_loss.numpy(), 7.562150, rtol=1e-06) assert np.abs(pd_loss.numpy() - 7.562150) < 1e-06
def test_sisdr_loss(): def test_sisdr_loss():
x, y = get_input() x, y = get_input()
loss = SISDRLoss() loss = SISDRLoss()
pd_loss = loss(x, y) pd_loss = loss(x, y)
np.allclose(pd_loss.numpy(), -145.377640, rtol=1e-06) assert np.abs(pd_loss.numpy() - (-145.377640)) < 1e-06
def test_gan_loss(): def test_gan_loss():
@ -52,10 +52,10 @@ def test_gan_loss():
x, y = get_input() x, y = get_input()
loss = GANLoss(My_discriminator0()) loss = GANLoss(My_discriminator0())
pd_loss0, pd_loss1 = loss(x, y) pd_loss0, pd_loss1 = loss(x, y)
np.allclose(pd_loss0.numpy(), -0.102722, rtol=1e-06) assert np.abs(pd_loss0.numpy() - (-0.102722)) < 1e-06
np.allclose(pd_loss1.numpy(), -0.001027, rtol=1e-06) assert np.abs(pd_loss1.numpy() - (-0.001027)) < 1e-06
loss = GANLoss(My_discriminator1()) loss = GANLoss(My_discriminator1())
pd_loss0, _ = loss.generator_loss(x, y) pd_loss0, _ = loss.generator_loss(x, y)
np.allclose(pd_loss0.numpy(), 1.000199, rtol=1e-06) assert np.abs(pd_loss0.numpy() - 1.000199) < 1e-06
pd_loss = loss.discriminator_loss(x, y) pd_loss = loss.discriminator_loss(x, y)
np.allclose(pd_loss.numpy(), 1.000200, rtol=1e-06) assert np.abs(pd_loss.numpy() - 1.000200) < 1e-06

Loading…
Cancel
Save