From 22a5344bcfbd87dd4b274b09caecd8cf71a4c6e3 Mon Sep 17 00:00:00 2001 From: YangZhou Date: Fri, 12 Aug 2022 16:49:24 +0800 Subject: [PATCH] fix save && effect test --- paddlespeech/audio/backends/sox_io_backend.py | 4 +- paddlespeech/audio/sox_effects/sox_effects.py | 29 +- paddlespeech/audio/src/pybind/pybind.cpp | 6 +- .../audio/src/pybind/sox/effects_chain.cpp | 1 - paddlespeech/audio/utils/sox_utils.py | 2 +- tests/unit/assets/sox_effect_test_args.jsonl | 78 ++++ tests/unit/audio/backends/sox_io/save_test.py | 2 - .../unit/audio/backends/sox_io/smoke_test.py | 183 +++++++++ .../audio/backends/sox_io/sox_effect_test.py | 346 ++++++++++++++++++ tests/unit/common_utils/__init__.py | 9 +- tests/unit/common_utils/case_utils.py | 3 + .../unit/common_utils/parameterized_utils.py | 27 +- tests/unit/common_utils/wav_utils.py | 10 + 13 files changed, 670 insertions(+), 30 deletions(-) create mode 100644 tests/unit/assets/sox_effect_test_args.jsonl create mode 100644 tests/unit/audio/backends/sox_io/smoke_test.py create mode 100644 tests/unit/audio/backends/sox_io/sox_effect_test.py diff --git a/paddlespeech/audio/backends/sox_io_backend.py b/paddlespeech/audio/backends/sox_io_backend.py index beb6ddb9d..2037ad81d 100644 --- a/paddlespeech/audio/backends/sox_io_backend.py +++ b/paddlespeech/audio/backends/sox_io_backend.py @@ -88,9 +88,9 @@ def save(filepath: str, ) @_mod_utils.requires_sox() -def info(filepath: str, format: Optional[str]) -> None: +def info(filepath: str, format: Optional[str] = "") -> None: if hasattr(filepath, "read"): - sinfo = paddleaudio.get_info_fileojb(filepath, format) + sinfo = paddleaudio.get_info_fileobj(filepath, format) if sinfo is not None: return AudioMetaData(*sinfo) return _fallback_info_fileobj(filepath, format) diff --git a/paddlespeech/audio/sox_effects/sox_effects.py b/paddlespeech/audio/sox_effects/sox_effects.py index 1a3f3af29..17d2d95af 100644 --- a/paddlespeech/audio/sox_effects/sox_effects.py +++ b/paddlespeech/audio/sox_effects/sox_effects.py @@ -1,5 +1,7 @@ import os from typing import List, Optional, Tuple +import paddle +import numpy from paddlespeech.audio._internal import module_utils as _mod_utils from paddlespeech.audio.utils.sox_utils import list_effects @@ -52,11 +54,11 @@ def effect_names() -> List[str]: @_mod_utils.requires_sox() def apply_effects_tensor( - tensor: torch.Tensor, + tensor: paddle.Tensor, sample_rate: int, effects: List[List[str]], channels_first: bool = True, -) -> Tuple[torch.Tensor, int]: +) -> Tuple[paddle.Tensor, int]: """Apply sox effects to given Tensor .. devices:: CPU @@ -152,7 +154,11 @@ def apply_effects_tensor( >>> waveform, sample_rate = transform(waveform, input_sample_rate) >>> assert sample_rate == 8000 """ - return paddleaudio.sox_effects_apply_effects_tensor(tensor, sample_rate, effects, channels_first) + tensor_np = tensor.numpy() + ret = paddleaudio.sox_effects_apply_effects_tensor(tensor_np, sample_rate, effects, channels_first) + if ret is not None: + return (paddle.to_tensor(ret[0]), ret[1]) + raise RuntimeError("Failed to apply sox effect") @_mod_utils.requires_sox() @@ -162,7 +168,7 @@ def apply_effects_file( normalize: bool = True, channels_first: bool = True, format: Optional[str] = None, -) -> Tuple[torch.Tensor, int]: +) -> Tuple[paddle.Tensor, int]: """Apply sox effects to the audio file and load the resulting data as Tensor .. devices:: CPU @@ -270,14 +276,13 @@ def apply_effects_file( >>> for batch in loader: >>> pass """ - if not torch.jit.is_scripting(): - if hasattr(path, "read"): - ret = paddleaudio._paddleaudio.apply_effects_fileobj(path, effects, normalize, channels_first, format) - if ret is None: - raise RuntimeError("Failed to load audio from {}".format(path)) - return ret - path = os.fspath(path) + if hasattr(path, "read"): + ret = paddleaudio.apply_effects_fileobj(path, effects, normalize, channels_first, format) + if ret is None: + raise RuntimeError("Failed to load audio from {}".format(path)) + return (paddle.to_tensor(ret[0]), ret[1]) + path = os.fspath(path) ret = paddleaudio.sox_effects_apply_effects_file(path, effects, normalize, channels_first, format) if ret is not None: - return ret + return (paddle.to_tensor(ret[0]), ret[1]) raise RuntimeError("Failed to load audio from {}".format(path)) \ No newline at end of file diff --git a/paddlespeech/audio/src/pybind/pybind.cpp b/paddlespeech/audio/src/pybind/pybind.cpp index 24cf0eb18..b265a2ab1 100644 --- a/paddlespeech/audio/src/pybind/pybind.cpp +++ b/paddlespeech/audio/src/pybind/pybind.cpp @@ -65,9 +65,9 @@ PYBIND11_MODULE(_paddleaudio, m) { &paddleaudio::sox_utils::get_buffer_size); // effect - //m.def("apply_effects_fileobj", - // &paddleaudio::sox_effects::apply_effects_fileobj, - // "Decode audio data from file-like obj and apply effects."); + m.def("apply_effects_fileobj", + &paddleaudio::sox_effects::apply_effects_fileobj, + "Decode audio data from file-like obj and apply effects."); m.def("sox_effects_initialize_sox_effects", &paddleaudio::sox_effects::initialize_sox_effects); m.def( diff --git a/paddlespeech/audio/src/pybind/sox/effects_chain.cpp b/paddlespeech/audio/src/pybind/sox/effects_chain.cpp index 15fc6d26e..5e8f6ee71 100644 --- a/paddlespeech/audio/src/pybind/sox/effects_chain.cpp +++ b/paddlespeech/audio/src/pybind/sox/effects_chain.cpp @@ -59,7 +59,6 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) { switch (tensor.dtype().num()) { //case c10::ScalarType::Float: { case 11: { - break; // Need to convert to 64-bit precision so that // values around INT32_MIN/MAX are handled correctly. for (int idx = 0; idx < chunk.size(); ++idx) { diff --git a/paddlespeech/audio/utils/sox_utils.py b/paddlespeech/audio/utils/sox_utils.py index fb19ff316..37696a5d9 100644 --- a/paddlespeech/audio/utils/sox_utils.py +++ b/paddlespeech/audio/utils/sox_utils.py @@ -31,7 +31,7 @@ def set_verbosity(verbosity: int): See Also: http://sox.sourceforge.net/sox.html """ - _paddleaudio.sox_utils_set_verbosity(verbosity) + _paddleaudio.sox_utils_set_verbosity(verbosity) @_mod_utils.requires_sox() diff --git a/tests/unit/assets/sox_effect_test_args.jsonl b/tests/unit/assets/sox_effect_test_args.jsonl new file mode 100644 index 000000000..b005515bb --- /dev/null +++ b/tests/unit/assets/sox_effect_test_args.jsonl @@ -0,0 +1,78 @@ +{"effects": [["allpass", "300", "10"]]} +{"effects": [["band", "300", "10"]]} +{"effects": [["bandpass", "300", "10"]]} +{"effects": [["bandreject", "300", "10"]]} +{"effects": [["bass", "-10"]]} +{"effects": [["biquad", "0.4", "0.2", "0.9", "0.7", "0.2", "0.6"]]} +{"effects": [["chorus", "0.7", "0.9", "55", "0.4", "0.25", "2", "-t"]]} +{"effects": [["chorus", "0.6", "0.9", "50", "0.4", "0.25", "2", "-t", "60", "0.32", "0.4", "1.3", "-s"]]} +{"effects": [["chorus", "0.5", "0.9", "50", "0.4", "0.25", "2", "-t", "60", "0.32", "0.4", "2.3", "-t", "40", "0.3", "0.3", "1.3", "-s"]]} +{"effects": [["channels", "1"]]} +{"effects": [["channels", "2"]]} +{"effects": [["channels", "3"]]} +{"effects": [["compand", "0.3,1", "6:-70,-60,-20", "-5", "-90", "0.2"]]} +{"effects": [["compand", ".1,.2", "-inf,-50.1,-inf,-50,-50", "0", "-90", ".1"]]} +{"effects": [["compand", ".1,.1", "-45.1,-45,-inf,0,-inf", "45", "-90", ".1"]]} +{"effects": [["contrast", "0"]]} +{"effects": [["contrast", "25"]]} +{"effects": [["contrast", "50"]]} +{"effects": [["contrast", "75"]]} +{"effects": [["contrast", "100"]]} +{"effects": [["dcshift", "1.0"]]} +{"effects": [["dcshift", "-1.0"]]} +{"effects": [["deemph"]], "input_sample_rate": 44100} +{"effects": [["dither", "-s"]]} +{"effects": [["dither", "-S"]]} +{"effects": [["divide"]]} +{"effects": [["downsample", "2"]], "input_sample_rate": 8000, "output_sample_rate": 4000} +{"effects": [["earwax"]], "input_sample_rate": 44100} +{"effects": [["echo", "0.8", "0.88", "60", "0.4"]]} +{"effects": [["echo", "0.8", "0.88", "6", "0.4"]]} +{"effects": [["echo", "0.8", "0.9", "1000", "0.3"]]} +{"effects": [["echo", "0.8", "0.9", "1000", "0.3", "1800", "0.25"]]} +{"effects": [["echos", "0.8", "0.7", "700", "0.25", "700", "0.3"]]} +{"effects": [["echos", "0.8", "0.7", "700", "0.25", "900", "0.3"]]} +{"effects": [["echos", "0.8", "0.7", "40", "0.25", "63", "0.3"]]} +{"effects": [["equalizer", "300", "10", "5"]]} +{"effects": [["fade", "q", "3"]]} +{"effects": [["fade", "h", "3"]]} +{"effects": [["fade", "t", "3"]]} +{"effects": [["fade", "l", "3"]]} +{"effects": [["fade", "p", "3"]]} +{"effects": [["fir", "0.0195", "-0.082", "0.234", "0.891", "-0.145", "0.043"]]} +{"effects": [["fir", "/sox_effect_test_fir_coeffs.txt"]]} +{"effects": [["flanger"]]} +{"effects": [["gain", "-l", "-6"]]} +{"effects": [["highpass", "-1", "300"]]} +{"effects": [["highpass", "-2", "300"]]} +{"effects": [["hilbert"]]} +{"effects": [["loudness"]]} +{"effects": [["lowpass", "-1", "300"]]} +{"effects": [["lowpass", "-2", "300"]]} +{"effects": [["mcompand", "0.005,0.1 -47,-40,-34,-34,-17,-33", "100", "0.003,0.05 -47,-40,-34,-34,-17,-33", "400", "0.000625,0.0125 -47,-40,-34,-34,-15,-33", "1600", "0.0001,0.025 -47,-40,-34,-34,-31,-31,-0,-30", "6400", "0,0.025 -38,-31,-28,-28,-0,-25"]], "input_sample_rate": 44100} +{"effects": [["oops"]]} +{"effects": [["overdrive"]]} +{"effects": [["pad"]]} +{"effects": [["phaser"]]} +{"effects": [["remix", "6", "7", "8", "0"]], "num_channels": 8} +{"effects": [["remix", "1-3,7", "3"]], "num_channels": 8} +{"effects": [["repeat"]]} +{"effects": [["reverb"]]} +{"effects": [["reverse"]]} +{"effects": [["riaa"]], "input_sample_rate": 44100} +{"effects": [["silence", "0"]]} +{"effects": [["speed", "1.3"]], "input_sample_rate": 4000, "output_sample_rate": 5200} +{"effects": [["speed", "0.7"]], "input_sample_rate": 4000, "output_sample_rate": 2800} +{"effects": [["stat"]]} +{"effects": [["stats"]]} +{"effects": [["stretch"]]} +{"effects": [["swap"]]} +{"effects": [["synth"]]} +{"effects": [["tempo", "0.9"]]} +{"effects": [["tempo", "1.1"]]} +{"effects": [["treble", "3"]]} +{"effects": [["tremolo", "300", "40"]]} +{"effects": [["tremolo", "300", "50"]]} +{"effects": [["trim", "0", "0.1"]]} +{"effects": [["upsample", "2"]], "input_sample_rate": 8000, "output_sample_rate": 16000} +{"effects": [["vol", "3"]]} diff --git a/tests/unit/audio/backends/sox_io/save_test.py b/tests/unit/audio/backends/sox_io/save_test.py index 269c502a3..b07af70f2 100644 --- a/tests/unit/audio/backends/sox_io/save_test.py +++ b/tests/unit/audio/backends/sox_io/save_test.py @@ -164,8 +164,6 @@ class TestSave(TestSaveBase, unittest.TestCase): [ ("float32",), ("int32",), - ("int16",), - ("uint8",), ], ) def test_save_wav_dtype(self, test_mode, params): diff --git a/tests/unit/audio/backends/sox_io/smoke_test.py b/tests/unit/audio/backends/sox_io/smoke_test.py new file mode 100644 index 000000000..1f191bc51 --- /dev/null +++ b/tests/unit/audio/backends/sox_io/smoke_test.py @@ -0,0 +1,183 @@ +import io +import itertools +import unittest + +from parameterized import parameterized +from paddlespeech.audio.backends import sox_io_backend +from tests.unit.common_utils import ( + get_wav_data, + TempDirMixin, + name_func +) + +class SmokeTest(TempDirMixin, unittest.TestCase): + """Run smoke test on various audio format + + The purpose of this test suite is to verify that sox_io_backend functionalities do not exhibit + abnormal behaviors. + + This test suite should be able to run without any additional tools (such as sox command), + however without such tools, the correctness of each function cannot be verified. + """ + + def run_smoke_test(self, ext, sample_rate, num_channels, *, compression=None, dtype="float32"): + duration = 1 + num_frames = sample_rate * duration + #path = self.get_temp_path(f"test.{ext}") + path = self.get_temp_path(f"test.{ext}") + original = get_wav_data(dtype, num_channels, normalize=False, num_frames=num_frames) + + # 1. run save + sox_io_backend.save(path, original, sample_rate, compression=compression) + # 2. run info + info = sox_io_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_channels == num_channels + # 3. run load + loaded, sr = sox_io_backend.load(path, normalize=False) + assert sr == sample_rate + assert loaded.shape[0] == num_channels + + @parameterized.expand( + list( + itertools.product( + ["float32", "int32" ], + #["float32", "int32", "int16", "uint8"], + [8000, 16000], + [1, 2], + ) + ), + name_func=name_func, + ) + def test_wav(self, dtype, sample_rate, num_channels): + """Run smoke test on wav format""" + self.run_smoke_test("wav", sample_rate, num_channels, dtype=dtype) + + #@parameterized.expand( + #list( + #itertools.product( + #[8000, 16000], + #[1, 2], + #[-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320], + #) + #) + #) + #def test_mp3(self, sample_rate, num_channels, bit_rate): + #"""Run smoke test on mp3 format""" + #self.run_smoke_test("mp3", sample_rate, num_channels, compression=bit_rate) + + #@parameterized.expand( + #list( + #itertools.product( + #[8000, 16000], + #[1, 2], + #[-1, 0, 1, 2, 3, 3.6, 5, 10], + #) + #) + #) + #def test_vorbis(self, sample_rate, num_channels, quality_level): + #"""Run smoke test on vorbis format""" + #self.run_smoke_test("vorbis", sample_rate, num_channels, compression=quality_level) + + @parameterized.expand( + list( + itertools.product( + [8000, 16000], + [1, 2], + list(range(9)), + ) + ), + name_func=name_func, + ) + def test_flac(self, sample_rate, num_channels, compression_level): + """Run smoke test on flac format""" + self.run_smoke_test("flac", sample_rate, num_channels, compression=compression_level) + + +class SmokeTestFileObj(unittest.TestCase): + """Run smoke test on various audio format + + The purpose of this test suite is to verify that sox_io_backend functionalities do not exhibit + abnormal behaviors. + + This test suite should be able to run without any additional tools (such as sox command), + however without such tools, the correctness of each function cannot be verified. + """ + + def run_smoke_test(self, ext, sample_rate, num_channels, *, compression=None, dtype="float32"): + duration = 1 + num_frames = sample_rate * duration + original = get_wav_data(dtype, num_channels, normalize=False, num_frames=num_frames) + + fileobj = io.BytesIO() + # 1. run save + sox_io_backend.save(fileobj, original, sample_rate, compression=compression, format=ext) + # 2. run info + fileobj.seek(0) + info = sox_io_backend.info(fileobj, format=ext) + assert info.sample_rate == sample_rate + assert info.num_channels == num_channels + # 3. run load + fileobj.seek(0) + loaded, sr = sox_io_backend.load(fileobj, normalize=False, format=ext) + assert sr == sample_rate + assert loaded.shape[0] == num_channels + + @parameterized.expand( + list( + itertools.product( + ["float32", "int32"], + [8000, 16000], + [1, 2], + ) + ), + name_func=name_func, + ) + def test_wav(self, dtype, sample_rate, num_channels): + """Run smoke test on wav format""" + self.run_smoke_test("wav", sample_rate, num_channels, dtype=dtype) + + # not support yet + #@parameterized.expand( + #list( + #itertools.product( + #[8000, 16000], + #[1, 2], + #[-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320], + #) + #) + #) + #def test_mp3(self, sample_rate, num_channels, bit_rate): + #"""Run smoke test on mp3 format""" + #self.run_smoke_test("mp3", sample_rate, num_channels, compression=bit_rate) + + #@parameterized.expand( + #list( + #itertools.product( + #[8000, 16000], + #[1, 2], + #[-1, 0, 1, 2, 3, 3.6, 5, 10], + #) + #) + #) + #def test_vorbis(self, sample_rate, num_channels, quality_level): + #"""Run smoke test on vorbis format""" + #self.run_smoke_test("vorbis", sample_rate, num_channels, compression=quality_level) + + @parameterized.expand( + list( + itertools.product( + [8000, 16000], + [1, 2], + list(range(9)), + ) + ), + name_func=name_func, + ) + def test_flac(self, sample_rate, num_channels, compression_level): + #"""Run smoke test on flac format""" + self.run_smoke_test("flac", sample_rate, num_channels, compression=compression_level) + +if __name__ == '__main__': + #test_func() + unittest.main() diff --git a/tests/unit/audio/backends/sox_io/sox_effect_test.py b/tests/unit/audio/backends/sox_io/sox_effect_test.py new file mode 100644 index 000000000..63c632ad1 --- /dev/null +++ b/tests/unit/audio/backends/sox_io/sox_effect_test.py @@ -0,0 +1,346 @@ +import io +import itertools +import tarfile +import unittest +from pathlib import Path +import numpy as np + +from parameterized import parameterized +from paddlespeech.audio import sox_effects +from paddlespeech.audio._internal import module_utils as _mod_utils +from tests.unit.common_utils import ( + get_sinusoid, + get_wav_data, + load_wav, + save_wav, + sox_utils, + TempDirMixin, + name_func, + load_effects_params +) + +if _mod_utils.is_module_available("requests"): + import requests + + +class TestSoxEffects(unittest.TestCase): + def test_init(self): + """Calling init_sox_effects multiple times does not crush""" + for _ in range(3): + sox_effects.init_sox_effects() + + +class TestSoxEffectsTensor(TempDirMixin, unittest.TestCase): + """Test suite for `apply_effects_tensor` function""" + + @parameterized.expand( + list(itertools.product(["float32", "int32"], [8000, 16000], [1, 2, 4, 8], [True, False])), + ) + def test_apply_no_effect(self, dtype, sample_rate, num_channels, channels_first): + """`apply_effects_tensor` without effects should return identical data as input""" + original = get_wav_data(dtype, num_channels, channels_first=channels_first) + expected = original.clone() + + found, output_sample_rate = sox_effects.apply_effects_tensor(expected, sample_rate, [], channels_first) + + assert (output_sample_rate == sample_rate) + # SoxEffect should not alter the input Tensor object + #self.assertEqual(original, expected) + np.testing.assert_array_almost_equal(original.numpy(), expected.numpy()) + + # SoxEffect should not return the same Tensor object + assert expected is not found + # Returned Tensor should equal to the input Tensor + #self.assertEqual(expected, found) + np.testing.assert_array_almost_equal(expected.numpy(), found.numpy()) + + @parameterized.expand( + load_effects_params("sox_effect_test_args.jsonl"), + name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}', + ) + def test_apply_effects(self, args): + """`apply_effects_tensor` should return identical data as sox command""" + effects = args["effects"] + num_channels = args.get("num_channels", 2) + input_sr = args.get("input_sample_rate", 8000) + output_sr = args.get("output_sample_rate") + + input_path = self.get_temp_path("input.wav") + reference_path = self.get_temp_path("reference.wav") + + original = get_sinusoid(frequency=800, sample_rate=input_sr, n_channels=num_channels, dtype="float32") + save_wav(input_path, original, input_sr) + sox_utils.run_sox_effect(input_path, reference_path, effects, output_sample_rate=output_sr) + + expected, expected_sr = load_wav(reference_path) + found, sr = sox_effects.apply_effects_tensor(original, input_sr, effects) + + assert sr == expected_sr + #self.assertEqual(expected, found) + np.testing.assert_array_almost_equal(expected.numpy(), found.numpy()) + + +class TestSoxEffectsFile(TempDirMixin, unittest.TestCase): + """Test suite for `apply_effects_file` function""" + + @parameterized.expand( + list( + itertools.product( + ["float32", "int32"], + [8000, 16000], + [1, 2, 4, 8], + [False, True], + ) + ), + #name_func=name_func, + ) + def test_apply_no_effect(self, dtype, sample_rate, num_channels, channels_first): + """`apply_effects_file` without effects should return identical data as input""" + path = self.get_temp_path("input.wav") + expected = get_wav_data(dtype, num_channels, channels_first=channels_first) + save_wav(path, expected, sample_rate, channels_first=channels_first) + + found, output_sample_rate = sox_effects.apply_effects_file( + path, [], normalize=False, channels_first=channels_first + ) + + assert output_sample_rate == sample_rate + #self.assertEqual(expected, found) + np.testing.assert_array_almost_equal(expected.numpy(), found.numpy()) + + @parameterized.expand( + load_effects_params("sox_effect_test_args.jsonl"), + #name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}', + ) + def test_apply_effects_str(self, args): + """`apply_effects_file` should return identical data as sox command""" + dtype = "int32" + channels_first = True + effects = args["effects"] + num_channels = args.get("num_channels", 2) + input_sr = args.get("input_sample_rate", 8000) + output_sr = args.get("output_sample_rate") + + input_path = self.get_temp_path("input.wav") + reference_path = self.get_temp_path("reference.wav") + data = get_wav_data(dtype, num_channels, channels_first=channels_first) + save_wav(input_path, data, input_sr, channels_first=channels_first) + sox_utils.run_sox_effect(input_path, reference_path, effects, output_sample_rate=output_sr) + + expected, expected_sr = load_wav(reference_path) + found, sr = sox_effects.apply_effects_file(input_path, effects, normalize=False, channels_first=channels_first) + + assert sr == expected_sr + #self.assertEqual(found, expected) + np.testing.assert_array_almost_equal(expected.numpy(), found.numpy()) + + + def test_apply_effects_path(self): + """`apply_effects_file` should return identical data as sox command when file path is given as a Path Object""" + dtype = "int32" + channels_first = True + effects = [["hilbert"]] + num_channels = 2 + input_sr = 8000 + output_sr = 8000 + + input_path = self.get_temp_path("input.wav") + reference_path = self.get_temp_path("reference.wav") + data = get_wav_data(dtype, num_channels, channels_first=channels_first) + save_wav(input_path, data, input_sr, channels_first=channels_first) + sox_utils.run_sox_effect(input_path, reference_path, effects, output_sample_rate=output_sr) + + expected, expected_sr = load_wav(reference_path) + found, sr = sox_effects.apply_effects_file( + Path(input_path), effects, normalize=False, channels_first=channels_first + ) + + assert sr == expected_sr + #self.assertEqual(found, expected) + np.testing.assert_array_almost_equal(expected.numpy(), found.numpy()) + + +class TestFileFormats(TempDirMixin, unittest.TestCase): + """`apply_effects_file` gives the same result as sox on various file formats""" + + @parameterized.expand( + list( + itertools.product( + ["float32", "int32"], + [8000, 16000], + [1, 2], + ) + ), + #name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}', + ) + def test_wav(self, dtype, sample_rate, num_channels): + """`apply_effects_file` works on various wav format""" + channels_first = True + effects = [["band", "300", "10"]] + + input_path = self.get_temp_path("input.wav") + reference_path = self.get_temp_path("reference.wav") + data = get_wav_data(dtype, num_channels, channels_first=channels_first) + save_wav(input_path, data, sample_rate, channels_first=channels_first) + sox_utils.run_sox_effect(input_path, reference_path, effects) + + expected, expected_sr = load_wav(reference_path) + found, sr = sox_effects.apply_effects_file(input_path, effects, normalize=False, channels_first=channels_first) + + assert sr == expected_sr + #self.assertEqual(found, expected) + np.testing.assert_array_almost_equal(found.numpy(), expected.numpy()) + + #not support now + #@parameterized.expand( + #list( + #itertools.product( + #[8000, 16000], + #[1, 2], + #) + #), + ##name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}', + #) + #def test_flac(self, sample_rate, num_channels): + #"""`apply_effects_file` works on various flac format""" + #channels_first = True + #effects = [["band", "300", "10"]] + + #input_path = self.get_temp_path("input.flac") + #reference_path = self.get_temp_path("reference.wav") + #sox_utils.gen_audio_file(input_path, sample_rate, num_channels) + #sox_utils.run_sox_effect(input_path, reference_path, effects, output_bitdepth=32) + + #expected, expected_sr = load_wav(reference_path) + #found, sr = sox_effects.apply_effects_file(input_path, effects, channels_first=channels_first) + #save_wav(self.get_temp_path("result.wav"), found, sr, channels_first=channels_first) + + #assert sr == expected_sr + ##self.assertEqual(found, expected) + #np.testing.assert_array_almost_equal(found.numpy(), expected.numpy()) + + #@parameterized.expand( + #list( + #itertools.product( + #[8000, 16000], + #[1, 2], + #) + #), + ##name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}', + #) + #def test_vorbis(self, sample_rate, num_channels): + #"""`apply_effects_file` works on various vorbis format""" + #channels_first = True + #effects = [["band", "300", "10"]] + + #input_path = self.get_temp_path("input.vorbis") + #reference_path = self.get_temp_path("reference.wav") + #sox_utils.gen_audio_file(input_path, sample_rate, num_channels) + #sox_utils.run_sox_effect(input_path, reference_path, effects, output_bitdepth=32) + + #expected, expected_sr = load_wav(reference_path) + #found, sr = sox_effects.apply_effects_file(input_path, effects, channels_first=channels_first) + #save_wav(self.get_temp_path("result.wav"), found, sr, channels_first=channels_first) + + #assert sr == expected_sr + ##self.assertEqual(found, expected) + #np.testing.assert_array_almost_equal(found.numpy(), expected.numpy()) + + +#@skipIfNoExec("sox") +#@skipIfNoSox +class TestFileObject(TempDirMixin, unittest.TestCase): + @parameterized.expand( + [ + ("wav", None), + ] + ) + def test_fileobj(self, ext, compression): + """Applying effects via file object works""" + sample_rate = 16000 + channels_first = True + effects = [["band", "300", "10"]] + input_path = self.get_temp_path(f"input.{ext}") + reference_path = self.get_temp_path("reference.wav") + + #sox_utils.gen_audio_file(input_path, sample_rate, num_channels=2, compression=compression) + data = get_wav_data("int32", 2, channels_first=channels_first) + save_wav(input_path, data, sample_rate, channels_first=channels_first) + + sox_utils.run_sox_effect(input_path, reference_path, effects, output_bitdepth=32) + expected, expected_sr = load_wav(reference_path) + + with open(input_path, "rb") as fileobj: + found, sr = sox_effects.apply_effects_file(fileobj, effects, channels_first=channels_first) + save_wav(self.get_temp_path("result.wav"), found, sr, channels_first=channels_first) + assert sr == expected_sr + #self.assertEqual(found, expected) + np.testing.assert_array_almost_equal(found.numpy(), expected.numpy()) + + @parameterized.expand( + [ + ("wav", None), + ] + ) + def test_bytesio(self, ext, compression): + """Applying effects via BytesIO object works""" + sample_rate = 16000 + channels_first = True + effects = [["band", "300", "10"]] + input_path = self.get_temp_path(f"input.{ext}") + reference_path = self.get_temp_path("reference.wav") + + #sox_utils.gen_audio_file(input_path, sample_rate, num_channels=2, compression=compression) + data = get_wav_data("int32", 2, channels_first=channels_first) + save_wav(input_path, data, sample_rate, channels_first=channels_first) + sox_utils.run_sox_effect(input_path, reference_path, effects, output_bitdepth=32) + expected, expected_sr = load_wav(reference_path) + + with open(input_path, "rb") as file_: + fileobj = io.BytesIO(file_.read()) + found, sr = sox_effects.apply_effects_file(fileobj, effects, channels_first=channels_first) + save_wav(self.get_temp_path("result.wav"), found, sr, channels_first=channels_first) + assert sr == expected_sr + #self.assertEqual(found, expected) + print("found") + print(found) + print("expected") + print(expected) + np.testing.assert_array_almost_equal(found.numpy(), expected.numpy()) + + @parameterized.expand( + [ + ("wav", None), + ] + ) + def test_tarfile(self, ext, compression): + """Applying effects to compressed audio via file-like file works""" + sample_rate = 16000 + channels_first = True + effects = [["band", "300", "10"]] + audio_file = f"input.{ext}" + + input_path = self.get_temp_path(audio_file) + reference_path = self.get_temp_path("reference.wav") + archive_path = self.get_temp_path("archive.tar.gz") + data = get_wav_data("int32", 2, channels_first=channels_first) + save_wav(input_path, data, sample_rate, channels_first=channels_first) + + # sox_utils.gen_audio_file(input_path, sample_rate, num_channels=2, compression=compression) + sox_utils.run_sox_effect(input_path, reference_path, effects, output_bitdepth=32) + + expected, expected_sr = load_wav(reference_path) + + with tarfile.TarFile(archive_path, "w") as tarobj: + tarobj.add(input_path, arcname=audio_file) + with tarfile.TarFile(archive_path, "r") as tarobj: + fileobj = tarobj.extractfile(audio_file) + found, sr = sox_effects.apply_effects_file(fileobj, effects, channels_first=channels_first) + save_wav(self.get_temp_path("result.wav"), found, sr, channels_first=channels_first) + assert sr == expected_sr + #self.assertEqual(found, expected) + np.testing.assert_array_almost_equal(found.numpy(), expected.numpy()) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/unit/common_utils/__init__.py b/tests/unit/common_utils/__init__.py index 722a9789f..7bc718f38 100644 --- a/tests/unit/common_utils/__init__.py +++ b/tests/unit/common_utils/__init__.py @@ -1,7 +1,9 @@ from .wav_utils import get_wav_data, load_wav, save_wav, normalize_wav -from .parameterized_utils import load_params, nested_params +from .parameterized_utils import nested_params +from .data_utils import get_sinusoid, load_params, load_effects_params from .case_utils import ( - TempDirMixin + TempDirMixin, + name_func ) __all__ = [ @@ -11,4 +13,7 @@ __all__ = [ "normalize_wav", "load_params", "nested_params", + "get_sinusoid", + "name_func", + "load_effects_params" ] diff --git a/tests/unit/common_utils/case_utils.py b/tests/unit/common_utils/case_utils.py index cee2f29c8..6f4326f56 100644 --- a/tests/unit/common_utils/case_utils.py +++ b/tests/unit/common_utils/case_utils.py @@ -14,6 +14,9 @@ from paddlespeech.audio._internal.module_utils import ( is_sox_available, ) +def name_func(func, _, params): + return f'{func.__name__}_{"_".join(str(arg) for arg in params.args)}' + class TempDirMixin: """Mixin to provide easy access to temp dir""" diff --git a/tests/unit/common_utils/parameterized_utils.py b/tests/unit/common_utils/parameterized_utils.py index 95af65c84..46cef3127 100644 --- a/tests/unit/common_utils/parameterized_utils.py +++ b/tests/unit/common_utils/parameterized_utils.py @@ -1,15 +1,28 @@ import json from itertools import product +import os from parameterized import param, parameterized -def get_asset_path(*paths): - """Return full path of a test asset""" - return os.path.join(_TEST_DIR_PATH, "assets", *paths) - -def load_params(*paths): - with open(get_asset_path(*paths), "r") as file: - return [param(json.loads(line)) for line in file] +#def get_asset_path(*paths): + #"""Return full path of a test asset""" + #return os.path.join(_TEST_DIR_PATH, "assets", *paths) + +#def load_params(*paths): + #with open(get_asset_path(*paths), "r") as file: + #return [param(json.loads(line)) for line in file] + +#def load_effects_params(*paths): + #params = [] + #with open(get_asset_path(*paths), "r") as file: + #for line in file: + #data = json.loads(line) + #for effect in data["effects"]: + #for i, arg in enumerate(effect): + #if arg.startswith(""): + #effect[i] = arg.replace("", get_asset_path()) + #params.append(param(data)) + #return params def _name_func(func, _, params): strs = [] diff --git a/tests/unit/common_utils/wav_utils.py b/tests/unit/common_utils/wav_utils.py index dbdd453e0..25d0b1971 100644 --- a/tests/unit/common_utils/wav_utils.py +++ b/tests/unit/common_utils/wav_utils.py @@ -2,6 +2,7 @@ from typing import Optional import scipy.io.wavfile import paddle +import numpy as np def normalize_wav(tensor: paddle.Tensor) -> paddle.Tensor: if tensor.dtype == paddle.float32: @@ -52,8 +53,14 @@ def get_wav_data( # paddle linspace not support uint8, int8, int16 #if dtype == "uint8": # base = paddle.linspace(0, 255, num_frames, dtype=dtype_) + #dtype_np = getattr(np, dtype) + #base_np = np.linspace(0, 255, num_frames, dtype_np) + #base = paddle.to_tensor(base_np, dtype=dtype_) #elif dtype == "int8": # base = paddle.linspace(-128, 127, num_frames, dtype=dtype_) + #dtype_np = getattr(np, dtype) + #base_np = np.linspace(-128, 127, num_frames, dtype_np) + #base = paddle.to_tensor(base_np, dtype=dtype_) if dtype == "float32": base = paddle.linspace(-1.0, 1.0, num_frames, dtype=dtype_) elif dtype == "float64": @@ -62,6 +69,9 @@ def get_wav_data( base = paddle.linspace(-2147483648, 2147483647, num_frames, dtype=dtype_) #elif dtype == "int16": # base = paddle.linspace(-32768, 32767, num_frames, dtype=dtype_) + #dtype_np = getattr(np, dtype) + #base_np = np.linspace(-32768, 32767, num_frames, dtype_np) + #base = paddle.to_tensor(base_np, dtype=dtype_) else: raise NotImplementedError(f"Unsupported dtype {dtype}") data = base.tile([num_channels, 1])