|
|
|
@ -1,5 +1,4 @@
|
|
|
|
|
# File under the MIT license, see https://github.com/your_repo/your_license for details.
|
|
|
|
|
# Author: your_name, current_year
|
|
|
|
|
import random
|
|
|
|
|
import sys
|
|
|
|
|
import unittest
|
|
|
|
@ -21,14 +20,9 @@ class _BaseTest(unittest.TestCase):
|
|
|
|
|
delta = 100 * paddle.norm(a - b, p=2) / paddle.norm(b, p=2)
|
|
|
|
|
self.assertLessEqual(delta.numpy(), tol, msg)
|
|
|
|
|
|
|
|
|
|
def compare_paddle(self,
|
|
|
|
|
*args,
|
|
|
|
|
block_ratio=10,
|
|
|
|
|
msg=None,
|
|
|
|
|
tol=TOLERANCE,
|
|
|
|
|
**kwargs):
|
|
|
|
|
def compare_paddle(self, *args, msg=None, tol=TOLERANCE, **kwargs):
|
|
|
|
|
y_ref = F.conv1d(*args, **kwargs)
|
|
|
|
|
y = fft_conv1d(*args, block_ratio=block_ratio, **kwargs)
|
|
|
|
|
y = fft_conv1d(*args, **kwargs)
|
|
|
|
|
self.assertEqual(list(y.shape), list(y_ref.shape), msg)
|
|
|
|
|
self.assertSimilar(y, y_ref, msg, tol)
|
|
|
|
|
|
|
|
|
@ -41,7 +35,6 @@ class TestFFTConv1d(_BaseTest):
|
|
|
|
|
length = random.randrange(kernel_size, 1024)
|
|
|
|
|
chin = random.randrange(1, 12)
|
|
|
|
|
chout = random.randrange(1, 12)
|
|
|
|
|
block_ratio = random.choice([5, 10, 20])
|
|
|
|
|
bias = random.random() < 0.5
|
|
|
|
|
if random.random() < 0.5:
|
|
|
|
|
padding = 0
|
|
|
|
@ -49,9 +42,7 @@ class TestFFTConv1d(_BaseTest):
|
|
|
|
|
padding = random.randrange(kernel_size // 2, 2 * kernel_size)
|
|
|
|
|
x = paddle.randn([batch_size, chin, length])
|
|
|
|
|
w = paddle.randn([chout, chin, kernel_size])
|
|
|
|
|
keys = [
|
|
|
|
|
"length", "kernel_size", "chin", "chout", "block_ratio", "bias"
|
|
|
|
|
]
|
|
|
|
|
keys = ["length", "kernel_size", "chin", "chout", "bias"]
|
|
|
|
|
loc = locals()
|
|
|
|
|
state = {key: loc[key] for key in keys}
|
|
|
|
|
if bias:
|
|
|
|
@ -61,13 +52,7 @@ class TestFFTConv1d(_BaseTest):
|
|
|
|
|
for stride in [1, 2, 5]:
|
|
|
|
|
state["stride"] = stride
|
|
|
|
|
self.compare_paddle(
|
|
|
|
|
x,
|
|
|
|
|
w,
|
|
|
|
|
bias,
|
|
|
|
|
stride,
|
|
|
|
|
padding,
|
|
|
|
|
block_ratio=block_ratio,
|
|
|
|
|
msg=repr(state))
|
|
|
|
|
x, w, bias, stride, padding, msg=repr(state))
|
|
|
|
|
|
|
|
|
|
def test_small_input(self):
|
|
|
|
|
x = paddle.randn([1, 5, 19])
|
|
|
|
@ -79,16 +64,16 @@ class TestFFTConv1d(_BaseTest):
|
|
|
|
|
w = paddle.randn([10, 5, 19])
|
|
|
|
|
self.assertEqual(list(fft_conv1d(x, w).shape), [1, 10, 1])
|
|
|
|
|
|
|
|
|
|
def test_block_ratio(self):
|
|
|
|
|
x = paddle.randn([1, 5, 1024])
|
|
|
|
|
w = paddle.randn([10, 5, 19])
|
|
|
|
|
ref = fft_conv1d(x, w)
|
|
|
|
|
for block_ratio in [1, 5, 10, 20]:
|
|
|
|
|
y = fft_conv1d(x, w, block_ratio=block_ratio)
|
|
|
|
|
self.assertSimilar(y, ref, msg=str(block_ratio))
|
|
|
|
|
# def test_block_ratio(self):
|
|
|
|
|
# x = paddle.randn([1, 5, 1024])
|
|
|
|
|
# w = paddle.randn([10, 5, 19])
|
|
|
|
|
# ref = fft_conv1d(x, w)
|
|
|
|
|
# for block_ratio in [1, 5, 10, 20]:
|
|
|
|
|
# y = fft_conv1d(x, w, block_ratio=block_ratio)
|
|
|
|
|
# self.assertSimilar(y, ref, msg=str(block_ratio))
|
|
|
|
|
|
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
|
|
|
y = fft_conv1d(x, w, block_ratio=0.9)
|
|
|
|
|
# with self.assertRaises(RuntimeError):
|
|
|
|
|
# y = fft_conv1d(x, w, block_ratio=0.9)
|
|
|
|
|
|
|
|
|
|
def test_module(self):
|
|
|
|
|
x = paddle.randn([16, 4, 1024])
|
|
|
|
|