pull/3947/head
drryanhuang 9 months ago
parent 136f7c5c4c
commit f3c00632c7

@ -20,35 +20,37 @@ import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from ...utils import satisfy_paddle_version
__all__ = [ __all__ = [
"fft_conv1d", "fft_conv1d",
"FFTConv1D", "FFTConv1D",
] ]
def __unfold(_input, kernel_size: int, stride: int): def __unfold(x, kernel_size: int, stride: int):
"""1D only unfolding similar to the one from Paddlepaddle. """1D only unfolding similar to the one from Paddlepaddle.
Notes Notes
------ ------
Given an _input tensor of size `[*, T]` this will return Given a tensor `x` of size `[*, T]` this will return
a tensor `[*, F, K]` with `K` the kernel size, and `F` the number a tensor `[*, F, K]` with `K` the kernel size, and `F` the number
of frames. The i-th frame is a view onto `i * stride: i * stride + kernel_size`. of frames. The i-th frame is a view onto `i * stride: i * stride + kernel_size`.
This will automatically pad the _input to cover at least once all entries in `_input`. This will automatically pad `x` to cover at least once all entries in `x`.
Args: Args:
_input (Tensor): x (Tensor):
tensor for which to return the frames. tensor for which to return the frames.
kernel_size (int): kernel_size (int):
size of each frame. size of each frame.
stride (int): stride (int):
stride between each frame. stride between each frame.
""" """
shape = list(_input.shape) shape = list(x.shape)
length = shape.pop(-1) length = shape.pop(-1)
n_frames = math.ceil((max(length, kernel_size) - kernel_size) / stride) + 1 n_frames = math.ceil((max(length, kernel_size) - kernel_size) / stride) + 1
tgt_length = (n_frames - 1) * stride + kernel_size tgt_length = (n_frames - 1) * stride + kernel_size
padded = F.pad(_input, (0, tgt_length - length), data_format="NCL") padded = F.pad(x, (0, tgt_length - length), data_format="NCL")
strides: typing.List[int] = [] strides: typing.List[int] = []
for dim in range(padded.dim()): for dim in range(padded.dim()):
strides.append(padded.strides[dim]) strides.append(padded.strides[dim])
@ -58,7 +60,7 @@ def __unfold(_input, kernel_size: int, stride: int):
def fft_conv1d( def fft_conv1d(
_input: paddle.Tensor, x: paddle.Tensor,
weight: paddle.Tensor, weight: paddle.Tensor,
bias: Optional[paddle.Tensor]=None, bias: Optional[paddle.Tensor]=None,
stride: int=1, stride: int=1,
@ -77,8 +79,8 @@ def fft_conv1d(
more memory than the default Conv1d implementation. more memory than the default Conv1d implementation.
Args: Args:
_input (Tensor): x (Tensor):
_input signal of shape `[B, C, T]`. x signal of shape `[B, C, T]`.
weight (Tensor): weight (Tensor):
weight of the convolution `[D, C, K]` with `D` the number of output channels. weight of the convolution `[D, C, K]` with `D` the number of output channels.
bias (Tensor or None): bias (Tensor or None):
@ -86,17 +88,17 @@ def fft_conv1d(
stride (int): stride (int):
stride of convolution. stride of convolution.
padding (int): padding (int):
padding to apply to the _input. padding to apply to x.
block_ratio (float): block_ratio (float):
can be tuned for speed. The _input is splitted in chunks with a size of `int(block_ratio * kernel_size)`. can be tuned for speed. x is splitted in chunks with a size of `int(block_ratio * kernel_size)`.
Shape: Shape:
- Inputs: `_input` is `[B, C, T]`, `weight` is `[D, C, K]` and bias is `[D]`. - Inputs: `x` is `[B, C, T]`, `weight` is `[D, C, K]` and bias is `[D]`.
- Output: `(*, T)` - Output: `(*, T)`
""" """
_input = F.pad(_input, (padding, padding), data_format="NCL") x = F.pad(x, (padding, padding), data_format="NCL")
batch, _, length = _input.shape batch, _, length = x.shape
out_channels, _, kernel_size = weight.shape out_channels, _, kernel_size = weight.shape
if length < kernel_size: if length < kernel_size:
@ -118,8 +120,8 @@ def fft_conv1d(
weight_z = paddle.fft.rfft(weight, axis=-1) weight_z = paddle.fft.rfft(weight, axis=-1)
# We pad the _input and get the different frames, on which # We pad `x` and get the different frames, on which
frames = __unfold(_input, block_size, fold_stride) frames = __unfold(x, block_size, fold_stride)
frames_z = paddle.fft.rfft(frames, axis=-1) frames_z = paddle.fft.rfft(frames, axis=-1)
weight_z_coml = paddle.conj(weight_z) weight_z_coml = paddle.conj(weight_z)
@ -152,7 +154,7 @@ class FFTConv1D(paddle.nn.Layer):
Args: Args:
in_channels (int): in_channels (int):
number of _input channels. number of `x` channels.
out_channels (int): out_channels (int):
number of output channels. number of output channels.
kernel_size (int): kernel_size (int):
@ -160,7 +162,7 @@ class FFTConv1D(paddle.nn.Layer):
stride (int): stride (int):
stride of convolution. stride of convolution.
padding (int): padding (int):
padding to apply to the _input. padding to apply to `x`.
bias_attr (bool): bias_attr (bool):
if True, use a bias term. if True, use a bias term.
@ -200,16 +202,13 @@ class FFTConv1D(paddle.nn.Layer):
else: else:
self.bias = None self.bias = None
def forward(self, _input: paddle.Tensor): def forward(self, x: paddle.Tensor):
return fft_conv1d(_input, self.weight, self.bias, self.stride, return fft_conv1d(x, self.weight, self.bias, self.stride, self.padding)
self.padding)
# Currently, the API unfold in Paddle is extremely slow, so __unfold is implemented # Currently, the API unfold in Paddle is extremely slow, so __unfold is implemented
# using the `.strides` and `.as_strided` APIs. However, these are only supported in # using the `.strides` and `.as_strided` APIs. However, these are only supported in
# Paddle version 2.6 and above, so F.conv1d and Conv1D are used as replacements. # Paddle version 2.6 and above, so F.conv1d and Conv1D are used as replacements.
version = paddle.__version__ if not satisfy_paddle_version('2.6'):
if version < '2.6':
fft_conv1d = F.conv1d fft_conv1d = F.conv1d
FFTConv1D = nn.Conv1D FFTConv1D = nn.Conv1D

@ -59,7 +59,7 @@ class TestFFTConv1D(unittest.TestCase):
out_conv1d = conv1d(x) out_conv1d = conv1d(x)
out_fft_conv1d = fft_conv1d(x) out_fft_conv1d = fft_conv1d(x)
self.assertTrue( self.assertTrue(
np.allclose(out_conv1d.numpy(), out_fft_conv1d.numpy(), atol=1e-5)) np.allclose(out_conv1d.numpy(), out_fft_conv1d.numpy(), atol=1e-6))
def test_fft_conv1d_vs_conv1d_no_padding(self): def test_fft_conv1d_vs_conv1d_no_padding(self):
x, conv1d, fft_conv1d = self._init_models( x, conv1d, fft_conv1d = self._init_models(
@ -68,7 +68,7 @@ class TestFFTConv1D(unittest.TestCase):
out_conv1d = conv1d(x) out_conv1d = conv1d(x)
out_fft_conv1d = fft_conv1d(x) out_fft_conv1d = fft_conv1d(x)
self.assertTrue( self.assertTrue(
np.allclose(out_conv1d.numpy(), out_fft_conv1d.numpy(), atol=1e-5)) np.allclose(out_conv1d.numpy(), out_fft_conv1d.numpy(), atol=1e-6))
def test_fft_conv1d_vs_conv1d_large_kernel(self): def test_fft_conv1d_vs_conv1d_large_kernel(self):
kernel_size = 256 kernel_size = 256
@ -79,7 +79,7 @@ class TestFFTConv1D(unittest.TestCase):
out_conv1d = conv1d(x) out_conv1d = conv1d(x)
out_fft_conv1d = fft_conv1d(x) out_fft_conv1d = fft_conv1d(x)
self.assertTrue( self.assertTrue(
np.allclose(out_conv1d.numpy(), out_fft_conv1d.numpy(), atol=1e-5)) np.allclose(out_conv1d.numpy(), out_fft_conv1d.numpy(), atol=1e-6))
def test_fft_conv1d_vs_conv1d_stride_2(self): def test_fft_conv1d_vs_conv1d_stride_2(self):
x, conv1d, fft_conv1d = self._init_models( x, conv1d, fft_conv1d = self._init_models(
@ -88,7 +88,7 @@ class TestFFTConv1D(unittest.TestCase):
out_conv1d = conv1d(x) out_conv1d = conv1d(x)
out_fft_conv1d = fft_conv1d(x) out_fft_conv1d = fft_conv1d(x)
self.assertTrue( self.assertTrue(
np.allclose(out_conv1d.numpy(), out_fft_conv1d.numpy(), atol=1e-5)) np.allclose(out_conv1d.numpy(), out_fft_conv1d.numpy(), atol=1e-6))
def test_fft_conv1d_vs_conv1d_different_input_length(self): def test_fft_conv1d_vs_conv1d_different_input_length(self):
input_length = 1024 input_length = 1024
@ -99,7 +99,7 @@ class TestFFTConv1D(unittest.TestCase):
out_conv1d = conv1d(x) out_conv1d = conv1d(x)
out_fft_conv1d = fft_conv1d(x) out_fft_conv1d = fft_conv1d(x)
self.assertTrue( self.assertTrue(
np.allclose(out_conv1d.numpy(), out_fft_conv1d.numpy(), atol=1e-5)) np.allclose(out_conv1d.numpy(), out_fft_conv1d.numpy(), atol=1e-6))
def test_fft_conv1d_vs_conv1d_no_bias(self): def test_fft_conv1d_vs_conv1d_no_bias(self):
conv1d = paddle.nn.Conv1D( conv1d = paddle.nn.Conv1D(
@ -121,7 +121,7 @@ class TestFFTConv1D(unittest.TestCase):
out_conv1d = conv1d(x) out_conv1d = conv1d(x)
out_fft_conv1d = fft_conv1d(x) out_fft_conv1d = fft_conv1d(x)
self.assertTrue( self.assertTrue(
np.allclose(out_conv1d.numpy(), out_fft_conv1d.numpy(), atol=1e-5)) np.allclose(out_conv1d.numpy(), out_fft_conv1d.numpy(), atol=1e-6))
if __name__ == '__main__': if __name__ == '__main__':

Loading…
Cancel
Save