You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/audio/tests/audiotools/core/test_fftconv.py

86 lines
2.8 KiB

# MIT License, Copyright (c) 2020 Alexandre Défossez.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from julius(https://github.com/adefossez/julius/blob/main/tests/test_fftconv.py)
import random
import sys
import unittest
import paddle
import paddle.nn.functional as F
from audio.audiotools.core import fft_conv1d
from audio.audiotools.core import FFTConv1D
TOLERANCE = 1e-4 # as relative delta in percentage
class _BaseTest(unittest.TestCase):
def setUp(self):
paddle.seed(1234)
random.seed(1234)
def assertSimilar(self, a, b, msg=None, tol=TOLERANCE):
delta = 100 * paddle.norm(a - b, p=2) / paddle.norm(b, p=2)
self.assertLessEqual(delta.numpy(), tol, msg)
def compare_paddle(self, *args, msg=None, tol=TOLERANCE, **kwargs):
y_ref = F.conv1d(*args, **kwargs)
y = fft_conv1d(*args, **kwargs)
self.assertEqual(list(y.shape), list(y_ref.shape), msg)
self.assertSimilar(y, y_ref, msg, tol)
class TestFFTConv1d(_BaseTest):
def test_same_as_paddle(self):
for _ in range(5):
kernel_size = random.randrange(4, 128)
batch_size = random.randrange(1, 6)
length = random.randrange(kernel_size, 1024)
chin = random.randrange(1, 12)
chout = random.randrange(1, 12)
bias = random.random() < 0.5
if random.random() < 0.5:
padding = 0
else:
padding = random.randrange(kernel_size // 2, 2 * kernel_size)
x = paddle.randn([batch_size, chin, length])
w = paddle.randn([chout, chin, kernel_size])
keys = ["length", "kernel_size", "chin", "chout", "bias"]
loc = locals()
state = {key: loc[key] for key in keys}
if bias:
bias = paddle.randn([chout])
else:
bias = None
for stride in [1, 2, 5]:
state["stride"] = stride
self.compare_paddle(
x, w, bias, stride, padding, msg=repr(state))
def test_small_input(self):
x = paddle.randn([1, 5, 19])
w = paddle.randn([10, 5, 32])
with self.assertRaises(RuntimeError):
fft_conv1d(x, w)
x = paddle.randn([1, 5, 19])
w = paddle.randn([10, 5, 19])
self.assertEqual(list(fft_conv1d(x, w).shape), [1, 10, 1])
def test_module(self):
x = paddle.randn([16, 4, 1024])
mod = FFTConv1D(4, 5, 8, bias_attr=True)
mod(x)
mod = FFTConv1D(4, 5, 8, bias_attr=False)
mod(x)
def test_dynamic_graph(self):
x = paddle.randn([16, 4, 1024])
mod = FFTConv1D(4, 5, 8, bias_attr=True)
self.assertEqual(list(mod(x).shape), [16, 5, 1024 - 8 + 1])
if __name__ == "__main__":
unittest.main()