From 7261d86344fb256edd3def1ce3b620afbb03f745 Mon Sep 17 00:00:00 2001 From: Yang Zhou Date: Thu, 15 Sep 2022 21:12:57 +0800 Subject: [PATCH] add test & benchmark --- audio/tests/backends/__init__.py | 13 + audio/tests/backends/common.py | 32 ++ audio/tests/backends/soundfile/__init__.py | 13 + audio/tests/backends/soundfile/base.py | 34 ++ audio/tests/backends/soundfile/common.py | 57 +++ audio/tests/backends/soundfile/info_test.py | 199 ++++++++++ audio/tests/backends/soundfile/load_test.py | 369 ++++++++++++++++++ audio/tests/backends/soundfile/save_test.py | 322 +++++++++++++++ audio/tests/backends/soundfile/test_io.py | 74 ++++ audio/tests/benchmark/README.md | 39 ++ audio/tests/benchmark/log_melspectrogram.py | 123 ++++++ audio/tests/benchmark/melspectrogram.py | 107 +++++ audio/tests/benchmark/mfcc.py | 121 ++++++ audio/tests/common_utils/__init__.py | 17 + audio/tests/common_utils/case_utils.py | 56 +++ .../tests/common_utils/parameterized_utils.py | 43 ++ audio/tests/common_utils/wav_utils.py | 102 +++++ audio/tests/features/__init__.py | 13 + audio/tests/features/base.py | 48 +++ audio/tests/features/test_kaldi.py | 81 ++++ audio/tests/features/test_librosa.py | 281 +++++++++++++ 21 files changed, 2144 insertions(+) create mode 100644 audio/tests/backends/__init__.py create mode 100644 audio/tests/backends/common.py create mode 100644 audio/tests/backends/soundfile/__init__.py create mode 100644 audio/tests/backends/soundfile/base.py create mode 100644 audio/tests/backends/soundfile/common.py create mode 100644 audio/tests/backends/soundfile/info_test.py create mode 100644 audio/tests/backends/soundfile/load_test.py create mode 100644 audio/tests/backends/soundfile/save_test.py create mode 100644 audio/tests/backends/soundfile/test_io.py create mode 100644 audio/tests/benchmark/README.md create mode 100644 audio/tests/benchmark/log_melspectrogram.py create mode 100644 audio/tests/benchmark/melspectrogram.py create mode 100644 audio/tests/benchmark/mfcc.py create mode 100644 audio/tests/common_utils/__init__.py create mode 100644 audio/tests/common_utils/case_utils.py create mode 100644 audio/tests/common_utils/parameterized_utils.py create mode 100644 audio/tests/common_utils/wav_utils.py create mode 100644 audio/tests/features/__init__.py create mode 100644 audio/tests/features/base.py create mode 100644 audio/tests/features/test_kaldi.py create mode 100644 audio/tests/features/test_librosa.py diff --git a/audio/tests/backends/__init__.py b/audio/tests/backends/__init__.py new file mode 100644 index 000000000..97043fd7b --- /dev/null +++ b/audio/tests/backends/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/audio/tests/backends/common.py b/audio/tests/backends/common.py new file mode 100644 index 000000000..79b922a91 --- /dev/null +++ b/audio/tests/backends/common.py @@ -0,0 +1,32 @@ + +def get_encoding(ext, dtype): + exts = { + "mp3", + "flac", + "vorbis", + } + encodings = { + "float32": "PCM_F", + "int32": "PCM_S", + "int16": "PCM_S", + "uint8": "PCM_U", + } + return ext.upper() if ext in exts else encodings[dtype] + + +def get_bit_depth(dtype): + bit_depths = { + "float32": 32, + "int32": 32, + "int16": 16, + "uint8": 8, + } + return bit_depths[dtype] + +def get_bits_per_sample(ext, dtype): + bits_per_samples = { + "flac": 24, + "mp3": 0, + "vorbis": 0, + } + return bits_per_samples.get(ext, get_bit_depth(dtype)) diff --git a/audio/tests/backends/soundfile/__init__.py b/audio/tests/backends/soundfile/__init__.py new file mode 100644 index 000000000..97043fd7b --- /dev/null +++ b/audio/tests/backends/soundfile/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/audio/tests/backends/soundfile/base.py b/audio/tests/backends/soundfile/base.py new file mode 100644 index 000000000..a67191887 --- /dev/null +++ b/audio/tests/backends/soundfile/base.py @@ -0,0 +1,34 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import unittest +import urllib.request + +mono_channel_wav = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav' +multi_channels_wav = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/cat.wav' + + +class BackendTest(unittest.TestCase): + def setUp(self): + self.initWavInput() + + def initWavInput(self): + self.files = [] + for url in [mono_channel_wav, multi_channels_wav]: + if not os.path.isfile(os.path.basename(url)): + urllib.request.urlretrieve(url, os.path.basename(url)) + self.files.append(os.path.basename(url)) + + def initParmas(self): + raise NotImplementedError diff --git a/audio/tests/backends/soundfile/common.py b/audio/tests/backends/soundfile/common.py new file mode 100644 index 000000000..42a07e1f0 --- /dev/null +++ b/audio/tests/backends/soundfile/common.py @@ -0,0 +1,57 @@ +import itertools +from unittest import skipIf + +from parameterized import parameterized +from paddleaudio._internal.module_utils import is_module_available + + +def name_func(func, _, params): + return f'{func.__name__}_{"_".join(str(arg) for arg in params.args)}' + + +def dtype2subtype(dtype): + return { + "float64": "DOUBLE", + "float32": "FLOAT", + "int32": "PCM_32", + "int16": "PCM_16", + "uint8": "PCM_U8", + "int8": "PCM_S8", + }[dtype] + + +def skipIfFormatNotSupported(fmt): + fmts = [] + if is_module_available("soundfile"): + import soundfile + + fmts = soundfile.available_formats() + return skipIf(fmt not in fmts, f'"{fmt}" is not supported by soundfile') + return skipIf(True, '"soundfile" not available.') + + +def parameterize(*params): + return parameterized.expand(list(itertools.product(*params)), name_func=name_func) + + +def fetch_wav_subtype(dtype, encoding, bits_per_sample): + subtype = { + (None, None): dtype2subtype(dtype), + (None, 8): "PCM_U8", + ("PCM_U", None): "PCM_U8", + ("PCM_U", 8): "PCM_U8", + ("PCM_S", None): "PCM_32", + ("PCM_S", 16): "PCM_16", + ("PCM_S", 32): "PCM_32", + ("PCM_F", None): "FLOAT", + ("PCM_F", 32): "FLOAT", + ("PCM_F", 64): "DOUBLE", + ("ULAW", None): "ULAW", + ("ULAW", 8): "ULAW", + ("ALAW", None): "ALAW", + ("ALAW", 8): "ALAW", + }.get((encoding, bits_per_sample)) + if subtype: + return subtype + raise ValueError(f"wav does not support ({encoding}, {bits_per_sample}).") + diff --git a/audio/tests/backends/soundfile/info_test.py b/audio/tests/backends/soundfile/info_test.py new file mode 100644 index 000000000..94f167ed9 --- /dev/null +++ b/audio/tests/backends/soundfile/info_test.py @@ -0,0 +1,199 @@ +#this code is from: https://github.com/pytorch/audio/blob/main/test/torchaudio_unittest/backend/soundfile/info_test.py + +import tarfile +import warnings +import unittest +from unittest.mock import patch + +import paddle +from paddleaudio._internal import module_utils as _mod_utils +from paddleaudio.backends import soundfile_backend +from tests.backends.common import get_bits_per_sample, get_encoding +from tests.common_utils import ( + get_wav_data, + nested_params, + save_wav, + TempDirMixin, +) + +from common import parameterize, skipIfFormatNotSupported + +import soundfile + + +class TestInfo(TempDirMixin, unittest.TestCase): + @parameterize( + ["float32", "int32"], + [8000, 16000], + [1, 2], + ) + def test_wav(self, dtype, sample_rate, num_channels): + """`soundfile_backend.info` can check wav file correctly""" + duration = 1 + path = self.get_temp_path("data.wav") + data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate) + save_wav(path, data, sample_rate) + info = soundfile_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels + assert info.bits_per_sample == get_bits_per_sample("wav", dtype) + assert info.encoding == get_encoding("wav", dtype) + + @parameterize([8000, 16000], [1, 2]) + @skipIfFormatNotSupported("FLAC") + def test_flac(self, sample_rate, num_channels): + """`soundfile_backend.info` can check flac file correctly""" + duration = 1 + num_frames = sample_rate * duration + #data = torch.randn(num_frames, num_channels).numpy() + data = paddle.randn(shape=[num_frames, num_channels]).numpy() + + path = self.get_temp_path("data.flac") + soundfile.write(path, data, sample_rate) + + info = soundfile_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_frames == num_frames + assert info.num_channels == num_channels + assert info.bits_per_sample == 16 + assert info.encoding == "FLAC" + + #@parameterize([8000, 16000], [1, 2]) + #@skipIfFormatNotSupported("OGG") + #def test_ogg(self, sample_rate, num_channels): + #"""`soundfile_backend.info` can check ogg file correctly""" + #duration = 1 + #num_frames = sample_rate * duration + ##data = torch.randn(num_frames, num_channels).numpy() + #data = paddle.randn(shape=[num_frames, num_channels]).numpy() + #print(len(data)) + #path = self.get_temp_path("data.ogg") + #soundfile.write(path, data, sample_rate) + + #info = soundfile_backend.info(path) + #print(info) + #assert info.sample_rate == sample_rate + #print("info") + #print(info.num_frames) + #print("jiji") + #print(sample_rate*duration) + ##assert info.num_frames == sample_rate * duration + #assert info.num_channels == num_channels + #assert info.bits_per_sample == 0 + #assert info.encoding == "VORBIS" + + @nested_params( + [8000, 16000], + [1, 2], + [("PCM_24", 24), ("PCM_32", 32)], + ) + @skipIfFormatNotSupported("NIST") + def test_sphere(self, sample_rate, num_channels, subtype_and_bit_depth): + """`soundfile_backend.info` can check sph file correctly""" + duration = 1 + num_frames = sample_rate * duration + #data = torch.randn(num_frames, num_channels).numpy() + data = paddle.randn(shape=[num_frames, num_channels]).numpy() + path = self.get_temp_path("data.nist") + subtype, bits_per_sample = subtype_and_bit_depth + soundfile.write(path, data, sample_rate, subtype=subtype) + + info = soundfile_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels + assert info.bits_per_sample == bits_per_sample + assert info.encoding == "PCM_S" + + def test_unknown_subtype_warning(self): + """soundfile_backend.info issues a warning when the subtype is unknown + + This will happen if a new subtype is supported in SoundFile: the _SUBTYPE_TO_BITS_PER_SAMPLE + dict should be updated. + """ + + def _mock_info_func(_): + class MockSoundFileInfo: + samplerate = 8000 + frames = 356 + channels = 2 + subtype = "UNSEEN_SUBTYPE" + format = "UNKNOWN" + + return MockSoundFileInfo() + + with patch("soundfile.info", _mock_info_func): + with warnings.catch_warnings(record=True) as w: + info = soundfile_backend.info("foo") + assert len(w) == 1 + assert "UNSEEN_SUBTYPE subtype is unknown to PaddleAudio" in str(w[-1].message) + assert info.bits_per_sample == 0 + + +class TestFileObject(TempDirMixin, unittest.TestCase): + def _test_fileobj(self, ext, subtype, bits_per_sample): + """Query audio via file-like object works""" + duration = 2 + sample_rate = 16000 + num_channels = 2 + num_frames = sample_rate * duration + path = self.get_temp_path(f"test.{ext}") + + #data = torch.randn(num_frames, num_channels).numpy() + data = paddle.randn(shape=[num_frames, num_channels]).numpy() + soundfile.write(path, data, sample_rate, subtype=subtype) + + with open(path, "rb") as fileobj: + info = soundfile_backend.info(fileobj) + assert info.sample_rate == sample_rate + assert info.num_frames == num_frames + assert info.num_channels == num_channels + assert info.bits_per_sample == bits_per_sample + assert info.encoding == "FLAC" if ext == "flac" else "PCM_S" + + def test_fileobj_wav(self): + """Loading audio via file-like object works""" + self._test_fileobj("wav", "PCM_16", 16) + + @skipIfFormatNotSupported("FLAC") + def test_fileobj_flac(self): + """Loading audio via file-like object works""" + self._test_fileobj("flac", "PCM_16", 16) + + def _test_tarobj(self, ext, subtype, bits_per_sample): + """Query compressed audio via file-like object works""" + duration = 2 + sample_rate = 16000 + num_channels = 2 + num_frames = sample_rate * duration + audio_file = f"test.{ext}" + audio_path = self.get_temp_path(audio_file) + archive_path = self.get_temp_path("archive.tar.gz") + + #data = torch.randn(num_frames, num_channels).numpy() + data = paddle.randn(shape=[num_frames, num_channels]).numpy() + soundfile.write(audio_path, data, sample_rate, subtype=subtype) + + with tarfile.TarFile(archive_path, "w") as tarobj: + tarobj.add(audio_path, arcname=audio_file) + with tarfile.TarFile(archive_path, "r") as tarobj: + fileobj = tarobj.extractfile(audio_file) + info = soundfile_backend.info(fileobj) + assert info.sample_rate == sample_rate + assert info.num_frames == num_frames + assert info.num_channels == num_channels + assert info.bits_per_sample == bits_per_sample + assert info.encoding == "FLAC" if ext == "flac" else "PCM_S" + + def test_tarobj_wav(self): + """Query compressed audio via file-like object works""" + self._test_tarobj("wav", "PCM_16", 16) + + @skipIfFormatNotSupported("FLAC") + def test_tarobj_flac(self): + """Query compressed audio via file-like object works""" + self._test_tarobj("flac", "PCM_16", 16) + +if __name__ == '__main__': + unittest.main() diff --git a/audio/tests/backends/soundfile/load_test.py b/audio/tests/backends/soundfile/load_test.py new file mode 100644 index 000000000..d315703cb --- /dev/null +++ b/audio/tests/backends/soundfile/load_test.py @@ -0,0 +1,369 @@ +#this code is from: https://github.com/pytorch/audio/blob/main/test/torchaudio_unittest/backend/soundfile/load_test.py + +import os +import tarfile +import unittest +from unittest.mock import patch +import numpy as np + +from parameterized import parameterized +import paddle +from paddleaudio._internal import module_utils as _mod_utils +from paddleaudio.backends import soundfile_backend +from tests.backends.common import get_bits_per_sample, get_encoding +from tests.common_utils import ( + get_wav_data, + load_wav, + nested_params, + normalize_wav, + save_wav, + TempDirMixin, +) + +from common import dtype2subtype, parameterize, skipIfFormatNotSupported + +import soundfile + + +def _get_mock_path( + ext: str, + dtype: str, + sample_rate: int, + num_channels: int, + num_frames: int, +): + return f"{dtype}_{sample_rate}_{num_channels}_{num_frames}.{ext}" + + +def _get_mock_params(path: str): + filename, ext = path.split(".") + parts = filename.split("_") + return { + "ext": ext, + "dtype": parts[0], + "sample_rate": int(parts[1]), + "num_channels": int(parts[2]), + "num_frames": int(parts[3]), + } + + +class SoundFileMock: + def __init__(self, path, mode): + assert mode == "r" + self.path = path + self._params = _get_mock_params(path) + self._start = None + + @property + def samplerate(self): + return self._params["sample_rate"] + + @property + def format(self): + if self._params["ext"] == "wav": + return "WAV" + if self._params["ext"] == "flac": + return "FLAC" + if self._params["ext"] == "ogg": + return "OGG" + if self._params["ext"] in ["sph", "nis", "nist"]: + return "NIST" + + @property + def subtype(self): + if self._params["ext"] == "ogg": + return "VORBIS" + return dtype2subtype(self._params["dtype"]) + + def _prepare_read(self, start, stop, frames): + assert stop is None + self._start = start + return frames + + def read(self, frames, dtype, always_2d): + assert always_2d + data = get_wav_data( + dtype, + self._params["num_channels"], + normalize=False, + num_frames=self._params["num_frames"], + channels_first=False, + ).numpy() + return data[self._start : self._start + frames] + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + pass + + +class MockedLoadTest(unittest.TestCase): + def assert_dtype(self, ext, dtype, sample_rate, num_channels, normalize, channels_first): + """When format is WAV or NIST, normalize=False will return the native dtype Tensor, otherwise float32""" + num_frames = 3 * sample_rate + path = _get_mock_path(ext, dtype, sample_rate, num_channels, num_frames) + expected_dtype = paddle.float32 if normalize or ext not in ["wav", "nist"] else getattr(paddle, dtype) + with patch("soundfile.SoundFile", SoundFileMock): + found, sr = soundfile_backend.load(path, normalize=normalize, channels_first=channels_first) + assert found.dtype == expected_dtype + assert sample_rate == sr + + @parameterize( + ["int32", "float32", "float64"], + [8000, 16000], + [1, 2], + [True, False], + [True, False], + ) + def test_wav(self, dtype, sample_rate, num_channels, normalize, channels_first): + """Returns native dtype when normalize=False else float32""" + self.assert_dtype("wav", dtype, sample_rate, num_channels, normalize, channels_first) + + @parameterize( + ["int32"], + [8000, 16000], + [1, 2], + [True, False], + [True, False], + ) + def test_sphere(self, dtype, sample_rate, num_channels, normalize, channels_first): + """Returns float32 always""" + self.assert_dtype("sph", dtype, sample_rate, num_channels, normalize, channels_first) + + @parameterize([8000, 16000], [1, 2], [True, False], [True, False]) + def test_ogg(self, sample_rate, num_channels, normalize, channels_first): + """Returns float32 always""" + self.assert_dtype("ogg", "int16", sample_rate, num_channels, normalize, channels_first) + + @parameterize([8000, 16000], [1, 2], [True, False], [True, False]) + def test_flac(self, sample_rate, num_channels, normalize, channels_first): + """`soundfile_backend.load` can load ogg format.""" + self.assert_dtype("flac", "int16", sample_rate, num_channels, normalize, channels_first) + + +class LoadTestBase(TempDirMixin, unittest.TestCase): + def assert_wav( + self, + dtype, + sample_rate, + num_channels, + normalize, + channels_first=True, + duration=1, + ): + """`soundfile_backend.load` can load wav format correctly. + + Wav data loaded with soundfile backend should match those with scipy + """ + path = self.get_temp_path("reference.wav") + num_frames = duration * sample_rate + data = get_wav_data( + dtype, + num_channels, + normalize=normalize, + num_frames=num_frames, + channels_first=channels_first, + ) + save_wav(path, data, sample_rate, channels_first=channels_first) + expected = load_wav(path, normalize=normalize, channels_first=channels_first)[0] + data, sr = soundfile_backend.load(path, normalize=normalize, channels_first=channels_first) + assert sr == sample_rate + np.testing.assert_array_almost_equal(data.numpy(), expected.numpy()) + + def assert_sphere( + self, + dtype, + sample_rate, + num_channels, + channels_first=True, + duration=1, + ): + """`soundfile_backend.load` can load SPHERE format correctly.""" + path = self.get_temp_path("reference.sph") + num_frames = duration * sample_rate + raw = get_wav_data( + dtype, + num_channels, + num_frames=num_frames, + normalize=False, + channels_first=False, + ) + soundfile.write(path, raw, sample_rate, subtype=dtype2subtype(dtype), format="NIST") + expected = normalize_wav(raw.t() if channels_first else raw) + data, sr = soundfile_backend.load(path, channels_first=channels_first) + assert sr == sample_rate + #self.assertEqual(data, expected, atol=1e-4, rtol=1e-8) + np.testing.assert_array_almost_equal(data.numpy(), expected.numpy()) + + def assert_flac( + self, + dtype, + sample_rate, + num_channels, + channels_first=True, + duration=1, + ): + """`soundfile_backend.load` can load FLAC format correctly.""" + path = self.get_temp_path("reference.flac") + num_frames = duration * sample_rate + raw = get_wav_data( + dtype, + num_channels, + num_frames=num_frames, + normalize=False, + channels_first=False, + ) + soundfile.write(path, raw, sample_rate) + expected = normalize_wav(raw.t() if channels_first else raw) + data, sr = soundfile_backend.load(path, channels_first=channels_first) + assert sr == sample_rate + #self.assertEqual(data, expected, atol=1e-4, rtol=1e-8) + np.testing.assert_array_almost_equal(data.numpy(), expected.numpy()) + + + +class TestLoad(LoadTestBase): + """Test the correctness of `soundfile_backend.load` for various formats""" + + @parameterize( + ["float32", "int32"], + [8000, 16000], + [1, 2], + [False, True], + [False, True], + ) + def test_wav(self, dtype, sample_rate, num_channels, normalize, channels_first): + """`soundfile_backend.load` can load wav format correctly.""" + self.assert_wav(dtype, sample_rate, num_channels, normalize, channels_first) + + @parameterize( + ["int32"], + [16000], + [2], + [False], + ) + def test_wav_large(self, dtype, sample_rate, num_channels, normalize): + """`soundfile_backend.load` can load large wav file correctly.""" + two_hours = 2 * 60 * 60 + self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=two_hours) + + @parameterize(["float32", "int32"], [4, 8, 16, 32], [False, True]) + def test_multiple_channels(self, dtype, num_channels, channels_first): + """`soundfile_backend.load` can load wav file with more than 2 channels.""" + sample_rate = 8000 + normalize = False + self.assert_wav(dtype, sample_rate, num_channels, normalize, channels_first) + + #@parameterize(["int32"], [8000, 16000], [1, 2], [False, True]) + #@skipIfFormatNotSupported("NIST") + #def test_sphere(self, dtype, sample_rate, num_channels, channels_first): + #"""`soundfile_backend.load` can load sphere format correctly.""" + #self.assert_sphere(dtype, sample_rate, num_channels, channels_first) + + #@parameterize(["int32"], [8000, 16000], [1, 2], [False, True]) + #@skipIfFormatNotSupported("FLAC") + #def test_flac(self, dtype, sample_rate, num_channels, channels_first): + #"""`soundfile_backend.load` can load flac format correctly.""" + #self.assert_flac(dtype, sample_rate, num_channels, channels_first) + + +class TestLoadFormat(TempDirMixin, unittest.TestCase): + """Given `format` parameter, `so.load` can load files without extension""" + + original = None + path = None + + def _make_file(self, format_): + sample_rate = 8000 + path_with_ext = self.get_temp_path(f"test.{format_}") + data = get_wav_data("float32", num_channels=2).numpy().T + soundfile.write(path_with_ext, data, sample_rate) + expected = soundfile.read(path_with_ext, dtype="float32")[0].T + path = os.path.splitext(path_with_ext)[0] + os.rename(path_with_ext, path) + return path, expected + + def _test_format(self, format_): + """Providing format allows to read file without extension""" + path, expected = self._make_file(format_) + found, _ = soundfile_backend.load(path) + #self.assertEqual(found, expected) + np.testing.assert_array_almost_equal(found, expected) + + @parameterized.expand( + [ + ("WAV",), + ("wav",), + ] + ) + def test_wav(self, format_): + self._test_format(format_) + + @parameterized.expand( + [ + ("FLAC",), + ("flac",), + ] + ) + @skipIfFormatNotSupported("FLAC") + def test_flac(self, format_): + self._test_format(format_) + + +class TestFileObject(TempDirMixin, unittest.TestCase): + def _test_fileobj(self, ext): + """Loading audio via file-like object works""" + sample_rate = 16000 + path = self.get_temp_path(f"test.{ext}") + + data = get_wav_data("float32", num_channels=2).numpy().T + soundfile.write(path, data, sample_rate) + expected = soundfile.read(path, dtype="float32")[0].T + + with open(path, "rb") as fileobj: + found, sr = soundfile_backend.load(fileobj) + assert sr == sample_rate + #self.assertEqual(expected, found) + np.testing.assert_array_almost_equal(found, expected) + + def test_fileobj_wav(self): + """Loading audio via file-like object works""" + self._test_fileobj("wav") + + def test_fileobj_flac(self): + """Loading audio via file-like object works""" + self._test_fileobj("flac") + + def _test_tarfile(self, ext): + """Loading audio via file-like object works""" + sample_rate = 16000 + audio_file = f"test.{ext}" + audio_path = self.get_temp_path(audio_file) + archive_path = self.get_temp_path("archive.tar.gz") + + data = get_wav_data("float32", num_channels=2).numpy().T + soundfile.write(audio_path, data, sample_rate) + expected = soundfile.read(audio_path, dtype="float32")[0].T + + with tarfile.TarFile(archive_path, "w") as tarobj: + tarobj.add(audio_path, arcname=audio_file) + with tarfile.TarFile(archive_path, "r") as tarobj: + fileobj = tarobj.extractfile(audio_file) + found, sr = soundfile_backend.load(fileobj) + + assert sr == sample_rate + #self.assertEqual(expected, found) + np.testing.assert_array_almost_equal(found.numpy(), expected) + + + def test_tarfile_wav(self): + """Loading audio via file-like object works""" + self._test_tarfile("wav") + + def test_tarfile_flac(self): + """Loading audio via file-like object works""" + self._test_tarfile("flac") + +if __name__ == '__main__': + unittest.main() diff --git a/audio/tests/backends/soundfile/save_test.py b/audio/tests/backends/soundfile/save_test.py new file mode 100644 index 000000000..28f0e5c79 --- /dev/null +++ b/audio/tests/backends/soundfile/save_test.py @@ -0,0 +1,322 @@ +import io +import unittest +from unittest.mock import patch + +from paddleaudio._internal import module_utils as _mod_utils +from paddleaudio.backends import soundfile_backend +from tests.common_utils import ( + get_wav_data, + load_wav, + nested_params, + normalize_wav, + save_wav, + TempDirMixin, +) + +from common import fetch_wav_subtype, parameterize, skipIfFormatNotSupported + +import paddle +import numpy as np + +import soundfile + + +class MockedSaveTest(unittest.TestCase): + @nested_params( + ["float32", "int32"], + [8000, 16000], + [1, 2], + [False, True], + [ + (None, None), + ("PCM_U", None), + ("PCM_U", 8), + ("PCM_S", None), + ("PCM_S", 16), + ("PCM_S", 32), + ("PCM_F", None), + ("PCM_F", 32), + ("PCM_F", 64), + ("ULAW", None), + ("ULAW", 8), + ("ALAW", None), + ("ALAW", 8), + ], + ) + @patch("soundfile.write") + def test_wav(self, dtype, sample_rate, num_channels, channels_first, enc_params, mocked_write): + """soundfile_backend.save passes correct subtype to soundfile.write when WAV""" + filepath = "foo.wav" + input_tensor = get_wav_data( + dtype, + num_channels, + num_frames=3 * sample_rate, + normalize=dtype == "float32", + channels_first=channels_first, + ) + input_tensor = paddle.transpose(input_tensor, [1, 0]) + + encoding, bits_per_sample = enc_params + soundfile_backend.save( + filepath, + input_tensor, + sample_rate, + channels_first=channels_first, + encoding=encoding, + bits_per_sample=bits_per_sample, + ) + + # on +Py3.8 call_args.kwargs is more descreptive + args = mocked_write.call_args[1] + assert args["file"] == filepath + assert args["samplerate"] == sample_rate + assert args["subtype"] == fetch_wav_subtype(dtype, encoding, bits_per_sample) + assert args["format"] is None + tensor_result = paddle.transpose(input_tensor, [1, 0]) if channels_first else input_tensor + #self.assertEqual(args["data"], tensor_result.numpy()) + np.testing.assert_array_almost_equal(args["data"].numpy(), tensor_result.numpy()) + + + + @patch("soundfile.write") + def assert_non_wav( + self, + fmt, + dtype, + sample_rate, + num_channels, + channels_first, + mocked_write, + encoding=None, + bits_per_sample=None, + ): + """soundfile_backend.save passes correct subtype and format to soundfile.write when SPHERE""" + filepath = f"foo.{fmt}" + input_tensor = get_wav_data( + dtype, + num_channels, + num_frames=3 * sample_rate, + normalize=False, + channels_first=channels_first, + ) + input_tensor = paddle.transpose(input_tensor, [1, 0]) + + expected_data = paddle.transpose(input_tensor, [1, 0]) if channels_first else input_tensor + + soundfile_backend.save( + filepath, + input_tensor, + sample_rate, + channels_first, + encoding=encoding, + bits_per_sample=bits_per_sample, + ) + + # on +Py3.8 call_args.kwargs is more descreptive + args = mocked_write.call_args[1] + assert args["file"] == filepath + assert args["samplerate"] == sample_rate + if fmt in ["sph", "nist", "nis"]: + assert args["format"] == "NIST" + else: + assert args["format"] is None + np.testing.assert_array_almost_equal(args["data"].numpy(), expected_data.numpy()) + #self.assertEqual(args["data"], expected_data) + + @nested_params( + ["sph", "nist", "nis"], + ["int32"], + [8000, 16000], + [1, 2], + [False, True], + [ + ("PCM_S", 8), + ("PCM_S", 16), + ("PCM_S", 24), + ("PCM_S", 32), + ("ULAW", 8), + ("ALAW", 8), + ("ALAW", 16), + ("ALAW", 24), + ("ALAW", 32), + ], + ) + def test_sph(self, fmt, dtype, sample_rate, num_channels, channels_first, enc_params): + """soundfile_backend.save passes default format and subtype (None-s) to + soundfile.write when not WAV""" + encoding, bits_per_sample = enc_params + self.assert_non_wav( + fmt, dtype, sample_rate, num_channels, channels_first, encoding=encoding, bits_per_sample=bits_per_sample + ) + + @parameterize( + ["int32"], + [8000, 16000], + [1, 2], + [False, True], + [8, 16, 24], + ) + def test_flac(self, dtype, sample_rate, num_channels, channels_first, bits_per_sample): + """soundfile_backend.save passes default format and subtype (None-s) to + soundfile.write when not WAV""" + self.assert_non_wav("flac", dtype, sample_rate, num_channels, channels_first, bits_per_sample=bits_per_sample) + + @parameterize( + ["int32"], + [8000, 16000], + [1, 2], + [False, True], + ) + def test_ogg(self, dtype, sample_rate, num_channels, channels_first): + """soundfile_backend.save passes default format and subtype (None-s) to + soundfile.write when not WAV""" + self.assert_non_wav("ogg", dtype, sample_rate, num_channels, channels_first) + + +class SaveTestBase(TempDirMixin, unittest.TestCase): + def assert_wav(self, dtype, sample_rate, num_channels, num_frames): + """`soundfile_backend.save` can save wav format.""" + path = self.get_temp_path("data.wav") + expected = get_wav_data(dtype, num_channels, num_frames=num_frames, normalize=False) + soundfile_backend.save(path, expected, sample_rate) + found, sr = load_wav(path, normalize=False) + assert sample_rate == sr + #self.assertEqual(found, expected) + np.testing.assert_array_almost_equal(found.numpy(), expected.numpy()) + + def _assert_non_wav(self, fmt, dtype, sample_rate, num_channels): + """`soundfile_backend.save` can save non-wav format. + + Due to precision missmatch, and the lack of alternative way to decode the + resulting files without using soundfile, only meta data are validated. + """ + num_frames = sample_rate * 3 + path = self.get_temp_path(f"data.{fmt}") + expected = get_wav_data(dtype, num_channels, num_frames=num_frames, normalize=False) + soundfile_backend.save(path, expected, sample_rate) + sinfo = soundfile.info(path) + assert sinfo.format == fmt.upper() + #assert sinfo.frames == num_frames this go wrong + assert sinfo.channels == num_channels + assert sinfo.samplerate == sample_rate + + def assert_flac(self, dtype, sample_rate, num_channels): + """`soundfile_backend.save` can save flac format.""" + self._assert_non_wav("flac", dtype, sample_rate, num_channels) + + def assert_sphere(self, dtype, sample_rate, num_channels): + """`soundfile_backend.save` can save sph format.""" + self._assert_non_wav("nist", dtype, sample_rate, num_channels) + + def assert_ogg(self, dtype, sample_rate, num_channels): + """`soundfile_backend.save` can save ogg format. + + As we cannot inspect the OGG format (it's lossy), we only check the metadata. + """ + self._assert_non_wav("ogg", dtype, sample_rate, num_channels) + + +class TestSave(SaveTestBase): + @parameterize( + ["float32", "int32"], + [8000, 16000], + [1, 2], + ) + def test_wav(self, dtype, sample_rate, num_channels): + """`soundfile_backend.save` can save wav format.""" + self.assert_wav(dtype, sample_rate, num_channels, num_frames=None) + + @parameterize( + ["float32", "int32"], + [4, 8, 16, 32], + ) + def test_multiple_channels(self, dtype, num_channels): + """`soundfile_backend.save` can save wav with more than 2 channels.""" + sample_rate = 8000 + self.assert_wav(dtype, sample_rate, num_channels, num_frames=None) + + @parameterize( + ["int32"], + [8000, 16000], + [1, 2], + ) + @skipIfFormatNotSupported("NIST") + def test_sphere(self, dtype, sample_rate, num_channels): + """`soundfile_backend.save` can save sph format.""" + self.assert_sphere(dtype, sample_rate, num_channels) + + @parameterize( + [8000, 16000], + [1, 2], + ) + @skipIfFormatNotSupported("FLAC") + def test_flac(self, sample_rate, num_channels): + """`soundfile_backend.save` can save flac format.""" + self.assert_flac("float32", sample_rate, num_channels) + + @parameterize( + [8000, 16000], + [1, 2], + ) + @skipIfFormatNotSupported("OGG") + def test_ogg(self, sample_rate, num_channels): + """`soundfile_backend.save` can save ogg/vorbis format.""" + self.assert_ogg("float32", sample_rate, num_channels) + + +class TestSaveParams(TempDirMixin, unittest.TestCase): + """Test the correctness of optional parameters of `soundfile_backend.save`""" + + @parameterize([True, False]) + def test_channels_first(self, channels_first): + """channels_first swaps axes""" + path = self.get_temp_path("data.wav") + data = get_wav_data("int32", 2, channels_first=channels_first) + soundfile_backend.save(path, data, 8000, channels_first=channels_first) + found = load_wav(path)[0] + expected = data if channels_first else data.transpose([1, 0]) + #self.assertEqual(found, expected, atol=1e-4, rtol=1e-8) + np.testing.assert_array_almost_equal(found.numpy(), expected.numpy()) + + +class TestFileObject(TempDirMixin, unittest.TestCase): + def _test_fileobj(self, ext): + """Saving audio to file-like object works""" + sample_rate = 16000 + path = self.get_temp_path(f"test.{ext}") + + subtype = "FLOAT" if ext == "wav" else None + data = get_wav_data("float32", num_channels=2) + soundfile.write(path, data.numpy().T, sample_rate, subtype=subtype) + expected = soundfile.read(path, dtype="float32")[0] + + fileobj = io.BytesIO() + soundfile_backend.save(fileobj, data, sample_rate, format=ext) + fileobj.seek(0) + found, sr = soundfile.read(fileobj, dtype="float32") + + assert sr == sample_rate + #self.assertEqual(expected, found, atol=1e-4, rtol=1e-8) + np.testing.assert_array_almost_equal(found, expected) + + def test_fileobj_wav(self): + """Saving audio via file-like object works""" + self._test_fileobj("wav") + + @skipIfFormatNotSupported("FLAC") + def test_fileobj_flac(self): + """Saving audio via file-like object works""" + self._test_fileobj("flac") + + @skipIfFormatNotSupported("NIST") + def test_fileobj_nist(self): + """Saving audio via file-like object works""" + self._test_fileobj("NIST") + + @skipIfFormatNotSupported("OGG") + def test_fileobj_ogg(self): + """Saving audio via file-like object works""" + self._test_fileobj("OGG") + +if __name__ == '__main__': + unittest.main() diff --git a/audio/tests/backends/soundfile/test_io.py b/audio/tests/backends/soundfile/test_io.py new file mode 100644 index 000000000..eed1b39fb --- /dev/null +++ b/audio/tests/backends/soundfile/test_io.py @@ -0,0 +1,74 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import filecmp +import os +import unittest + +import numpy as np +from paddleaudio.backends import soundfile_load as load +from paddleaudio.backends import soundfile_save as save +import soundfile as sf + +from base import BackendTest + + +class TestIO(BackendTest): + def test_load_mono_channel(self): + sf_data, sf_sr = sf.read(self.files[0]) + pa_data, pa_sr = load( + self.files[0], normal=False, dtype='float64') + + self.assertEqual(sf_data.dtype, pa_data.dtype) + self.assertEqual(sf_sr, pa_sr) + np.testing.assert_array_almost_equal(sf_data, pa_data) + + def test_load_multi_channels(self): + sf_data, sf_sr = sf.read(self.files[1]) + sf_data = sf_data.T # Channel dim first + pa_data, pa_sr = load( + self.files[1], mono=False, normal=False, dtype='float64') + + self.assertEqual(sf_data.dtype, pa_data.dtype) + self.assertEqual(sf_sr, pa_sr) + np.testing.assert_array_almost_equal(sf_data, pa_data) + + def test_save_mono_channel(self): + waveform, sr = np.random.randint( + low=-32768, high=32768, size=(48000), dtype=np.int16), 16000 + sf_tmp_file = 'sf_tmp.wav' + pa_tmp_file = 'pa_tmp.wav' + + sf.write(sf_tmp_file, waveform, sr) + save(waveform, sr, pa_tmp_file) + + self.assertTrue(filecmp.cmp(sf_tmp_file, pa_tmp_file)) + for file in [sf_tmp_file, pa_tmp_file]: + os.remove(file) + + def test_save_multi_channels(self): + waveform, sr = np.random.randint( + low=-32768, high=32768, size=(2, 48000), dtype=np.int16), 16000 + sf_tmp_file = 'sf_tmp.wav' + pa_tmp_file = 'pa_tmp.wav' + + sf.write(sf_tmp_file, waveform.T, sr) + save(waveform.T, sr, pa_tmp_file) + + self.assertTrue(filecmp.cmp(sf_tmp_file, pa_tmp_file)) + for file in [sf_tmp_file, pa_tmp_file]: + os.remove(file) + + +if __name__ == '__main__': + unittest.main() diff --git a/audio/tests/benchmark/README.md b/audio/tests/benchmark/README.md new file mode 100644 index 000000000..b9034100d --- /dev/null +++ b/audio/tests/benchmark/README.md @@ -0,0 +1,39 @@ +# 1. Prepare +First, install `pytest-benchmark` via pip. +```sh +pip install pytest-benchmark +``` + +# 2. Run +Run the specific script for profiling. +```sh +pytest melspectrogram.py +``` + +Result: +```sh +========================================================================== test session starts ========================================================================== +platform linux -- Python 3.7.7, pytest-7.0.1, pluggy-1.0.0 +benchmark: 3.4.1 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000) +rootdir: /ssd3/chenxiaojie06/PaddleSpeech/DeepSpeech/paddleaudio +plugins: typeguard-2.12.1, benchmark-3.4.1, anyio-3.5.0 +collected 4 items + +melspectrogram.py .... [100%] + + +-------------------------------------------------------------------------------------------------- benchmark: 4 tests ------------------------------------------------------------------------------------------------- +Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- +test_melspect_gpu_torchaudio 202.0765 (1.0) 360.6230 (1.0) 218.1168 (1.0) 16.3022 (1.0) 214.2871 (1.0) 21.8451 (1.0) 40;3 4,584.7001 (1.0) 286 1 +test_melspect_gpu 657.8509 (3.26) 908.0470 (2.52) 724.2545 (3.32) 106.5771 (6.54) 669.9096 (3.13) 113.4719 (5.19) 1;0 1,380.7300 (0.30) 5 1 +test_melspect_cpu_torchaudio 1,247.6053 (6.17) 2,892.5799 (8.02) 1,443.2853 (6.62) 345.3732 (21.19) 1,262.7263 (5.89) 221.6385 (10.15) 56;53 692.8637 (0.15) 399 1 +test_melspect_cpu 20,326.2549 (100.59) 20,607.8682 (57.15) 20,473.4125 (93.86) 63.8654 (3.92) 20,467.0429 (95.51) 68.4294 (3.13) 8;1 48.8438 (0.01) 29 1 +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + +Legend: + Outliers: 1 Standard Deviation from Mean; 1.5 IQR (InterQuartile Range) from 1st Quartile and 3rd Quartile. + OPS: Operations Per Second, computed as 1 / Mean +========================================================================== 4 passed in 21.12s =========================================================================== + +``` diff --git a/audio/tests/benchmark/log_melspectrogram.py b/audio/tests/benchmark/log_melspectrogram.py new file mode 100644 index 000000000..79b5406d2 --- /dev/null +++ b/audio/tests/benchmark/log_melspectrogram.py @@ -0,0 +1,123 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import urllib.request + +import librosa +import numpy as np +import paddle +import paddleaudio +import torch +import torchaudio + +wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav' +if not os.path.isfile(os.path.basename(wav_url)): + urllib.request.urlretrieve(wav_url, os.path.basename(wav_url)) + +waveform, sr = paddleaudio.backends.soundfile_load(os.path.abspath(os.path.basename(wav_url))) +waveform_tensor = paddle.to_tensor(waveform).unsqueeze(0) +waveform_tensor_torch = torch.from_numpy(waveform).unsqueeze(0) + +# Feature conf +mel_conf = { + 'sr': sr, + 'n_fft': 512, + 'hop_length': 128, + 'n_mels': 40, +} + +mel_conf_torchaudio = { + 'sample_rate': sr, + 'n_fft': 512, + 'hop_length': 128, + 'n_mels': 40, + 'norm': 'slaney', + 'mel_scale': 'slaney', +} + + +def enable_cpu_device(): + paddle.set_device('cpu') + + +def enable_gpu_device(): + paddle.set_device('gpu') + + +log_mel_extractor = paddleaudio.features.LogMelSpectrogram( + **mel_conf, f_min=0.0, top_db=80.0, dtype=waveform_tensor.dtype) + + +def log_melspectrogram(): + return log_mel_extractor(waveform_tensor).squeeze(0) + + +def test_log_melspect_cpu(benchmark): + enable_cpu_device() + feature_paddleaudio = benchmark(log_melspectrogram) + feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) + feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0) + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddleaudio, decimal=3) + + +def test_log_melspect_gpu(benchmark): + enable_gpu_device() + feature_paddleaudio = benchmark(log_melspectrogram) + feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) + feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0) + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddleaudio, decimal=2) + + +mel_extractor_torchaudio = torchaudio.transforms.MelSpectrogram( + **mel_conf_torchaudio, f_min=0.0) +amplitude_to_DB = torchaudio.transforms.AmplitudeToDB('power', top_db=80.0) + + +def melspectrogram_torchaudio(): + return mel_extractor_torchaudio(waveform_tensor_torch).squeeze(0) + + +def log_melspectrogram_torchaudio(): + mel_specgram = mel_extractor_torchaudio(waveform_tensor_torch) + return amplitude_to_DB(mel_specgram).squeeze(0) + + +def test_log_melspect_cpu_torchaudio(benchmark): + global waveform_tensor_torch, mel_extractor_torchaudio, amplitude_to_DB + + mel_extractor_torchaudio = mel_extractor_torchaudio.to('cpu') + waveform_tensor_torch = waveform_tensor_torch.to('cpu') + amplitude_to_DB = amplitude_to_DB.to('cpu') + + feature_paddleaudio = benchmark(log_melspectrogram_torchaudio) + feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) + feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0) + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddleaudio, decimal=3) + + +def test_log_melspect_gpu_torchaudio(benchmark): + global waveform_tensor_torch, mel_extractor_torchaudio, amplitude_to_DB + + mel_extractor_torchaudio = mel_extractor_torchaudio.to('cuda') + waveform_tensor_torch = waveform_tensor_torch.to('cuda') + amplitude_to_DB = amplitude_to_DB.to('cuda') + + feature_torchaudio = benchmark(log_melspectrogram_torchaudio) + feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) + feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0) + np.testing.assert_array_almost_equal( + feature_librosa, feature_torchaudio.cpu(), decimal=2) diff --git a/audio/tests/benchmark/melspectrogram.py b/audio/tests/benchmark/melspectrogram.py new file mode 100644 index 000000000..34e65bcb5 --- /dev/null +++ b/audio/tests/benchmark/melspectrogram.py @@ -0,0 +1,107 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import urllib.request + +import librosa +import numpy as np +import paddle +import paddleaudio +import torch +import torchaudio + +wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav' +if not os.path.isfile(os.path.basename(wav_url)): + urllib.request.urlretrieve(wav_url, os.path.basename(wav_url)) + +waveform, sr = paddleaudio.backends.soundfile_load(os.path.abspath(os.path.basename(wav_url))) +waveform_tensor = paddle.to_tensor(waveform).unsqueeze(0) +waveform_tensor_torch = torch.from_numpy(waveform).unsqueeze(0) + +# Feature conf +mel_conf = { + 'sr': sr, + 'n_fft': 512, + 'hop_length': 128, + 'n_mels': 40, +} + +mel_conf_torchaudio = { + 'sample_rate': sr, + 'n_fft': 512, + 'hop_length': 128, + 'n_mels': 40, + 'norm': 'slaney', + 'mel_scale': 'slaney', +} + + +def enable_cpu_device(): + paddle.set_device('cpu') + + +def enable_gpu_device(): + paddle.set_device('gpu') + + +mel_extractor = paddleaudio.features.MelSpectrogram( + **mel_conf, f_min=0.0, dtype=waveform_tensor.dtype) + + +def melspectrogram(): + return mel_extractor(waveform_tensor).squeeze(0) + + +def test_melspect_cpu(benchmark): + enable_cpu_device() + feature_paddleaudio = benchmark(melspectrogram) + feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddleaudio, decimal=3) + + +def test_melspect_gpu(benchmark): + enable_gpu_device() + feature_paddleaudio = benchmark(melspectrogram) + feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddleaudio, decimal=3) + + +mel_extractor_torchaudio = torchaudio.transforms.MelSpectrogram( + **mel_conf_torchaudio, f_min=0.0) + + +def melspectrogram_torchaudio(): + return mel_extractor_torchaudio(waveform_tensor_torch).squeeze(0) + + +def test_melspect_cpu_torchaudio(benchmark): + global waveform_tensor_torch, mel_extractor_torchaudio + mel_extractor_torchaudio = mel_extractor_torchaudio.to('cpu') + waveform_tensor_torch = waveform_tensor_torch.to('cpu') + feature_paddleaudio = benchmark(melspectrogram_torchaudio) + feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddleaudio, decimal=3) + + +def test_melspect_gpu_torchaudio(benchmark): + global waveform_tensor_torch, mel_extractor_torchaudio + mel_extractor_torchaudio = mel_extractor_torchaudio.to('cuda') + waveform_tensor_torch = waveform_tensor_torch.to('cuda') + feature_torchaudio = benchmark(melspectrogram_torchaudio) + feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) + np.testing.assert_array_almost_equal( + feature_librosa, feature_torchaudio.cpu(), decimal=3) diff --git a/audio/tests/benchmark/mfcc.py b/audio/tests/benchmark/mfcc.py new file mode 100644 index 000000000..4173c4bec --- /dev/null +++ b/audio/tests/benchmark/mfcc.py @@ -0,0 +1,121 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import urllib.request + +import librosa +import numpy as np +import paddle +import paddleaudio +import torch +import torchaudio + +wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav' +if not os.path.isfile(os.path.basename(wav_url)): + urllib.request.urlretrieve(wav_url, os.path.basename(wav_url)) + +waveform, sr = paddleaudio.backends.soundfile_load(os.path.abspath(os.path.basename(wav_url))) +waveform_tensor = paddle.to_tensor(waveform).unsqueeze(0) +waveform_tensor_torch = torch.from_numpy(waveform).unsqueeze(0) + +# Feature conf +mel_conf = { + 'sr': sr, + 'n_fft': 512, + 'hop_length': 128, + 'n_mels': 40, +} +mfcc_conf = { + 'n_mfcc': 20, + 'top_db': 80.0, +} +mfcc_conf.update(mel_conf) + +mel_conf_torchaudio = { + 'sample_rate': sr, + 'n_fft': 512, + 'hop_length': 128, + 'n_mels': 40, + 'norm': 'slaney', + 'mel_scale': 'slaney', +} +mfcc_conf_torchaudio = { + 'sample_rate': sr, + 'n_mfcc': 20, +} + + +def enable_cpu_device(): + paddle.set_device('cpu') + + +def enable_gpu_device(): + paddle.set_device('gpu') + + +mfcc_extractor = paddleaudio.features.MFCC( + **mfcc_conf, f_min=0.0, dtype=waveform_tensor.dtype) + + +def mfcc(): + return mfcc_extractor(waveform_tensor).squeeze(0) + + +def test_mfcc_cpu(benchmark): + enable_cpu_device() + feature_paddleaudio = benchmark(mfcc) + feature_librosa = librosa.feature.mfcc(waveform, **mel_conf) + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddleaudio, decimal=3) + + +def test_mfcc_gpu(benchmark): + enable_gpu_device() + feature_paddleaudio = benchmark(mfcc) + feature_librosa = librosa.feature.mfcc(waveform, **mel_conf) + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddleaudio, decimal=3) + + +del mel_conf_torchaudio['sample_rate'] +mfcc_extractor_torchaudio = torchaudio.transforms.MFCC( + **mfcc_conf_torchaudio, melkwargs=mel_conf_torchaudio) + + +def mfcc_torchaudio(): + return mfcc_extractor_torchaudio(waveform_tensor_torch).squeeze(0) + + +def test_mfcc_cpu_torchaudio(benchmark): + global waveform_tensor_torch, mfcc_extractor_torchaudio + + mel_extractor_torchaudio = mfcc_extractor_torchaudio.to('cpu') + waveform_tensor_torch = waveform_tensor_torch.to('cpu') + + feature_paddleaudio = benchmark(mfcc_torchaudio) + feature_librosa = librosa.feature.mfcc(waveform, **mel_conf) + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddleaudio, decimal=3) + + +def test_mfcc_gpu_torchaudio(benchmark): + global waveform_tensor_torch, mfcc_extractor_torchaudio + + mel_extractor_torchaudio = mfcc_extractor_torchaudio.to('cuda') + waveform_tensor_torch = waveform_tensor_torch.to('cuda') + + feature_torchaudio = benchmark(mfcc_torchaudio) + feature_librosa = librosa.feature.mfcc(waveform, **mel_conf) + np.testing.assert_array_almost_equal( + feature_librosa, feature_torchaudio.cpu(), decimal=3) diff --git a/audio/tests/common_utils/__init__.py b/audio/tests/common_utils/__init__.py new file mode 100644 index 000000000..32b785124 --- /dev/null +++ b/audio/tests/common_utils/__init__.py @@ -0,0 +1,17 @@ +from .wav_utils import get_wav_data, load_wav, save_wav, normalize_wav +from .parameterized_utils import nested_params +from .case_utils import ( + TempDirMixin, + name_func +) + +__all__ = [ + "get_wav_data", + "load_wav", + "save_wav", + "normalize_wav", + "get_sinusoid", + "name_func", + "nested_params", + "TempDirMixin" +] diff --git a/audio/tests/common_utils/case_utils.py b/audio/tests/common_utils/case_utils.py new file mode 100644 index 000000000..328c3de43 --- /dev/null +++ b/audio/tests/common_utils/case_utils.py @@ -0,0 +1,56 @@ +import functools +import os.path +import shutil +import subprocess +import sys +import tempfile +import time +import unittest + +#code is from:https://github.com/pytorch/audio/blob/main/test/torchaudio_unittest/common_utils/case_utils.py + +import paddle + +def name_func(func, _, params): + return f'{func.__name__}_{"_".join(str(arg) for arg in params.args)}' + +class TempDirMixin: + """Mixin to provide easy access to temp dir""" + + temp_dir_ = None + + @classmethod + def get_base_temp_dir(cls): + # If PADDLEAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory. + # this is handy for debugging. + key = "PADDLEAUDIO_TEST_TEMP_DIR" + if key in os.environ: + return os.environ[key] + if cls.temp_dir_ is None: + cls.temp_dir_ = tempfile.TemporaryDirectory() + return cls.temp_dir_.name + + @classmethod + def tearDownClass(cls): + if cls.temp_dir_ is not None: + try: + cls.temp_dir_.cleanup() + cls.temp_dir_ = None + except PermissionError: + # On Windows there is a know issue with `shutil.rmtree`, + # which fails intermittenly. + # + # https://github.com/python/cpython/issues/74168 + # + # We observed this on CircleCI, where Windows job raises + # PermissionError. + # + # Following the above thread, we ignore it. + pass + super().tearDownClass() + + def get_temp_path(self, *paths): + temp_dir = os.path.join(self.get_base_temp_dir(), self.id()) + path = os.path.join(temp_dir, *paths) + os.makedirs(os.path.dirname(path), exist_ok=True) + return path diff --git a/audio/tests/common_utils/parameterized_utils.py b/audio/tests/common_utils/parameterized_utils.py new file mode 100644 index 000000000..d27c27469 --- /dev/null +++ b/audio/tests/common_utils/parameterized_utils.py @@ -0,0 +1,43 @@ +import json +from itertools import product +import os + +from parameterized import param, parameterized + +def _name_func(func, _, params): + strs = [] + for arg in params.args: + if isinstance(arg, tuple): + strs.append("_".join(str(a) for a in arg)) + else: + strs.append(str(arg)) + # sanitize the test name + name = "_".join(strs) + return parameterized.to_safe_name(f"{func.__name__}_{name}") + + +def nested_params(*params_set, name_func=_name_func): + """Generate the cartesian product of the given list of parameters. + + Args: + params_set (list of parameters): Parameters. When using ``parameterized.param`` class, + all the parameters have to be specified with the class, only using kwargs. + """ + flatten = [p for params in params_set for p in params] + + # Parameters to be nested are given as list of plain objects + if all(not isinstance(p, param) for p in flatten): + args = list(product(*params_set)) + return parameterized.expand(args, name_func=_name_func) + + # Parameters to be nested are given as list of `parameterized.param` + if not all(isinstance(p, param) for p in flatten): + raise TypeError("When using ``parameterized.param``, " "all the parameters have to be of the ``param`` type.") + if any(p.args for p in flatten): + raise ValueError( + "When using ``parameterized.param``, " "all the parameters have to be provided as keyword argument." + ) + args = [param()] + for params in params_set: + args = [param(**x.kwargs, **y.kwargs) for x in args for y in params] + return parameterized.expand(args) diff --git a/audio/tests/common_utils/wav_utils.py b/audio/tests/common_utils/wav_utils.py new file mode 100644 index 000000000..25d0b1971 --- /dev/null +++ b/audio/tests/common_utils/wav_utils.py @@ -0,0 +1,102 @@ +from typing import Optional + +import scipy.io.wavfile +import paddle +import numpy as np + +def normalize_wav(tensor: paddle.Tensor) -> paddle.Tensor: + if tensor.dtype == paddle.float32: + pass + elif tensor.dtype == paddle.int32: + tensor = paddle.cast(tensor, paddle.float32) + tensor[tensor > 0] /= 2147483647.0 + tensor[tensor < 0] /= 2147483648.0 + elif tensor.dtype == paddle.int16: + tensor = paddle.cast(tensor, paddle.float32) + tensor[tensor > 0] /= 32767.0 + tensor[tensor < 0] /= 32768.0 + elif tensor.dtype == paddle.uint8: + tensor = paddle.cast(tensor, paddle.float32) - 128 + tensor[tensor > 0] /= 127.0 + tensor[tensor < 0] /= 128.0 + return tensor + + +def get_wav_data( + dtype: str, + num_channels: int, + *, + num_frames: Optional[int] = None, + normalize: bool = True, + channels_first: bool = True, +): + """Generate linear signal of the given dtype and num_channels + + Data range is + [-1.0, 1.0] for float32, + [-2147483648, 2147483647] for int32 + [-32768, 32767] for int16 + [0, 255] for uint8 + + num_frames allow to change the linear interpolation parameter. + Default values are 256 for uint8, else 1 << 16. + 1 << 16 as default is so that int16 value range is completely covered. + """ + dtype_ = getattr(paddle, dtype) + + if num_frames is None: + if dtype == "uint8": + num_frames = 256 + else: + num_frames = 1 << 16 + + # paddle linspace not support uint8, int8, int16 + #if dtype == "uint8": + # base = paddle.linspace(0, 255, num_frames, dtype=dtype_) + #dtype_np = getattr(np, dtype) + #base_np = np.linspace(0, 255, num_frames, dtype_np) + #base = paddle.to_tensor(base_np, dtype=dtype_) + #elif dtype == "int8": + # base = paddle.linspace(-128, 127, num_frames, dtype=dtype_) + #dtype_np = getattr(np, dtype) + #base_np = np.linspace(-128, 127, num_frames, dtype_np) + #base = paddle.to_tensor(base_np, dtype=dtype_) + if dtype == "float32": + base = paddle.linspace(-1.0, 1.0, num_frames, dtype=dtype_) + elif dtype == "float64": + base = paddle.linspace(-1.0, 1.0, num_frames, dtype=dtype_) + elif dtype == "int32": + base = paddle.linspace(-2147483648, 2147483647, num_frames, dtype=dtype_) + #elif dtype == "int16": + # base = paddle.linspace(-32768, 32767, num_frames, dtype=dtype_) + #dtype_np = getattr(np, dtype) + #base_np = np.linspace(-32768, 32767, num_frames, dtype_np) + #base = paddle.to_tensor(base_np, dtype=dtype_) + else: + raise NotImplementedError(f"Unsupported dtype {dtype}") + data = base.tile([num_channels, 1]) + if not channels_first: + data = data.transpose([1, 0]) + if normalize: + data = normalize_wav(data) + return data + + +def load_wav(path: str, normalize=True, channels_first=True) -> paddle.Tensor: + """Load wav file without paddleaudio""" + sample_rate, data = scipy.io.wavfile.read(path) + data = paddle.to_tensor(data.copy()) + if data.ndim == 1: + data = data.unsqueeze(1) + if normalize: + data = normalize_wav(data) + if channels_first: + data = data.transpose([1, 0]) + return data, sample_rate + + +def save_wav(path, data, sample_rate, channels_first=True): + """Save wav file without paddleaudio""" + if channels_first: + data = data.transpose([1, 0]) + scipy.io.wavfile.write(path, sample_rate, data.numpy()) diff --git a/audio/tests/features/__init__.py b/audio/tests/features/__init__.py new file mode 100644 index 000000000..97043fd7b --- /dev/null +++ b/audio/tests/features/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/audio/tests/features/base.py b/audio/tests/features/base.py new file mode 100644 index 000000000..d183b72ad --- /dev/null +++ b/audio/tests/features/base.py @@ -0,0 +1,48 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import unittest +import urllib.request + +import numpy as np +import paddle +from paddleaudio.backends import soundfile_load as load + +wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav' + + +class FeatTest(unittest.TestCase): + def setUp(self): + self.initParmas() + self.initWavInput() + self.setUpDevice() + + def setUpDevice(self, device='cpu'): + paddle.set_device(device) + + def initWavInput(self, url=wav_url): + if not os.path.isfile(os.path.basename(url)): + urllib.request.urlretrieve(url, os.path.basename(url)) + self.waveform, self.sr = load(os.path.abspath(os.path.basename(url))) + self.waveform = self.waveform.astype( + np.float32 + ) # paddlespeech.s2t.transform.spectrogram only supports float32 + dim = len(self.waveform.shape) + + assert dim in [1, 2] + if dim == 1: + self.waveform = np.expand_dims(self.waveform, 0) + + def initParmas(self): + raise NotImplementedError diff --git a/audio/tests/features/test_kaldi.py b/audio/tests/features/test_kaldi.py new file mode 100644 index 000000000..2bd5dc734 --- /dev/null +++ b/audio/tests/features/test_kaldi.py @@ -0,0 +1,81 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +import paddle +import paddleaudio +import torch +import torchaudio + +from base import FeatTest + + +class TestKaldi(FeatTest): + def initParmas(self): + self.window_size = 1024 + self.dtype = 'float32' + + def test_window(self): + t_hann_window = torch.hann_window( + self.window_size, periodic=False, dtype=eval(f'torch.{self.dtype}')) + t_hamm_window = torch.hamming_window( + self.window_size, + periodic=False, + alpha=0.54, + beta=0.46, + dtype=eval(f'torch.{self.dtype}')) + t_povey_window = torch.hann_window( + self.window_size, periodic=False, + dtype=eval(f'torch.{self.dtype}')).pow(0.85) + + p_hann_window = paddleaudio.functional.window.get_window( + 'hann', + self.window_size, + fftbins=False, + dtype=eval(f'paddle.{self.dtype}')) + p_hamm_window = paddleaudio.functional.window.get_window( + 'hamming', + self.window_size, + fftbins=False, + dtype=eval(f'paddle.{self.dtype}')) + p_povey_window = paddleaudio.functional.window.get_window( + 'hann', + self.window_size, + fftbins=False, + dtype=eval(f'paddle.{self.dtype}')).pow(0.85) + + np.testing.assert_array_almost_equal(t_hann_window, p_hann_window) + np.testing.assert_array_almost_equal(t_hamm_window, p_hamm_window) + np.testing.assert_array_almost_equal(t_povey_window, p_povey_window) + + def test_fbank(self): + ta_features = torchaudio.compliance.kaldi.fbank( + torch.from_numpy(self.waveform.astype(self.dtype))) + pa_features = paddleaudio.compliance.kaldi.fbank( + paddle.to_tensor(self.waveform.astype(self.dtype))) + np.testing.assert_array_almost_equal( + ta_features, pa_features, decimal=4) + + def test_mfcc(self): + ta_features = torchaudio.compliance.kaldi.mfcc( + torch.from_numpy(self.waveform.astype(self.dtype))) + pa_features = paddleaudio.compliance.kaldi.mfcc( + paddle.to_tensor(self.waveform.astype(self.dtype))) + np.testing.assert_array_almost_equal( + ta_features, pa_features, decimal=4) + + +if __name__ == '__main__': + unittest.main() diff --git a/audio/tests/features/test_librosa.py b/audio/tests/features/test_librosa.py new file mode 100644 index 000000000..19d094b4b --- /dev/null +++ b/audio/tests/features/test_librosa.py @@ -0,0 +1,281 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import librosa +import numpy as np +import paddle +import paddleaudio +from paddleaudio.functional.window import get_window + +from base import FeatTest + + +class TestLibrosa(FeatTest): + def initParmas(self): + self.n_fft = 512 + self.hop_length = 128 + self.n_mels = 40 + self.n_mfcc = 20 + self.fmin = 0.0 + self.window_str = 'hann' + self.pad_mode = 'reflect' + self.top_db = 80.0 + + def test_stft(self): + if len(self.waveform.shape) == 2: # (C, T) + self.waveform = self.waveform.squeeze( + 0) # 1D input for librosa.feature.melspectrogram + + feature_librosa = librosa.core.stft( + y=self.waveform, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=None, + window=self.window_str, + center=True, + dtype=None, + pad_mode=self.pad_mode, ) + x = paddle.to_tensor(self.waveform).unsqueeze(0) + window = get_window(self.window_str, self.n_fft, dtype=x.dtype) + feature_paddle = paddle.signal.stft( + x=x, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=None, + window=window, + center=True, + pad_mode=self.pad_mode, + normalized=False, + onesided=True, ).squeeze(0) + + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddle, decimal=5) + + def test_istft(self): + if len(self.waveform.shape) == 2: # (C, T) + self.waveform = self.waveform.squeeze( + 0) # 1D input for librosa.feature.melspectrogram + + # Get stft result from librosa. + stft_matrix = librosa.core.stft( + y=self.waveform, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=None, + window=self.window_str, + center=True, + pad_mode=self.pad_mode, ) + + feature_librosa = librosa.core.istft( + stft_matrix=stft_matrix, + hop_length=self.hop_length, + win_length=None, + window=self.window_str, + center=True, + dtype=None, + length=None, ) + + x = paddle.to_tensor(stft_matrix).unsqueeze(0) + window = get_window( + self.window_str, + self.n_fft, + dtype=paddle.to_tensor(self.waveform).dtype) + feature_paddle = paddle.signal.istft( + x=x, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=None, + window=window, + center=True, + normalized=False, + onesided=True, + length=None, + return_complex=False, ).squeeze(0) + + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddle, decimal=5) + + def test_mel(self): + feature_librosa = librosa.filters.mel( + sr=self.sr, + n_fft=self.n_fft, + n_mels=self.n_mels, + fmin=self.fmin, + fmax=None, + htk=False, + norm='slaney', + dtype=self.waveform.dtype, ) + feature_compliance = paddleaudio.compliance.librosa.compute_fbank_matrix( + sr=self.sr, + n_fft=self.n_fft, + n_mels=self.n_mels, + fmin=self.fmin, + fmax=None, + htk=False, + norm='slaney', + dtype=self.waveform.dtype, ) + x = paddle.to_tensor(self.waveform) + feature_functional = paddleaudio.functional.compute_fbank_matrix( + sr=self.sr, + n_fft=self.n_fft, + n_mels=self.n_mels, + f_min=self.fmin, + f_max=None, + htk=False, + norm='slaney', + dtype=x.dtype, ) + + np.testing.assert_array_almost_equal(feature_librosa, + feature_compliance) + np.testing.assert_array_almost_equal(feature_librosa, + feature_functional) + + def test_melspect(self): + if len(self.waveform.shape) == 2: # (C, T) + self.waveform = self.waveform.squeeze( + 0) # 1D input for librosa.feature.melspectrogram + + # librosa: + feature_librosa = librosa.feature.melspectrogram( + y=self.waveform, + sr=self.sr, + n_fft=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + fmin=self.fmin) + + # paddleaudio.compliance.librosa: + feature_compliance = paddleaudio.compliance.librosa.melspectrogram( + x=self.waveform, + sr=self.sr, + window_size=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + fmin=self.fmin, + to_db=False) + + # paddleaudio.features.layer + x = paddle.to_tensor( + self.waveform, dtype=paddle.float64).unsqueeze(0) # Add batch dim. + feature_extractor = paddleaudio.features.MelSpectrogram( + sr=self.sr, + n_fft=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + f_min=self.fmin, + dtype=x.dtype) + feature_layer = feature_extractor(x).squeeze(0).numpy() + + np.testing.assert_array_almost_equal( + feature_librosa, feature_compliance, decimal=5) + np.testing.assert_array_almost_equal( + feature_librosa, feature_layer, decimal=5) + + def test_log_melspect(self): + if len(self.waveform.shape) == 2: # (C, T) + self.waveform = self.waveform.squeeze( + 0) # 1D input for librosa.feature.melspectrogram + + # librosa: + feature_librosa = librosa.feature.melspectrogram( + y=self.waveform, + sr=self.sr, + n_fft=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + fmin=self.fmin) + feature_librosa = librosa.power_to_db(feature_librosa, top_db=None) + + # paddleaudio.compliance.librosa: + feature_compliance = paddleaudio.compliance.librosa.melspectrogram( + x=self.waveform, + sr=self.sr, + window_size=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + fmin=self.fmin) + + # paddleaudio.features.layer + x = paddle.to_tensor( + self.waveform, dtype=paddle.float64).unsqueeze(0) # Add batch dim. + feature_extractor = paddleaudio.features.LogMelSpectrogram( + sr=self.sr, + n_fft=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + f_min=self.fmin, + dtype=x.dtype) + feature_layer = feature_extractor(x).squeeze(0).numpy() + + np.testing.assert_array_almost_equal( + feature_librosa, feature_compliance, decimal=5) + np.testing.assert_array_almost_equal( + feature_librosa, feature_layer, decimal=4) + + def test_mfcc(self): + if len(self.waveform.shape) == 2: # (C, T) + self.waveform = self.waveform.squeeze( + 0) # 1D input for librosa.feature.melspectrogram + + # librosa: + feature_librosa = librosa.feature.mfcc( + y=self.waveform, + sr=self.sr, + S=None, + n_mfcc=self.n_mfcc, + dct_type=2, + norm='ortho', + lifter=0, + n_fft=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + fmin=self.fmin) + + # paddleaudio.compliance.librosa: + feature_compliance = paddleaudio.compliance.librosa.mfcc( + x=self.waveform, + sr=self.sr, + n_mfcc=self.n_mfcc, + dct_type=2, + norm='ortho', + lifter=0, + window_size=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + fmin=self.fmin, + top_db=self.top_db) + + # paddleaudio.features.layer + x = paddle.to_tensor( + self.waveform, dtype=paddle.float64).unsqueeze(0) # Add batch dim. + feature_extractor = paddleaudio.features.MFCC( + sr=self.sr, + n_mfcc=self.n_mfcc, + n_fft=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + f_min=self.fmin, + top_db=self.top_db, + dtype=x.dtype) + feature_layer = feature_extractor(x).squeeze(0).numpy() + + np.testing.assert_array_almost_equal( + feature_librosa, feature_compliance, decimal=4) + np.testing.assert_array_almost_equal( + feature_librosa, feature_layer, decimal=4) + + +if __name__ == '__main__': + unittest.main()