From f3c00632c736a5acabf501c758f314555895fe77 Mon Sep 17 00:00:00 2001 From: drryanhuang Date: Wed, 18 Dec 2024 11:19:15 +0000 Subject: [PATCH] fix sth --- paddlespeech/t2s/modules/fftconv1d.py | 47 +++++++++++++-------------- tests/unit/tts/test_fftconv1d.py | 12 +++---- 2 files changed, 29 insertions(+), 30 deletions(-) diff --git a/paddlespeech/t2s/modules/fftconv1d.py b/paddlespeech/t2s/modules/fftconv1d.py index 83877ffad..cbdb84bda 100644 --- a/paddlespeech/t2s/modules/fftconv1d.py +++ b/paddlespeech/t2s/modules/fftconv1d.py @@ -20,35 +20,37 @@ import paddle import paddle.nn as nn import paddle.nn.functional as F +from ...utils import satisfy_paddle_version + __all__ = [ "fft_conv1d", "FFTConv1D", ] -def __unfold(_input, kernel_size: int, stride: int): +def __unfold(x, kernel_size: int, stride: int): """1D only unfolding similar to the one from Paddlepaddle. Notes ------ - Given an _input tensor of size `[*, T]` this will return + Given a tensor `x` of size `[*, T]` this will return a tensor `[*, F, K]` with `K` the kernel size, and `F` the number of frames. The i-th frame is a view onto `i * stride: i * stride + kernel_size`. - This will automatically pad the _input to cover at least once all entries in `_input`. + This will automatically pad `x` to cover at least once all entries in `x`. Args: - _input (Tensor): + x (Tensor): tensor for which to return the frames. kernel_size (int): size of each frame. stride (int): stride between each frame. """ - shape = list(_input.shape) + shape = list(x.shape) length = shape.pop(-1) n_frames = math.ceil((max(length, kernel_size) - kernel_size) / stride) + 1 tgt_length = (n_frames - 1) * stride + kernel_size - padded = F.pad(_input, (0, tgt_length - length), data_format="NCL") + padded = F.pad(x, (0, tgt_length - length), data_format="NCL") strides: typing.List[int] = [] for dim in range(padded.dim()): strides.append(padded.strides[dim]) @@ -58,7 +60,7 @@ def __unfold(_input, kernel_size: int, stride: int): def fft_conv1d( - _input: paddle.Tensor, + x: paddle.Tensor, weight: paddle.Tensor, bias: Optional[paddle.Tensor]=None, stride: int=1, @@ -77,8 +79,8 @@ def fft_conv1d( more memory than the default Conv1d implementation. Args: - _input (Tensor): - _input signal of shape `[B, C, T]`. + x (Tensor): + x signal of shape `[B, C, T]`. weight (Tensor): weight of the convolution `[D, C, K]` with `D` the number of output channels. bias (Tensor or None): @@ -86,17 +88,17 @@ def fft_conv1d( stride (int): stride of convolution. padding (int): - padding to apply to the _input. + padding to apply to x. block_ratio (float): - can be tuned for speed. The _input is splitted in chunks with a size of `int(block_ratio * kernel_size)`. + can be tuned for speed. x is splitted in chunks with a size of `int(block_ratio * kernel_size)`. Shape: - - Inputs: `_input` is `[B, C, T]`, `weight` is `[D, C, K]` and bias is `[D]`. + - Inputs: `x` is `[B, C, T]`, `weight` is `[D, C, K]` and bias is `[D]`. - Output: `(*, T)` """ - _input = F.pad(_input, (padding, padding), data_format="NCL") - batch, _, length = _input.shape + x = F.pad(x, (padding, padding), data_format="NCL") + batch, _, length = x.shape out_channels, _, kernel_size = weight.shape if length < kernel_size: @@ -118,8 +120,8 @@ def fft_conv1d( weight_z = paddle.fft.rfft(weight, axis=-1) - # We pad the _input and get the different frames, on which - frames = __unfold(_input, block_size, fold_stride) + # We pad `x` and get the different frames, on which + frames = __unfold(x, block_size, fold_stride) frames_z = paddle.fft.rfft(frames, axis=-1) weight_z_coml = paddle.conj(weight_z) @@ -152,7 +154,7 @@ class FFTConv1D(paddle.nn.Layer): Args: in_channels (int): - number of _input channels. + number of `x` channels. out_channels (int): number of output channels. kernel_size (int): @@ -160,7 +162,7 @@ class FFTConv1D(paddle.nn.Layer): stride (int): stride of convolution. padding (int): - padding to apply to the _input. + padding to apply to `x`. bias_attr (bool): if True, use a bias term. @@ -200,16 +202,13 @@ class FFTConv1D(paddle.nn.Layer): else: self.bias = None - def forward(self, _input: paddle.Tensor): - return fft_conv1d(_input, self.weight, self.bias, self.stride, - self.padding) + def forward(self, x: paddle.Tensor): + return fft_conv1d(x, self.weight, self.bias, self.stride, self.padding) # Currently, the API unfold in Paddle is extremely slow, so __unfold is implemented # using the `.strides` and `.as_strided` APIs. However, these are only supported in # Paddle version 2.6 and above, so F.conv1d and Conv1D are used as replacements. -version = paddle.__version__ - -if version < '2.6': +if not satisfy_paddle_version('2.6'): fft_conv1d = F.conv1d FFTConv1D = nn.Conv1D diff --git a/tests/unit/tts/test_fftconv1d.py b/tests/unit/tts/test_fftconv1d.py index 518600142..88ea397ec 100644 --- a/tests/unit/tts/test_fftconv1d.py +++ b/tests/unit/tts/test_fftconv1d.py @@ -59,7 +59,7 @@ class TestFFTConv1D(unittest.TestCase): out_conv1d = conv1d(x) out_fft_conv1d = fft_conv1d(x) self.assertTrue( - np.allclose(out_conv1d.numpy(), out_fft_conv1d.numpy(), atol=1e-5)) + np.allclose(out_conv1d.numpy(), out_fft_conv1d.numpy(), atol=1e-6)) def test_fft_conv1d_vs_conv1d_no_padding(self): x, conv1d, fft_conv1d = self._init_models( @@ -68,7 +68,7 @@ class TestFFTConv1D(unittest.TestCase): out_conv1d = conv1d(x) out_fft_conv1d = fft_conv1d(x) self.assertTrue( - np.allclose(out_conv1d.numpy(), out_fft_conv1d.numpy(), atol=1e-5)) + np.allclose(out_conv1d.numpy(), out_fft_conv1d.numpy(), atol=1e-6)) def test_fft_conv1d_vs_conv1d_large_kernel(self): kernel_size = 256 @@ -79,7 +79,7 @@ class TestFFTConv1D(unittest.TestCase): out_conv1d = conv1d(x) out_fft_conv1d = fft_conv1d(x) self.assertTrue( - np.allclose(out_conv1d.numpy(), out_fft_conv1d.numpy(), atol=1e-5)) + np.allclose(out_conv1d.numpy(), out_fft_conv1d.numpy(), atol=1e-6)) def test_fft_conv1d_vs_conv1d_stride_2(self): x, conv1d, fft_conv1d = self._init_models( @@ -88,7 +88,7 @@ class TestFFTConv1D(unittest.TestCase): out_conv1d = conv1d(x) out_fft_conv1d = fft_conv1d(x) self.assertTrue( - np.allclose(out_conv1d.numpy(), out_fft_conv1d.numpy(), atol=1e-5)) + np.allclose(out_conv1d.numpy(), out_fft_conv1d.numpy(), atol=1e-6)) def test_fft_conv1d_vs_conv1d_different_input_length(self): input_length = 1024 @@ -99,7 +99,7 @@ class TestFFTConv1D(unittest.TestCase): out_conv1d = conv1d(x) out_fft_conv1d = fft_conv1d(x) self.assertTrue( - np.allclose(out_conv1d.numpy(), out_fft_conv1d.numpy(), atol=1e-5)) + np.allclose(out_conv1d.numpy(), out_fft_conv1d.numpy(), atol=1e-6)) def test_fft_conv1d_vs_conv1d_no_bias(self): conv1d = paddle.nn.Conv1D( @@ -121,7 +121,7 @@ class TestFFTConv1D(unittest.TestCase): out_conv1d = conv1d(x) out_fft_conv1d = fft_conv1d(x) self.assertTrue( - np.allclose(out_conv1d.numpy(), out_fft_conv1d.numpy(), atol=1e-5)) + np.allclose(out_conv1d.numpy(), out_fft_conv1d.numpy(), atol=1e-6)) if __name__ == '__main__':