diff --git a/paddlespeech/t2s/modules/__init__.py b/paddlespeech/t2s/modules/__init__.py index 1e3312002..88e74db7a 100644 --- a/paddlespeech/t2s/modules/__init__.py +++ b/paddlespeech/t2s/modules/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from .conv import * +from .fftconv1d import * from .geometry import * from .losses import * from .positional_encoding import * diff --git a/paddlespeech/t2s/modules/fftconv1d.py b/paddlespeech/t2s/modules/fftconv1d.py new file mode 100644 index 000000000..cbdb84bda --- /dev/null +++ b/paddlespeech/t2s/modules/fftconv1d.py @@ -0,0 +1,214 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import typing +from typing import Optional +from typing import Sequence + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from ...utils import satisfy_paddle_version + +__all__ = [ + "fft_conv1d", + "FFTConv1D", +] + + +def __unfold(x, kernel_size: int, stride: int): + """1D only unfolding similar to the one from Paddlepaddle. + + Notes + ------ + Given a tensor `x` of size `[*, T]` this will return + a tensor `[*, F, K]` with `K` the kernel size, and `F` the number + of frames. The i-th frame is a view onto `i * stride: i * stride + kernel_size`. + This will automatically pad `x` to cover at least once all entries in `x`. + + Args: + x (Tensor): + tensor for which to return the frames. + kernel_size (int): + size of each frame. + stride (int): + stride between each frame. + """ + shape = list(x.shape) + length = shape.pop(-1) + n_frames = math.ceil((max(length, kernel_size) - kernel_size) / stride) + 1 + tgt_length = (n_frames - 1) * stride + kernel_size + padded = F.pad(x, (0, tgt_length - length), data_format="NCL") + strides: typing.List[int] = [] + for dim in range(padded.dim()): + strides.append(padded.strides[dim]) + assert strides.pop(-1) == 1, "data should be contiguous" + strides = strides + [stride, 1] + return padded.as_strided(shape + [n_frames, kernel_size], strides) + + +def fft_conv1d( + x: paddle.Tensor, + weight: paddle.Tensor, + bias: Optional[paddle.Tensor]=None, + stride: int=1, + padding: int=0, + block_ratio: float=5, ): + """ + Same as `paddle.nn.functional.conv1d` but using FFT for the convolution. + Please check PaddlePaddle documentation for more information. + + Notes + ------ + This function is faster than `paddle.nn.functional.conv1d` only in specific cases. + Typically, the kernel size should be of the order of 256 to see any real gain, + for a stride of 1. + Dilation and groups are not supported at the moment. This function might use + more memory than the default Conv1d implementation. + + Args: + x (Tensor): + x signal of shape `[B, C, T]`. + weight (Tensor): + weight of the convolution `[D, C, K]` with `D` the number of output channels. + bias (Tensor or None): + if not None, bias term for the convolution. + stride (int): + stride of convolution. + padding (int): + padding to apply to x. + block_ratio (float): + can be tuned for speed. x is splitted in chunks with a size of `int(block_ratio * kernel_size)`. + + Shape: + + - Inputs: `x` is `[B, C, T]`, `weight` is `[D, C, K]` and bias is `[D]`. + - Output: `(*, T)` + """ + x = F.pad(x, (padding, padding), data_format="NCL") + batch, _, length = x.shape + out_channels, _, kernel_size = weight.shape + + if length < kernel_size: + raise RuntimeError( + f"Input should be at least as large as the kernel size {kernel_size}, " + f"but it is only {length} samples long.") + if block_ratio < 1: + raise RuntimeError("Block ratio must be greater than 1.") + + block_size: int = min(int(kernel_size * block_ratio), length) + fold_stride = block_size - kernel_size + 1 + # weight = pad_to(weight, block_size) + + weight = F.pad( + weight, (0, block_size - weight.shape[-1]), + mode="constant", + value=0.0, + data_format="NCL") + + weight_z = paddle.fft.rfft(weight, axis=-1) + + # We pad `x` and get the different frames, on which + frames = __unfold(x, block_size, fold_stride) + + frames_z = paddle.fft.rfft(frames, axis=-1) + weight_z_coml = paddle.conj(weight_z) + out_z = paddle.einsum("bcft,dct->bdft", frames_z, weight_z_coml) + out = paddle.fft.irfft(out_z, n=block_size, axis=-1) + + # The last bit is invalid, because FFT will do a circular convolution. + out = out[..., :-kernel_size + 1] + out = out.reshape([batch, out_channels, -1]) + out = out[..., ::stride] + target_length = (length - kernel_size) // stride + 1 + out = out[..., :target_length] + if bias is not None: + out += bias[:, None] + return out + + +class FFTConv1D(paddle.nn.Layer): + """ + Same as `paddle.nn.Conv1D` but based on a custom FFT-based convolution. + Please check PaddlePaddle documentation for more information on `paddle.nn.Conv1D`. + + Notes + ------ + This module is faster than `paddle.nn.Conv1D` only in specific cases. + Typically, `kernel_size` should be of the order of 256 to see any real gain, + for a stride of 1. + Dilation and groups are not supported at the moment. This module might use + more memory than the default Conv1D implementation. + + Args: + in_channels (int): + number of `x` channels. + out_channels (int): + number of output channels. + kernel_size (int): + kernel size of convolution. + stride (int): + stride of convolution. + padding (int): + padding to apply to `x`. + bias_attr (bool): + if True, use a bias term. + + Examples: + >>> fftconv = FFTConv1D(12, 24, 128, 4) + >>> x = paddle.randn([4, 12, 1024]) + >>> print(list(fftconv(x).shape)) + [4, 24, 225] + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int=1, + padding: int=0, + bias_attr: bool=True, ): + super(FFTConv1D, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + + # Create a Conv1D layer to initialize weights and bias + conv = paddle.nn.Conv1D( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias_attr=bias_attr) + self.weight = conv.weight + if bias_attr: + self.bias = conv.bias + else: + self.bias = None + + def forward(self, x: paddle.Tensor): + return fft_conv1d(x, self.weight, self.bias, self.stride, self.padding) + + +# Currently, the API unfold in Paddle is extremely slow, so __unfold is implemented +# using the `.strides` and `.as_strided` APIs. However, these are only supported in +# Paddle version 2.6 and above, so F.conv1d and Conv1D are used as replacements. +if not satisfy_paddle_version('2.6'): + fft_conv1d = F.conv1d + FFTConv1D = nn.Conv1D diff --git a/tests/unit/ci.sh b/tests/unit/ci.sh index 72b4678d6..daf40f721 100644 --- a/tests/unit/ci.sh +++ b/tests/unit/ci.sh @@ -14,6 +14,7 @@ function main(){ cd ${speech_ci_path}/tts python test_data_table.py python test_enfrontend.py + python test_fftconv1d.py python test_mixfrontend.py echo "End TTS" diff --git a/tests/unit/tts/test_fftconv1d.py b/tests/unit/tts/test_fftconv1d.py new file mode 100644 index 000000000..88ea397ec --- /dev/null +++ b/tests/unit/tts/test_fftconv1d.py @@ -0,0 +1,128 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import unittest + +import numpy as np +import paddle +from paddle.nn import Conv1D + +from paddlespeech.t2s.modules import fft_conv1d +from paddlespeech.t2s.modules import FFTConv1D + + +class TestFFTConv1D(unittest.TestCase): + def setUp(self): + self.batch_size = 4 + self.in_channels = 3 + self.out_channels = 16 + self.kernel_size = 5 + self.stride = 1 + self.padding = 1 + self.input_length = 32 + + def _init_models(self, in_channels, out_channels, kernel_size, stride, + padding): + x = paddle.randn([self.batch_size, in_channels, self.input_length]) + conv1d = paddle.nn.Conv1D( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding) + fft_conv1d = FFTConv1D( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding) + fft_conv1d.weight.set_value(conv1d.weight.numpy()) + if conv1d.bias is not None: + fft_conv1d.bias.set_value(conv1d.bias.numpy()) + return x, conv1d, fft_conv1d + + def test_fft_conv1d_vs_conv1d_default(self): + x, conv1d, fft_conv1d = self._init_models( + self.in_channels, self.out_channels, self.kernel_size, self.stride, + self.padding) + out_conv1d = conv1d(x) + out_fft_conv1d = fft_conv1d(x) + self.assertTrue( + np.allclose(out_conv1d.numpy(), out_fft_conv1d.numpy(), atol=1e-6)) + + def test_fft_conv1d_vs_conv1d_no_padding(self): + x, conv1d, fft_conv1d = self._init_models( + self.in_channels, self.out_channels, self.kernel_size, self.stride, + 0) + out_conv1d = conv1d(x) + out_fft_conv1d = fft_conv1d(x) + self.assertTrue( + np.allclose(out_conv1d.numpy(), out_fft_conv1d.numpy(), atol=1e-6)) + + def test_fft_conv1d_vs_conv1d_large_kernel(self): + kernel_size = 256 + padding = kernel_size - 1 + x, conv1d, fft_conv1d = self._init_models( + self.in_channels, self.out_channels, kernel_size, self.stride, + padding) + out_conv1d = conv1d(x) + out_fft_conv1d = fft_conv1d(x) + self.assertTrue( + np.allclose(out_conv1d.numpy(), out_fft_conv1d.numpy(), atol=1e-6)) + + def test_fft_conv1d_vs_conv1d_stride_2(self): + x, conv1d, fft_conv1d = self._init_models( + self.in_channels, self.out_channels, self.kernel_size, 2, + self.padding) + out_conv1d = conv1d(x) + out_fft_conv1d = fft_conv1d(x) + self.assertTrue( + np.allclose(out_conv1d.numpy(), out_fft_conv1d.numpy(), atol=1e-6)) + + def test_fft_conv1d_vs_conv1d_different_input_length(self): + input_length = 1024 + x, conv1d, fft_conv1d = self._init_models( + self.in_channels, self.out_channels, self.kernel_size, self.stride, + self.padding) + x = paddle.randn([self.batch_size, self.in_channels, input_length]) + out_conv1d = conv1d(x) + out_fft_conv1d = fft_conv1d(x) + self.assertTrue( + np.allclose(out_conv1d.numpy(), out_fft_conv1d.numpy(), atol=1e-6)) + + def test_fft_conv1d_vs_conv1d_no_bias(self): + conv1d = paddle.nn.Conv1D( + self.in_channels, + self.out_channels, + self.kernel_size, + stride=self.stride, + padding=self.padding, + bias_attr=False) + fft_conv1d = FFTConv1D( + self.in_channels, + self.out_channels, + self.kernel_size, + stride=self.stride, + padding=self.padding, + bias_attr=False) + fft_conv1d.weight.set_value(conv1d.weight.numpy()) + x = paddle.randn([self.batch_size, self.in_channels, self.input_length]) + out_conv1d = conv1d(x) + out_fft_conv1d = fft_conv1d(x) + self.assertTrue( + np.allclose(out_conv1d.numpy(), out_fft_conv1d.numpy(), atol=1e-6)) + + +if __name__ == '__main__': + unittest.main()