|
|
@ -92,14 +92,14 @@ class TestFFTConv1d(_BaseTest):
|
|
|
|
|
|
|
|
|
|
|
|
def test_module(self):
|
|
|
|
def test_module(self):
|
|
|
|
x = paddle.randn([16, 4, 1024])
|
|
|
|
x = paddle.randn([16, 4, 1024])
|
|
|
|
mod = FFTConv1D(4, 5, 8, bias=True)
|
|
|
|
mod = FFTConv1D(4, 5, 8, bias_attr=True)
|
|
|
|
mod(x)
|
|
|
|
mod(x)
|
|
|
|
mod = FFTConv1D(4, 5, 8, bias=False)
|
|
|
|
mod = FFTConv1D(4, 5, 8, bias_attr=False)
|
|
|
|
mod(x)
|
|
|
|
mod(x)
|
|
|
|
|
|
|
|
|
|
|
|
def test_dynamic_graph(self):
|
|
|
|
def test_dynamic_graph(self):
|
|
|
|
x = paddle.randn([16, 4, 1024])
|
|
|
|
x = paddle.randn([16, 4, 1024])
|
|
|
|
mod = FFTConv1D(4, 5, 8, bias=True)
|
|
|
|
mod = FFTConv1D(4, 5, 8, bias_attr=True)
|
|
|
|
self.assertEqual(list(mod(x).shape), [16, 5, 1024 - 8 + 1])
|
|
|
|
self.assertEqual(list(mod(x).shape), [16, 5, 1024 - 8 + 1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|