diff --git a/paddleaudio/paddleaudio/features/layers.py b/paddleaudio/paddleaudio/features/layers.py index 4a2c1673a..6afd234a0 100644 --- a/paddleaudio/paddleaudio/features/layers.py +++ b/paddleaudio/paddleaudio/features/layers.py @@ -37,6 +37,7 @@ class Spectrogram(nn.Layer): hop_length: Optional[int]=None, win_length: Optional[int]=None, window: str='hann', + power: float=2.0, center: bool=True, pad_mode: str='reflect', dtype: str=paddle.float32): @@ -54,6 +55,7 @@ class Spectrogram(nn.Layer): The folllowing window names are supported: 'hamming','hann','kaiser','gaussian', 'exponential','triang','bohman','blackman','cosine','tukey','taylor'. The default value is 'hann' + power (float): Exponent for the magnitude spectrogram. The default value is 2.0. center (bool): if True, the signal is padded so that frame t is centered at x[t * hop_length]. If False, frame t begins at x[t * hop_length] The default value is True @@ -68,6 +70,9 @@ class Spectrogram(nn.Layer): """ super(Spectrogram, self).__init__() + assert power > 0, 'Power of spectrogram must be > 0.' + self.power = power + if win_length is None: win_length = n_fft @@ -85,7 +90,7 @@ class Spectrogram(nn.Layer): def forward(self, x): stft = self._stft(x) - spectrogram = paddle.square(paddle.abs(stft)) + spectrogram = paddle.pow(paddle.abs(stft), self.power) return spectrogram @@ -96,6 +101,7 @@ class MelSpectrogram(nn.Layer): hop_length: Optional[int]=None, win_length: Optional[int]=None, window: str='hann', + power: float=2.0, center: bool=True, pad_mode: str='reflect', n_mels: int=64, @@ -120,6 +126,7 @@ class MelSpectrogram(nn.Layer): The folllowing window names are supported: 'hamming','hann','kaiser','gaussian', 'exponential','triang','bohman','blackman','cosine','tukey','taylor'. The default value is 'hann' + power (float): Exponent for the magnitude spectrogram. The default value is 2.0. center(bool): if True, the signal is padded so that frame t is centered at x[t * hop_length]. If False, frame t begins at x[t * hop_length] The default value is True @@ -142,6 +149,7 @@ class MelSpectrogram(nn.Layer): hop_length=hop_length, win_length=win_length, window=window, + power=power, center=center, pad_mode=pad_mode, dtype=dtype) @@ -176,6 +184,7 @@ class LogMelSpectrogram(nn.Layer): hop_length: Optional[int]=None, win_length: Optional[int]=None, window: str='hann', + power: float=2.0, center: bool=True, pad_mode: str='reflect', n_mels: int=64, @@ -214,11 +223,8 @@ class LogMelSpectrogram(nn.Layer): htk (bool): whether to use HTK formula in computing fbank matrix. norm (str|float): the normalization type in computing fbank matrix. Slaney-style is used by default. You can specify norm=1.0/2.0 to use customized p-norm normalization. - ref_value (float): the reference value. If smaller than 1.0, the db level - amin (float): the minimum value of input magnitude, below which the input of the signal will be pulled up accordingly. - Otherwise, the db level is pushed down. - magnitude is clipped(to amin). For numerical stability, set amin to a larger value, - e.g., 1e-3. + ref_value (float): the reference value. If smaller than 1.0, the db level of the signal will be pulled up accordingly. Otherwise, the db level is pushed down. + amin (float): the minimum value of input magnitude, below which the input magnitude is clipped(to amin). top_db (float): the maximum db value of resulting spectrum, above which the spectrum is clipped(to top_db). dtype (str): the datatype of fbank matrix used in the transform. Use float64 to increase numerical @@ -232,6 +238,7 @@ class LogMelSpectrogram(nn.Layer): hop_length=hop_length, win_length=win_length, window=window, + power=power, center=center, pad_mode=pad_mode, n_mels=n_mels, @@ -246,7 +253,6 @@ class LogMelSpectrogram(nn.Layer): self.top_db = top_db def forward(self, x): - # import ipdb; ipdb.set_trace() mel_feature = self._melspectrogram(x) log_mel_feature = power_to_db( mel_feature, @@ -264,6 +270,7 @@ class MFCC(nn.Layer): hop_length: Optional[int]=None, win_length: Optional[int]=None, window: str='hann', + power: float=2.0, center: bool=True, pad_mode: str='reflect', n_mels: int=64, @@ -291,6 +298,7 @@ class MFCC(nn.Layer): The folllowing window names are supported: 'hamming','hann','kaiser','gaussian', 'exponential','triang','bohman','blackman','cosine','tukey','taylor'. The default value is 'hann' + power (float): Exponent for the magnitude spectrogram. The default value is 2.0. center (bool): if True, the signal is padded so that frame t is centered at x[t * hop_length]. If False, frame t begins at x[t * hop_length] The default value is True @@ -303,11 +311,8 @@ class MFCC(nn.Layer): htk (bool): whether to use HTK formula in computing fbank matrix. norm (str|float): the normalization type in computing fbank matrix. Slaney-style is used by default. You can specify norm=1.0/2.0 to use customized p-norm normalization. - ref_value (float): the reference value. If smaller than 1.0, the db level - amin (float): the minimum value of input magnitude, below which the input of the signal will be pulled up accordingly. - Otherwise, the db level is pushed down. - magnitude is clipped(to amin). For numerical stability, set amin to a larger value, - e.g., 1e-3. + ref_value (float): the reference value. If smaller than 1.0, the db level of the signal will be pulled up accordingly. Otherwise, the db level is pushed down. + amin (float): the minimum value of input magnitude, below which the input magnitude is clipped(to amin). top_db (float): the maximum db value of resulting spectrum, above which the spectrum is clipped(to top_db). dtype (str): the datatype of fbank matrix used in the transform. Use float64 to increase numerical @@ -322,6 +327,7 @@ class MFCC(nn.Layer): hop_length=hop_length, win_length=win_length, window=window, + power=power, center=center, pad_mode=pad_mode, n_mels=n_mels, diff --git a/paddleaudio/setup.py b/paddleaudio/setup.py index 7623443a6..b0ccd2eba 100644 --- a/paddleaudio/setup.py +++ b/paddleaudio/setup.py @@ -11,19 +11,50 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import glob +import os + import setuptools +from setuptools.command.install import install +from setuptools.command.test import test # set the version here VERSION = '0.2.0' +# Inspired by the example at https://pytest.org/latest/goodpractises.html +class TestCommand(test): + def finalize_options(self): + test.finalize_options(self) + self.test_args = [] + self.test_suite = True + + def run(self): + self.run_benchmark() + super(TestCommand, self).run() + + def run_tests(self): + # Run nose ensuring that argv simulates running nosetests directly + import nose + nose.run_exit(argv=['nosetests', '-w', 'tests']) + + def run_benchmark(self): + for benchmark_item in glob.glob('tests/benchmark/*py'): + os.system(f'pytest {benchmark_item}') + + +class InstallCommand(install): + def run(self): + install.run(self) + + def write_version_py(filename='paddleaudio/__init__.py'): import paddleaudio if hasattr(paddleaudio, "__version__") and paddleaudio.__version__ == VERSION: return with open(filename, "a") as f: - f.write(f"\n__version__ = '{VERSION}'\n") + f.write(f"__version__ = '{VERSION}'") def remove_version_py(filename='paddleaudio/__init__.py'): @@ -61,6 +92,14 @@ setuptools.setup( 'colorlog', 'dtaidistance >= 2.3.6', 'mcd >= 0.4', - ], ) + ], + setup_requires=[ + 'nose', 'librosa==0.8.1', 'soundfile==0.10.3.post1', + 'torchaudio==0.10.2', 'pytest-benchmark' + ], + cmdclass={ + 'install': InstallCommand, + 'test': TestCommand, + }, ) remove_version_py() diff --git a/paddleaudio/tests/backends/__init__.py b/paddleaudio/tests/backends/__init__.py new file mode 100644 index 000000000..97043fd7b --- /dev/null +++ b/paddleaudio/tests/backends/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddleaudio/tests/backends/base.py b/paddleaudio/tests/backends/base.py new file mode 100644 index 000000000..a67191887 --- /dev/null +++ b/paddleaudio/tests/backends/base.py @@ -0,0 +1,34 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import unittest +import urllib.request + +mono_channel_wav = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav' +multi_channels_wav = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/cat.wav' + + +class BackendTest(unittest.TestCase): + def setUp(self): + self.initWavInput() + + def initWavInput(self): + self.files = [] + for url in [mono_channel_wav, multi_channels_wav]: + if not os.path.isfile(os.path.basename(url)): + urllib.request.urlretrieve(url, os.path.basename(url)) + self.files.append(os.path.basename(url)) + + def initParmas(self): + raise NotImplementedError diff --git a/paddleaudio/tests/backends/soundfile/__init__.py b/paddleaudio/tests/backends/soundfile/__init__.py new file mode 100644 index 000000000..97043fd7b --- /dev/null +++ b/paddleaudio/tests/backends/soundfile/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddleaudio/tests/backends/soundfile/test_io.py b/paddleaudio/tests/backends/soundfile/test_io.py new file mode 100644 index 000000000..0f7580a40 --- /dev/null +++ b/paddleaudio/tests/backends/soundfile/test_io.py @@ -0,0 +1,73 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import filecmp +import os +import unittest + +import numpy as np +import soundfile as sf + +import paddleaudio +from ..base import BackendTest + + +class TestIO(BackendTest): + def test_load_mono_channel(self): + sf_data, sf_sr = sf.read(self.files[0]) + pa_data, pa_sr = paddleaudio.load( + self.files[0], normal=False, dtype='float64') + + self.assertEqual(sf_data.dtype, pa_data.dtype) + self.assertEqual(sf_sr, pa_sr) + np.testing.assert_array_almost_equal(sf_data, pa_data) + + def test_load_multi_channels(self): + sf_data, sf_sr = sf.read(self.files[1]) + sf_data = sf_data.T # Channel dim first + pa_data, pa_sr = paddleaudio.load( + self.files[1], mono=False, normal=False, dtype='float64') + + self.assertEqual(sf_data.dtype, pa_data.dtype) + self.assertEqual(sf_sr, pa_sr) + np.testing.assert_array_almost_equal(sf_data, pa_data) + + def test_save_mono_channel(self): + waveform, sr = np.random.randint( + low=-32768, high=32768, size=(48000), dtype=np.int16), 16000 + sf_tmp_file = 'sf_tmp.wav' + pa_tmp_file = 'pa_tmp.wav' + + sf.write(sf_tmp_file, waveform, sr) + paddleaudio.save(waveform, sr, pa_tmp_file) + + self.assertTrue(filecmp.cmp(sf_tmp_file, pa_tmp_file)) + for file in [sf_tmp_file, pa_tmp_file]: + os.remove(file) + + def test_save_multi_channels(self): + waveform, sr = np.random.randint( + low=-32768, high=32768, size=(2, 48000), dtype=np.int16), 16000 + sf_tmp_file = 'sf_tmp.wav' + pa_tmp_file = 'pa_tmp.wav' + + sf.write(sf_tmp_file, waveform.T, sr) + paddleaudio.save(waveform.T, sr, pa_tmp_file) + + self.assertTrue(filecmp.cmp(sf_tmp_file, pa_tmp_file)) + for file in [sf_tmp_file, pa_tmp_file]: + os.remove(file) + + +if __name__ == '__main__': + unittest.main() diff --git a/paddleaudio/tests/benchmark/README.md b/paddleaudio/tests/benchmark/README.md new file mode 100644 index 000000000..b9034100d --- /dev/null +++ b/paddleaudio/tests/benchmark/README.md @@ -0,0 +1,39 @@ +# 1. Prepare +First, install `pytest-benchmark` via pip. +```sh +pip install pytest-benchmark +``` + +# 2. Run +Run the specific script for profiling. +```sh +pytest melspectrogram.py +``` + +Result: +```sh +========================================================================== test session starts ========================================================================== +platform linux -- Python 3.7.7, pytest-7.0.1, pluggy-1.0.0 +benchmark: 3.4.1 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000) +rootdir: /ssd3/chenxiaojie06/PaddleSpeech/DeepSpeech/paddleaudio +plugins: typeguard-2.12.1, benchmark-3.4.1, anyio-3.5.0 +collected 4 items + +melspectrogram.py .... [100%] + + +-------------------------------------------------------------------------------------------------- benchmark: 4 tests ------------------------------------------------------------------------------------------------- +Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- +test_melspect_gpu_torchaudio 202.0765 (1.0) 360.6230 (1.0) 218.1168 (1.0) 16.3022 (1.0) 214.2871 (1.0) 21.8451 (1.0) 40;3 4,584.7001 (1.0) 286 1 +test_melspect_gpu 657.8509 (3.26) 908.0470 (2.52) 724.2545 (3.32) 106.5771 (6.54) 669.9096 (3.13) 113.4719 (5.19) 1;0 1,380.7300 (0.30) 5 1 +test_melspect_cpu_torchaudio 1,247.6053 (6.17) 2,892.5799 (8.02) 1,443.2853 (6.62) 345.3732 (21.19) 1,262.7263 (5.89) 221.6385 (10.15) 56;53 692.8637 (0.15) 399 1 +test_melspect_cpu 20,326.2549 (100.59) 20,607.8682 (57.15) 20,473.4125 (93.86) 63.8654 (3.92) 20,467.0429 (95.51) 68.4294 (3.13) 8;1 48.8438 (0.01) 29 1 +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + +Legend: + Outliers: 1 Standard Deviation from Mean; 1.5 IQR (InterQuartile Range) from 1st Quartile and 3rd Quartile. + OPS: Operations Per Second, computed as 1 / Mean +========================================================================== 4 passed in 21.12s =========================================================================== + +``` diff --git a/paddleaudio/tests/benchmark/log_melspectrogram.py b/paddleaudio/tests/benchmark/log_melspectrogram.py new file mode 100644 index 000000000..5230acd42 --- /dev/null +++ b/paddleaudio/tests/benchmark/log_melspectrogram.py @@ -0,0 +1,124 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import urllib.request + +import librosa +import numpy as np +import paddle +import torch +import torchaudio + +import paddleaudio + +wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav' +if not os.path.isfile(os.path.basename(wav_url)): + urllib.request.urlretrieve(wav_url, os.path.basename(wav_url)) + +waveform, sr = paddleaudio.load(os.path.abspath(os.path.basename(wav_url))) +waveform_tensor = paddle.to_tensor(waveform).unsqueeze(0) +waveform_tensor_torch = torch.from_numpy(waveform).unsqueeze(0) + +# Feature conf +mel_conf = { + 'sr': sr, + 'n_fft': 512, + 'hop_length': 128, + 'n_mels': 40, +} + +mel_conf_torchaudio = { + 'sample_rate': sr, + 'n_fft': 512, + 'hop_length': 128, + 'n_mels': 40, + 'norm': 'slaney', + 'mel_scale': 'slaney', +} + + +def enable_cpu_device(): + paddle.set_device('cpu') + + +def enable_gpu_device(): + paddle.set_device('gpu') + + +log_mel_extractor = paddleaudio.features.LogMelSpectrogram( + **mel_conf, f_min=0.0, top_db=80.0, dtype=waveform_tensor.dtype) + + +def log_melspectrogram(): + return log_mel_extractor(waveform_tensor).squeeze(0) + + +def test_log_melspect_cpu(benchmark): + enable_cpu_device() + feature_paddleaudio = benchmark(log_melspectrogram) + feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) + feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0) + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddleaudio, decimal=3) + + +def test_log_melspect_gpu(benchmark): + enable_gpu_device() + feature_paddleaudio = benchmark(log_melspectrogram) + feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) + feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0) + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddleaudio, decimal=2) + + +mel_extractor_torchaudio = torchaudio.transforms.MelSpectrogram( + **mel_conf_torchaudio, f_min=0.0) +amplitude_to_DB = torchaudio.transforms.AmplitudeToDB('power', top_db=80.0) + + +def melspectrogram_torchaudio(): + return mel_extractor_torchaudio(waveform_tensor_torch).squeeze(0) + + +def log_melspectrogram_torchaudio(): + mel_specgram = mel_extractor_torchaudio(waveform_tensor_torch) + return amplitude_to_DB(mel_specgram).squeeze(0) + + +def test_log_melspect_cpu_torchaudio(benchmark): + global waveform_tensor_torch, mel_extractor_torchaudio, amplitude_to_DB + + mel_extractor_torchaudio = mel_extractor_torchaudio.to('cpu') + waveform_tensor_torch = waveform_tensor_torch.to('cpu') + amplitude_to_DB = amplitude_to_DB.to('cpu') + + feature_paddleaudio = benchmark(log_melspectrogram_torchaudio) + feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) + feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0) + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddleaudio, decimal=3) + + +def test_log_melspect_gpu_torchaudio(benchmark): + global waveform_tensor_torch, mel_extractor_torchaudio, amplitude_to_DB + + mel_extractor_torchaudio = mel_extractor_torchaudio.to('cuda') + waveform_tensor_torch = waveform_tensor_torch.to('cuda') + amplitude_to_DB = amplitude_to_DB.to('cuda') + + feature_torchaudio = benchmark(log_melspectrogram_torchaudio) + feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) + feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0) + np.testing.assert_array_almost_equal( + feature_librosa, feature_torchaudio.cpu(), decimal=2) diff --git a/paddleaudio/tests/benchmark/melspectrogram.py b/paddleaudio/tests/benchmark/melspectrogram.py new file mode 100644 index 000000000..e0b79b45a --- /dev/null +++ b/paddleaudio/tests/benchmark/melspectrogram.py @@ -0,0 +1,108 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import urllib.request + +import librosa +import numpy as np +import paddle +import torch +import torchaudio + +import paddleaudio + +wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav' +if not os.path.isfile(os.path.basename(wav_url)): + urllib.request.urlretrieve(wav_url, os.path.basename(wav_url)) + +waveform, sr = paddleaudio.load(os.path.abspath(os.path.basename(wav_url))) +waveform_tensor = paddle.to_tensor(waveform).unsqueeze(0) +waveform_tensor_torch = torch.from_numpy(waveform).unsqueeze(0) + +# Feature conf +mel_conf = { + 'sr': sr, + 'n_fft': 512, + 'hop_length': 128, + 'n_mels': 40, +} + +mel_conf_torchaudio = { + 'sample_rate': sr, + 'n_fft': 512, + 'hop_length': 128, + 'n_mels': 40, + 'norm': 'slaney', + 'mel_scale': 'slaney', +} + + +def enable_cpu_device(): + paddle.set_device('cpu') + + +def enable_gpu_device(): + paddle.set_device('gpu') + + +mel_extractor = paddleaudio.features.MelSpectrogram( + **mel_conf, f_min=0.0, dtype=waveform_tensor.dtype) + + +def melspectrogram(): + return mel_extractor(waveform_tensor).squeeze(0) + + +def test_melspect_cpu(benchmark): + enable_cpu_device() + feature_paddleaudio = benchmark(melspectrogram) + feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddleaudio, decimal=3) + + +def test_melspect_gpu(benchmark): + enable_gpu_device() + feature_paddleaudio = benchmark(melspectrogram) + feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddleaudio, decimal=3) + + +mel_extractor_torchaudio = torchaudio.transforms.MelSpectrogram( + **mel_conf_torchaudio, f_min=0.0) + + +def melspectrogram_torchaudio(): + return mel_extractor_torchaudio(waveform_tensor_torch).squeeze(0) + + +def test_melspect_cpu_torchaudio(benchmark): + global waveform_tensor_torch, mel_extractor_torchaudio + mel_extractor_torchaudio = mel_extractor_torchaudio.to('cpu') + waveform_tensor_torch = waveform_tensor_torch.to('cpu') + feature_paddleaudio = benchmark(melspectrogram_torchaudio) + feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddleaudio, decimal=3) + + +def test_melspect_gpu_torchaudio(benchmark): + global waveform_tensor_torch, mel_extractor_torchaudio + mel_extractor_torchaudio = mel_extractor_torchaudio.to('cuda') + waveform_tensor_torch = waveform_tensor_torch.to('cuda') + feature_torchaudio = benchmark(melspectrogram_torchaudio) + feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) + np.testing.assert_array_almost_equal( + feature_librosa, feature_torchaudio.cpu(), decimal=3) diff --git a/paddleaudio/tests/benchmark/mfcc.py b/paddleaudio/tests/benchmark/mfcc.py new file mode 100644 index 000000000..2572ff33d --- /dev/null +++ b/paddleaudio/tests/benchmark/mfcc.py @@ -0,0 +1,122 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import urllib.request + +import librosa +import numpy as np +import paddle +import torch +import torchaudio + +import paddleaudio + +wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav' +if not os.path.isfile(os.path.basename(wav_url)): + urllib.request.urlretrieve(wav_url, os.path.basename(wav_url)) + +waveform, sr = paddleaudio.load(os.path.abspath(os.path.basename(wav_url))) +waveform_tensor = paddle.to_tensor(waveform).unsqueeze(0) +waveform_tensor_torch = torch.from_numpy(waveform).unsqueeze(0) + +# Feature conf +mel_conf = { + 'sr': sr, + 'n_fft': 512, + 'hop_length': 128, + 'n_mels': 40, +} +mfcc_conf = { + 'n_mfcc': 20, + 'top_db': 80.0, +} +mfcc_conf.update(mel_conf) + +mel_conf_torchaudio = { + 'sample_rate': sr, + 'n_fft': 512, + 'hop_length': 128, + 'n_mels': 40, + 'norm': 'slaney', + 'mel_scale': 'slaney', +} +mfcc_conf_torchaudio = { + 'sample_rate': sr, + 'n_mfcc': 20, +} + + +def enable_cpu_device(): + paddle.set_device('cpu') + + +def enable_gpu_device(): + paddle.set_device('gpu') + + +mfcc_extractor = paddleaudio.features.MFCC( + **mfcc_conf, f_min=0.0, dtype=waveform_tensor.dtype) + + +def mfcc(): + return mfcc_extractor(waveform_tensor).squeeze(0) + + +def test_mfcc_cpu(benchmark): + enable_cpu_device() + feature_paddleaudio = benchmark(mfcc) + feature_librosa = librosa.feature.mfcc(waveform, **mel_conf) + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddleaudio, decimal=3) + + +def test_mfcc_gpu(benchmark): + enable_gpu_device() + feature_paddleaudio = benchmark(mfcc) + feature_librosa = librosa.feature.mfcc(waveform, **mel_conf) + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddleaudio, decimal=3) + + +del mel_conf_torchaudio['sample_rate'] +mfcc_extractor_torchaudio = torchaudio.transforms.MFCC( + **mfcc_conf_torchaudio, melkwargs=mel_conf_torchaudio) + + +def mfcc_torchaudio(): + return mfcc_extractor_torchaudio(waveform_tensor_torch).squeeze(0) + + +def test_mfcc_cpu_torchaudio(benchmark): + global waveform_tensor_torch, mfcc_extractor_torchaudio + + mel_extractor_torchaudio = mfcc_extractor_torchaudio.to('cpu') + waveform_tensor_torch = waveform_tensor_torch.to('cpu') + + feature_paddleaudio = benchmark(mfcc_torchaudio) + feature_librosa = librosa.feature.mfcc(waveform, **mel_conf) + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddleaudio, decimal=3) + + +def test_mfcc_gpu_torchaudio(benchmark): + global waveform_tensor_torch, mfcc_extractor_torchaudio + + mel_extractor_torchaudio = mfcc_extractor_torchaudio.to('cuda') + waveform_tensor_torch = waveform_tensor_torch.to('cuda') + + feature_torchaudio = benchmark(mfcc_torchaudio) + feature_librosa = librosa.feature.mfcc(waveform, **mel_conf) + np.testing.assert_array_almost_equal( + feature_librosa, feature_torchaudio.cpu(), decimal=3) diff --git a/paddleaudio/tests/features/__init__.py b/paddleaudio/tests/features/__init__.py new file mode 100644 index 000000000..97043fd7b --- /dev/null +++ b/paddleaudio/tests/features/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddleaudio/tests/features/base.py b/paddleaudio/tests/features/base.py new file mode 100644 index 000000000..725e1e2e7 --- /dev/null +++ b/paddleaudio/tests/features/base.py @@ -0,0 +1,49 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import unittest +import urllib.request + +import numpy as np +import paddle + +from paddleaudio import load + +wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav' + + +class FeatTest(unittest.TestCase): + def setUp(self): + self.initParmas() + self.initWavInput() + self.setUpDevice() + + def setUpDevice(self, device='cpu'): + paddle.set_device(device) + + def initWavInput(self, url=wav_url): + if not os.path.isfile(os.path.basename(url)): + urllib.request.urlretrieve(url, os.path.basename(url)) + self.waveform, self.sr = load(os.path.abspath(os.path.basename(url))) + self.waveform = self.waveform.astype( + np.float32 + ) # paddlespeech.s2t.transform.spectrogram only supports float32 + dim = len(self.waveform.shape) + + assert dim in [1, 2] + if dim == 1: + self.waveform = np.expand_dims(self.waveform, 0) + + def initParmas(self): + raise NotImplementedError diff --git a/paddleaudio/tests/features/test_istft.py b/paddleaudio/tests/features/test_istft.py new file mode 100644 index 000000000..23371200b --- /dev/null +++ b/paddleaudio/tests/features/test_istft.py @@ -0,0 +1,49 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +import paddle + +from .base import FeatTest +from paddleaudio.functional.window import get_window +from paddlespeech.s2t.transform.spectrogram import IStft +from paddlespeech.s2t.transform.spectrogram import Stft + + +class TestIstft(FeatTest): + def initParmas(self): + self.n_fft = 512 + self.hop_length = 128 + self.window_str = 'hann' + + def test_istft(self): + ps_stft = Stft(self.n_fft, self.hop_length) + ps_res = ps_stft( + self.waveform.T).squeeze(1).T # (n_fft//2 + 1, n_frmaes) + x = paddle.to_tensor(ps_res) + + ps_istft = IStft(self.hop_length) + ps_res = ps_istft(ps_res.T) + + window = get_window( + self.window_str, self.n_fft, dtype=self.waveform.dtype) + pd_res = paddle.signal.istft( + x, self.n_fft, self.hop_length, window=window) + + np.testing.assert_array_almost_equal(ps_res, pd_res, decimal=5) + + +if __name__ == '__main__': + unittest.main() diff --git a/paddleaudio/tests/features/test_kaldi.py b/paddleaudio/tests/features/test_kaldi.py new file mode 100644 index 000000000..6e826aaa7 --- /dev/null +++ b/paddleaudio/tests/features/test_kaldi.py @@ -0,0 +1,81 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +import paddle +import torch +import torchaudio + +import paddleaudio +from .base import FeatTest + + +class TestKaldi(FeatTest): + def initParmas(self): + self.window_size = 1024 + self.dtype = 'float32' + + def test_window(self): + t_hann_window = torch.hann_window( + self.window_size, periodic=False, dtype=eval(f'torch.{self.dtype}')) + t_hamm_window = torch.hamming_window( + self.window_size, + periodic=False, + alpha=0.54, + beta=0.46, + dtype=eval(f'torch.{self.dtype}')) + t_povey_window = torch.hann_window( + self.window_size, periodic=False, + dtype=eval(f'torch.{self.dtype}')).pow(0.85) + + p_hann_window = paddleaudio.functional.window.get_window( + 'hann', + self.window_size, + fftbins=False, + dtype=eval(f'paddle.{self.dtype}')) + p_hamm_window = paddleaudio.functional.window.get_window( + 'hamming', + self.window_size, + fftbins=False, + dtype=eval(f'paddle.{self.dtype}')) + p_povey_window = paddleaudio.functional.window.get_window( + 'hann', + self.window_size, + fftbins=False, + dtype=eval(f'paddle.{self.dtype}')).pow(0.85) + + np.testing.assert_array_almost_equal(t_hann_window, p_hann_window) + np.testing.assert_array_almost_equal(t_hamm_window, p_hamm_window) + np.testing.assert_array_almost_equal(t_povey_window, p_povey_window) + + def test_fbank(self): + ta_features = torchaudio.compliance.kaldi.fbank( + torch.from_numpy(self.waveform.astype(self.dtype))) + pa_features = paddleaudio.compliance.kaldi.fbank( + paddle.to_tensor(self.waveform.astype(self.dtype))) + np.testing.assert_array_almost_equal( + ta_features, pa_features, decimal=4) + + def test_mfcc(self): + ta_features = torchaudio.compliance.kaldi.mfcc( + torch.from_numpy(self.waveform.astype(self.dtype))) + pa_features = paddleaudio.compliance.kaldi.mfcc( + paddle.to_tensor(self.waveform.astype(self.dtype))) + np.testing.assert_array_almost_equal( + ta_features, pa_features, decimal=4) + + +if __name__ == '__main__': + unittest.main() diff --git a/paddleaudio/tests/features/test_librosa.py b/paddleaudio/tests/features/test_librosa.py new file mode 100644 index 000000000..cf0c98c72 --- /dev/null +++ b/paddleaudio/tests/features/test_librosa.py @@ -0,0 +1,281 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import librosa +import numpy as np +import paddle + +import paddleaudio +from .base import FeatTest +from paddleaudio.functional.window import get_window + + +class TestLibrosa(FeatTest): + def initParmas(self): + self.n_fft = 512 + self.hop_length = 128 + self.n_mels = 40 + self.n_mfcc = 20 + self.fmin = 0.0 + self.window_str = 'hann' + self.pad_mode = 'reflect' + self.top_db = 80.0 + + def test_stft(self): + if len(self.waveform.shape) == 2: # (C, T) + self.waveform = self.waveform.squeeze( + 0) # 1D input for librosa.feature.melspectrogram + + feature_librosa = librosa.core.stft( + y=self.waveform, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=None, + window=self.window_str, + center=True, + dtype=None, + pad_mode=self.pad_mode, ) + x = paddle.to_tensor(self.waveform).unsqueeze(0) + window = get_window(self.window_str, self.n_fft, dtype=x.dtype) + feature_paddle = paddle.signal.stft( + x=x, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=None, + window=window, + center=True, + pad_mode=self.pad_mode, + normalized=False, + onesided=True, ).squeeze(0) + + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddle, decimal=5) + + def test_istft(self): + if len(self.waveform.shape) == 2: # (C, T) + self.waveform = self.waveform.squeeze( + 0) # 1D input for librosa.feature.melspectrogram + + # Get stft result from librosa. + stft_matrix = librosa.core.stft( + y=self.waveform, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=None, + window=self.window_str, + center=True, + pad_mode=self.pad_mode, ) + + feature_librosa = librosa.core.istft( + stft_matrix=stft_matrix, + hop_length=self.hop_length, + win_length=None, + window=self.window_str, + center=True, + dtype=None, + length=None, ) + + x = paddle.to_tensor(stft_matrix).unsqueeze(0) + window = get_window( + self.window_str, + self.n_fft, + dtype=paddle.to_tensor(self.waveform).dtype) + feature_paddle = paddle.signal.istft( + x=x, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=None, + window=window, + center=True, + normalized=False, + onesided=True, + length=None, + return_complex=False, ).squeeze(0) + + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddle, decimal=5) + + def test_mel(self): + feature_librosa = librosa.filters.mel( + sr=self.sr, + n_fft=self.n_fft, + n_mels=self.n_mels, + fmin=self.fmin, + fmax=None, + htk=False, + norm='slaney', + dtype=self.waveform.dtype, ) + feature_compliance = paddleaudio.compliance.librosa.compute_fbank_matrix( + sr=self.sr, + n_fft=self.n_fft, + n_mels=self.n_mels, + fmin=self.fmin, + fmax=None, + htk=False, + norm='slaney', + dtype=self.waveform.dtype, ) + x = paddle.to_tensor(self.waveform) + feature_functional = paddleaudio.functional.compute_fbank_matrix( + sr=self.sr, + n_fft=self.n_fft, + n_mels=self.n_mels, + f_min=self.fmin, + f_max=None, + htk=False, + norm='slaney', + dtype=x.dtype, ) + + np.testing.assert_array_almost_equal(feature_librosa, + feature_compliance) + np.testing.assert_array_almost_equal(feature_librosa, + feature_functional) + + def test_melspect(self): + if len(self.waveform.shape) == 2: # (C, T) + self.waveform = self.waveform.squeeze( + 0) # 1D input for librosa.feature.melspectrogram + + # librosa: + feature_librosa = librosa.feature.melspectrogram( + y=self.waveform, + sr=self.sr, + n_fft=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + fmin=self.fmin) + + # paddleaudio.compliance.librosa: + feature_compliance = paddleaudio.compliance.librosa.melspectrogram( + x=self.waveform, + sr=self.sr, + window_size=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + fmin=self.fmin, + to_db=False) + + # paddleaudio.features.layer + x = paddle.to_tensor( + self.waveform, dtype=paddle.float64).unsqueeze(0) # Add batch dim. + feature_extractor = paddleaudio.features.MelSpectrogram( + sr=self.sr, + n_fft=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + f_min=self.fmin, + dtype=x.dtype) + feature_layer = feature_extractor(x).squeeze(0).numpy() + + np.testing.assert_array_almost_equal( + feature_librosa, feature_compliance, decimal=5) + np.testing.assert_array_almost_equal( + feature_librosa, feature_layer, decimal=5) + + def test_log_melspect(self): + if len(self.waveform.shape) == 2: # (C, T) + self.waveform = self.waveform.squeeze( + 0) # 1D input for librosa.feature.melspectrogram + + # librosa: + feature_librosa = librosa.feature.melspectrogram( + y=self.waveform, + sr=self.sr, + n_fft=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + fmin=self.fmin) + feature_librosa = librosa.power_to_db(feature_librosa, top_db=None) + + # paddleaudio.compliance.librosa: + feature_compliance = paddleaudio.compliance.librosa.melspectrogram( + x=self.waveform, + sr=self.sr, + window_size=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + fmin=self.fmin) + + # paddleaudio.features.layer + x = paddle.to_tensor( + self.waveform, dtype=paddle.float64).unsqueeze(0) # Add batch dim. + feature_extractor = paddleaudio.features.LogMelSpectrogram( + sr=self.sr, + n_fft=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + f_min=self.fmin, + dtype=x.dtype) + feature_layer = feature_extractor(x).squeeze(0).numpy() + + np.testing.assert_array_almost_equal( + feature_librosa, feature_compliance, decimal=5) + np.testing.assert_array_almost_equal( + feature_librosa, feature_layer, decimal=4) + + def test_mfcc(self): + if len(self.waveform.shape) == 2: # (C, T) + self.waveform = self.waveform.squeeze( + 0) # 1D input for librosa.feature.melspectrogram + + # librosa: + feature_librosa = librosa.feature.mfcc( + y=self.waveform, + sr=self.sr, + S=None, + n_mfcc=self.n_mfcc, + dct_type=2, + norm='ortho', + lifter=0, + n_fft=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + fmin=self.fmin) + + # paddleaudio.compliance.librosa: + feature_compliance = paddleaudio.compliance.librosa.mfcc( + x=self.waveform, + sr=self.sr, + n_mfcc=self.n_mfcc, + dct_type=2, + norm='ortho', + lifter=0, + window_size=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + fmin=self.fmin, + top_db=self.top_db) + + # paddleaudio.features.layer + x = paddle.to_tensor( + self.waveform, dtype=paddle.float64).unsqueeze(0) # Add batch dim. + feature_extractor = paddleaudio.features.MFCC( + sr=self.sr, + n_mfcc=self.n_mfcc, + n_fft=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + f_min=self.fmin, + top_db=self.top_db, + dtype=x.dtype) + feature_layer = feature_extractor(x).squeeze(0).numpy() + + np.testing.assert_array_almost_equal( + feature_librosa, feature_compliance, decimal=4) + np.testing.assert_array_almost_equal( + feature_librosa, feature_layer, decimal=4) + + +if __name__ == '__main__': + unittest.main() diff --git a/paddleaudio/tests/features/test_log_melspectrogram.py b/paddleaudio/tests/features/test_log_melspectrogram.py new file mode 100644 index 000000000..6bae2df3f --- /dev/null +++ b/paddleaudio/tests/features/test_log_melspectrogram.py @@ -0,0 +1,50 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +import paddle + +import paddleaudio +from .base import FeatTest +from paddlespeech.s2t.transform.spectrogram import LogMelSpectrogram + + +class TestLogMelSpectrogram(FeatTest): + def initParmas(self): + self.n_fft = 512 + self.hop_length = 128 + self.n_mels = 40 + + def test_log_melspect(self): + ps_melspect = LogMelSpectrogram(self.sr, self.n_mels, self.n_fft, + self.hop_length) + ps_res = ps_melspect(self.waveform.T).squeeze(1).T + + x = paddle.to_tensor(self.waveform) + # paddlespeech.s2t的特征存在幅度谱和功率谱滥用的情况 + ps_melspect = paddleaudio.features.LogMelSpectrogram( + self.sr, + self.n_fft, + self.hop_length, + power=1.0, + n_mels=self.n_mels, + f_min=0.0) + pa_res = (ps_melspect(x) / 10.0).squeeze(0).numpy() + + np.testing.assert_array_almost_equal(ps_res, pa_res, decimal=5) + + +if __name__ == '__main__': + unittest.main() diff --git a/paddleaudio/tests/features/test_spectrogram.py b/paddleaudio/tests/features/test_spectrogram.py new file mode 100644 index 000000000..50b21403b --- /dev/null +++ b/paddleaudio/tests/features/test_spectrogram.py @@ -0,0 +1,42 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +import paddle + +import paddleaudio +from .base import FeatTest +from paddlespeech.s2t.transform.spectrogram import Spectrogram + + +class TestSpectrogram(FeatTest): + def initParmas(self): + self.n_fft = 512 + self.hop_length = 128 + + def test_spectrogram(self): + ps_spect = Spectrogram(self.n_fft, self.hop_length) + ps_res = ps_spect(self.waveform.T).squeeze(1).T # Magnitude + + x = paddle.to_tensor(self.waveform) + pa_spect = paddleaudio.features.Spectrogram( + self.n_fft, self.hop_length, power=1.0) + pa_res = pa_spect(x).squeeze(0).numpy() + + np.testing.assert_array_almost_equal(ps_res, pa_res, decimal=5) + + +if __name__ == '__main__': + unittest.main() diff --git a/paddleaudio/tests/features/test_stft.py b/paddleaudio/tests/features/test_stft.py new file mode 100644 index 000000000..c64b5ebe6 --- /dev/null +++ b/paddleaudio/tests/features/test_stft.py @@ -0,0 +1,44 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +import paddle + +from .base import FeatTest +from paddleaudio.functional.window import get_window +from paddlespeech.s2t.transform.spectrogram import Stft + + +class TestStft(FeatTest): + def initParmas(self): + self.n_fft = 512 + self.hop_length = 128 + self.window_str = 'hann' + + def test_stft(self): + ps_stft = Stft(self.n_fft, self.hop_length) + ps_res = ps_stft( + self.waveform.T).squeeze(1).T # (n_fft//2 + 1, n_frmaes) + + x = paddle.to_tensor(self.waveform) + window = get_window(self.window_str, self.n_fft, dtype=x.dtype) + pd_res = paddle.signal.stft( + x, self.n_fft, self.hop_length, window=window).squeeze(0).numpy() + + np.testing.assert_array_almost_equal(ps_res, pd_res, decimal=5) + + +if __name__ == '__main__': + unittest.main()