diff --git a/paddlespeech/audio/backends/sox_io_backend.py b/paddlespeech/audio/backends/sox_io_backend.py index c75894181..beb6ddb9d 100644 --- a/paddlespeech/audio/backends/sox_io_backend.py +++ b/paddlespeech/audio/backends/sox_io_backend.py @@ -1,11 +1,11 @@ from pathlib import Path from typing import Callable -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Optional, Tuple, Union +import paddle from paddle import Tensor from .common import AudioMetaData +import os from paddlespeech.audio._internal import module_utils as _mod_utils from paddlespeech.audio import _paddleaudio as paddleaudio @@ -48,31 +48,53 @@ def load( normalize: bool = True, channels_first: bool = True, format: Optional[str]=None, ) -> Tuple[Tensor, int]: + if hasattr(filepath, "read"): + ret = paddleaudio.load_audio_fileobj( + filepath, frame_offset, num_frames, normalize, channels_first, format + ) + if ret is not None: + audio_tensor = paddle.to_tensor(ret[0]) + return (audio_tensor, ret[1]) + return _fallback_load_fileobj(filepath, frame_offset, num_frames, normalize, channels_first, format) + filepath = os.fspath(filepath) ret = paddleaudio.sox_io_load_audio_file( filepath, frame_offset, num_frames, normalize, channels_first, format ) if ret is not None: - return ret + audio_tensor = paddle.to_tensor(ret[0]) + return (audio_tensor, ret[1]) return _fallback_load(filepath, frame_offset, num_frames, normalize, channels_first, format) @_mod_utils.requires_sox() -def save(filepath: str, - frame_offset: int = 0, - num_frames: int = -1, - normalize: bool = True, - channels_first: bool = True, - format: Optional[str] = None) -> Tuple[Tensor, int]: - ret = paddleaudio.sox_io_load_audio_file( - filepath, frame_offset, num_frames, normalize, channels_first, format +def save(filepath: str, + src: Tensor, + sample_rate: int, + channels_first: bool = True, + compression: Optional[float] = None, + format: Optional[str] = None, + encoding: Optional[str] = None, + bits_per_sample: Optional[int] = None, +): + src_arr = src.numpy() + if hasattr(filepath, "write"): + paddleaudio.save_audio_fileobj( + filepath, src_arr, sample_rate, channels_first, compression, format, encoding, bits_per_sample + ) + return + filepath = os.fspath(filepath) + paddleaudio.sox_io_save_audio_file( + filepath, src_arr, sample_rate, channels_first, compression, format, encoding, bits_per_sample ) - if ret is not None: - return ret - return _fallback_load(filepath, frame_offset, num_frames, normalize, channels_first, format) - @_mod_utils.requires_sox() def info(filepath: str, format: Optional[str]) -> None: + if hasattr(filepath, "read"): + sinfo = paddleaudio.get_info_fileojb(filepath, format) + if sinfo is not None: + return AudioMetaData(*sinfo) + return _fallback_info_fileobj(filepath, format) + filepath = os.fspath(filepath) sinfo = paddleaudio.get_info_file(filepath, format) if sinfo is not None: return AudioMetaData(*sinfo) diff --git a/paddlespeech/audio/src/pybind/pybind.cpp b/paddlespeech/audio/src/pybind/pybind.cpp index 776e43a7e..24cf0eb18 100644 --- a/paddlespeech/audio/src/pybind/pybind.cpp +++ b/paddlespeech/audio/src/pybind/pybind.cpp @@ -21,7 +21,7 @@ PYBIND11_MODULE(_paddleaudio, m) { &paddleaudio::sox_io::get_info_file, "Get metadata of audio file."); // support obj later - /*m.def("get_info_fileobj", + m.def("get_info_fileobj", &paddleaudio::sox_io::get_info_fileobj, "Get metadata of audio in file object."); m.def("load_audio_fileobj", @@ -30,7 +30,7 @@ PYBIND11_MODULE(_paddleaudio, m) { m.def("save_audio_fileobj", &paddleaudio::sox_io::save_audio_fileobj, "Save audio to file obj."); - */ + // sox io m.def("sox_io_get_info", &paddleaudio::sox_io::get_info_file); m.def( diff --git a/paddlespeech/audio/src/pybind/sox/effects_chain.cpp b/paddlespeech/audio/src/pybind/sox/effects_chain.cpp index 4ad90da36..15fc6d26e 100644 --- a/paddlespeech/audio/src/pybind/sox/effects_chain.cpp +++ b/paddlespeech/audio/src/pybind/sox/effects_chain.cpp @@ -1,5 +1,6 @@ #include - +#include +#include #include "paddlespeech/audio/src/pybind/sox/effects_chain.h" #include "paddlespeech/audio/src/pybind/sox/utils.h" @@ -42,6 +43,7 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) { if (index + *osamp > num_samples) { *osamp = num_samples - index; } + // Ensure that it's a multiple of the number of channels *osamp -= *osamp % num_channels; @@ -49,52 +51,80 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) { // refacor this module, chunk auto i_frame = index / num_channels; auto num_frames = *osamp / num_channels; - py::array chunk(tensor.dtype(), {num_frames*num_channels}); + + std::vector chunk(num_frames*num_channels); py::buffer_info ori_info = tensor.request(); - py::buffer_info info = chunk.request(); - char* ori_start_ptr = (char*)ori_info.ptr + index * chunk.itemsize() / sizeof(char); - std::memcpy(info.ptr, ori_start_ptr, chunk.nbytes()); - - py::dtype chunk_type = py::dtype("i"); // dtype int32 - py::array new_chunk = py::array(chunk_type, chunk.shape()); - py::buffer_info new_info = new_chunk.request(); - void* ptr = (void*) info.ptr; - int* new_ptr = (int*) new_info.ptr; + void* ptr = ori_info.ptr; // Convert to sox_sample_t (int32_t) - switch (chunk.dtype().num()) { + switch (tensor.dtype().num()) { //case c10::ScalarType::Float: { case 11: { + break; // Need to convert to 64-bit precision so that // values around INT32_MIN/MAX are handled correctly. - float* ptr_f = (float*)ptr; for (int idx = 0; idx < chunk.size(); ++idx) { - double elem = *ptr_f * 2147483648.; + int frame_idx = (idx + index) / num_channels; + int channels_idx = (idx + index) % num_channels; + double elem = 0; + if (priv->channels_first) { + elem = *(float*)tensor.data(channels_idx, frame_idx); + } else { + elem = *(float*)tensor.data(frame_idx, channels_idx); + } + elem = elem * 2147483648.; // *new_ptr = std::clamp(elem, INT32_MIN, INT32_MAX); if (elem > INT32_MAX) { - *new_ptr = INT32_MAX; + chunk[idx] = INT32_MAX; } else if (elem < INT32_MIN) { - *new_ptr = INT32_MIN; - } else { *new_ptr = elem; } + chunk[idx] = INT32_MIN; + } else { + chunk[idx] = elem; + } } break; } //case c10::ScalarType::Int: { case 5: { + for (int idx = 0; idx < chunk.size(); ++idx) { + int frame_idx = (idx + index) / num_channels; + int channels_idx = (idx + index) % num_channels; + int elem = 0; + if (priv->channels_first) { + elem = *(int*)tensor.data(channels_idx, frame_idx); + } else { + elem = *(int*)tensor.data(frame_idx, channels_idx); + } + chunk[idx] = elem; + } break; } // case short case 3: { - int16_t* ptr_s = (int16_t*) ptr; for (int idx = 0; idx < chunk.size(); ++idx) { - *new_ptr = *ptr_s * 65536; + int frame_idx = (idx + index) / num_channels; + int channels_idx = (idx + index) % num_channels; + int16_t elem = 0; + if (priv->channels_first) { + elem = *(int16_t*)tensor.data(channels_idx, frame_idx); + } else { + elem = *(int16_t*)tensor.data(frame_idx, channels_idx); + } + chunk[idx] = elem * 65536; } break; } // case byte case 1: { - int8_t* ptr_b = (int8_t*) ptr; for (int idx = 0; idx < chunk.size(); ++idx) { - *new_ptr = (*ptr_b - 128) * 16777216; + int frame_idx = (idx + index) / num_channels; + int channels_idx = (idx + index) % num_channels; + int8_t elem = 0; + if (priv->channels_first) { + elem = *(int8_t*)tensor.data(channels_idx, frame_idx); + } else { + elem = *(int8_t*)tensor.data(frame_idx, channels_idx); + } + chunk[idx] = (elem - 128) * 16777216; } break; } @@ -102,7 +132,7 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) { throw std::runtime_error("Unexpected dtype."); } // Write to buffer - memcpy(obuf, (int*)new_info.ptr, *osamp * 4); + memcpy(obuf, chunk.data(), *osamp * 4); priv->index += *osamp; return (priv->index == num_samples) ? SOX_EOF : SOX_SUCCESS; } diff --git a/tests/unit/audio/backends/sox_io/save_test.py b/tests/unit/audio/backends/sox_io/save_test.py index ae18a29ef..269c502a3 100644 --- a/tests/unit/audio/backends/sox_io/save_test.py +++ b/tests/unit/audio/backends/sox_io/save_test.py @@ -1,34 +1,177 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +import io +import os import unittest import numpy as np import paddle - +from parameterized import parameterized from paddlespeech.audio.backends import sox_io_backend -class TestInfo(unittest.TestCase): - - def test_wav(self, dtype, sample_rate, num_channels, sample_size): - """check wav file correctly """ - path = 'testdata/test.wav' - info = sox_io_backend.get_info_file(path) - assert info.sample_rate == sample_rate - assert info.num_frames == sample_size # duration*sample_rate - assert info.num_channels == num_channels - assert info.bits_per_sample == get_bit_depth(dtype) - assert info.encoding == get_encoding('wav', dtype) - +from tests.unit.common_utils import ( + get_wav_data, + load_wav, + save_wav, + nested_params, + TempDirMixin, + sox_utils +) + +#code is from:https://github.com/pytorch/audio/blob/main/torchaudio/test/torchaudio_unittest/backend/sox_io/save_test.py + +def _get_sox_encoding(encoding): + encodings = { + "PCM_F": "floating-point", + "PCM_S": "signed-integer", + "PCM_U": "unsigned-integer", + "ULAW": "u-law", + "ALAW": "a-law", + } + return encodings.get(encoding) + +class TestSaveBase(TempDirMixin): + def assert_save_consistency( + self, + format: str, + *, + compression: float = None, + encoding: str = None, + bits_per_sample: int = None, + sample_rate: float = 8000, + num_channels: int = 2, + num_frames: float = 3 * 8000, + src_dtype: str = "int32", + test_mode: str = "path", + ): + """`save` function produces file that is comparable with `sox` command + + To compare that the file produced by `save` function agains the file produced by + the equivalent `sox` command, we need to load both files. + But there are many formats that cannot be opened with common Python modules (like + SciPy). + So we use `sox` command to prepare the original data and convert the saved files + into a format that SciPy can read (PCM wav). + The following diagram illustrates this process. The difference is 2.1. and 3.1. + + This assumes that + - loading data with SciPy preserves the data well. + - converting the resulting files into WAV format with `sox` preserve the data well. + + x + | 1. Generate source wav file with SciPy + | + v + -------------- wav ---------------- + | | + | 2.1. load with scipy | 3.1. Convert to the target + | then save it into the target | format depth with sox + | format with torchaudio | + v v + target format target format + | | + | 2.2. Convert to wav with sox | 3.2. Convert to wav with sox + | | + v v + wav wav + | | + | 2.3. load with scipy | 3.3. load with scipy + | | + v v + tensor -------> compare <--------- tensor + + """ + cmp_encoding = "floating-point" + cmp_bit_depth = 32 + + src_path = self.get_temp_path("1.source.wav") + tgt_path = self.get_temp_path(f"2.1.torchaudio.{format}") + tst_path = self.get_temp_path("2.2.result.wav") + sox_path = self.get_temp_path(f"3.1.sox.{format}") + ref_path = self.get_temp_path("3.2.ref.wav") + + # 1. Generate original wav + data = get_wav_data(src_dtype, num_channels, normalize=False, num_frames=num_frames) + save_wav(src_path, data, sample_rate) + + # 2.1. Convert the original wav to target format with torchaudio + data = load_wav(src_path, normalize=False)[0] + if test_mode == "path": + sox_io_backend.save( + tgt_path, data, sample_rate, compression=compression, encoding=encoding, bits_per_sample=bits_per_sample + ) + elif test_mode == "fileobj": + with open(tgt_path, "bw") as file_: + sox_io_backend.save( + file_, + data, + sample_rate, + format=format, + compression=compression, + encoding=encoding, + bits_per_sample=bits_per_sample, + ) + elif test_mode == "bytesio": + file_ = io.BytesIO() + sox_io_backend.save( + file_, + data, + sample_rate, + format=format, + compression=compression, + encoding=encoding, + bits_per_sample=bits_per_sample, + ) + file_.seek(0) + with open(tgt_path, "bw") as f: + f.write(file_.read()) + else: + raise ValueError(f"Unexpected test mode: {test_mode}") + # 2.2. Convert the target format to wav with sox + sox_utils.convert_audio_file(tgt_path, tst_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth) + # 2.3. Load with SciPy + found = load_wav(tst_path, normalize=False)[0] + + # 3.1. Convert the original wav to target format with sox + sox_encoding = _get_sox_encoding(encoding) + sox_utils.convert_audio_file( + src_path, sox_path, compression=compression, encoding=sox_encoding, bit_depth=bits_per_sample + ) + # 3.2. Convert the target format to wav with sox + sox_utils.convert_audio_file(sox_path, ref_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth) + # 3.3. Load with SciPy + expected = load_wav(ref_path, normalize=False)[0] + + np.testing.assert_array_almost_equal(found, expected) + +class TestSave(TestSaveBase, unittest.TestCase): + @nested_params( + ["path",], + [ + ("PCM_U", 8), + ("PCM_S", 16), + ("PCM_S", 32), + ("PCM_F", 32), + ("PCM_F", 64), + ("ULAW", 8), + ("ALAW", 8), + ], + ) + def test_save_wav(self, test_mode, enc_params): + encoding, bits_per_sample = enc_params + self.assert_save_consistency("wav", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode) + + @nested_params( + ["path", ], + [ + ("float32",), + ("int32",), + ("int16",), + ("uint8",), + ], + ) + def test_save_wav_dtype(self, test_mode, params): + (dtype,) = params + self.assert_save_consistency("wav", src_dtype=dtype, test_mode=test_mode) + + if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tests/unit/common_utils/__init__.py b/tests/unit/common_utils/__init__.py index dae409f3c..722a9789f 100644 --- a/tests/unit/common_utils/__init__.py +++ b/tests/unit/common_utils/__init__.py @@ -1,8 +1,14 @@ from .wav_utils import get_wav_data, load_wav, save_wav, normalize_wav +from .parameterized_utils import load_params, nested_params +from .case_utils import ( + TempDirMixin +) __all__ = [ "get_wav_data", "load_wav", "save_wav", - "normalize_wav" + "normalize_wav", + "load_params", + "nested_params", ] diff --git a/tests/unit/common_utils/case_utils.py b/tests/unit/common_utils/case_utils.py new file mode 100644 index 000000000..cee2f29c8 --- /dev/null +++ b/tests/unit/common_utils/case_utils.py @@ -0,0 +1,56 @@ +import functools +import os.path +import shutil +import subprocess +import sys +import tempfile +import time +import unittest + +import paddle +from paddlespeech.audio._internal.module_utils import ( + is_kaldi_available, + is_module_available, + is_sox_available, +) + +class TempDirMixin: + """Mixin to provide easy access to temp dir""" + + temp_dir_ = None + + @classmethod + def get_base_temp_dir(cls): + # If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory. + # this is handy for debugging. + key = "TORCHAUDIO_TEST_TEMP_DIR" + if key in os.environ: + return os.environ[key] + if cls.temp_dir_ is None: + cls.temp_dir_ = tempfile.TemporaryDirectory() + return cls.temp_dir_.name + + @classmethod + def tearDownClass(cls): + if cls.temp_dir_ is not None: + try: + cls.temp_dir_.cleanup() + cls.temp_dir_ = None + except PermissionError: + # On Windows there is a know issue with `shutil.rmtree`, + # which fails intermittenly. + # + # https://github.com/python/cpython/issues/74168 + # + # We observed this on CircleCI, where Windows job raises + # PermissionError. + # + # Following the above thread, we ignore it. + pass + super().tearDownClass() + + def get_temp_path(self, *paths): + temp_dir = os.path.join(self.get_base_temp_dir(), self.id()) + path = os.path.join(temp_dir, *paths) + os.makedirs(os.path.dirname(path), exist_ok=True) + return path diff --git a/tests/unit/common_utils/parameterized_utils.py b/tests/unit/common_utils/parameterized_utils.py new file mode 100644 index 000000000..95af65c84 --- /dev/null +++ b/tests/unit/common_utils/parameterized_utils.py @@ -0,0 +1,50 @@ +import json +from itertools import product + +from parameterized import param, parameterized + +def get_asset_path(*paths): + """Return full path of a test asset""" + return os.path.join(_TEST_DIR_PATH, "assets", *paths) + +def load_params(*paths): + with open(get_asset_path(*paths), "r") as file: + return [param(json.loads(line)) for line in file] + +def _name_func(func, _, params): + strs = [] + for arg in params.args: + if isinstance(arg, tuple): + strs.append("_".join(str(a) for a in arg)) + else: + strs.append(str(arg)) + # sanitize the test name + name = "_".join(strs) + return parameterized.to_safe_name(f"{func.__name__}_{name}") + + +def nested_params(*params_set, name_func=_name_func): + """Generate the cartesian product of the given list of parameters. + + Args: + params_set (list of parameters): Parameters. When using ``parameterized.param`` class, + all the parameters have to be specified with the class, only using kwargs. + """ + flatten = [p for params in params_set for p in params] + + # Parameters to be nested are given as list of plain objects + if all(not isinstance(p, param) for p in flatten): + args = list(product(*params_set)) + return parameterized.expand(args, name_func=_name_func) + + # Parameters to be nested are given as list of `parameterized.param` + if not all(isinstance(p, param) for p in flatten): + raise TypeError("When using ``parameterized.param``, " "all the parameters have to be of the ``param`` type.") + if any(p.args for p in flatten): + raise ValueError( + "When using ``parameterized.param``, " "all the parameters have to be provided as keyword argument." + ) + args = [param()] + for params in params_set: + args = [param(**x.kwargs, **y.kwargs) for x in args for y in params] + return parameterized.expand(args) diff --git a/tests/unit/common_utils/sox_utils.py b/tests/unit/common_utils/sox_utils.py new file mode 100644 index 000000000..6ceae081e --- /dev/null +++ b/tests/unit/common_utils/sox_utils.py @@ -0,0 +1,116 @@ +import subprocess +import sys +import warnings + + +def get_encoding(dtype): + encodings = { + "float32": "floating-point", + "int32": "signed-integer", + "int16": "signed-integer", + "uint8": "unsigned-integer", + } + return encodings[dtype] + + +def get_bit_depth(dtype): + bit_depths = { + "float32": 32, + "int32": 32, + "int16": 16, + "uint8": 8, + } + return bit_depths[dtype] + + +def gen_audio_file( + path, + sample_rate, + num_channels, + *, + encoding=None, + bit_depth=None, + compression=None, + attenuation=None, + duration=1, + comment_file=None, +): + """Generate synthetic audio file with `sox` command.""" + if path.endswith(".wav"): + warnings.warn("Use get_wav_data and save_wav to generate wav file for accurate result.") + command = [ + "sox", + "-V3", # verbose + "--no-dither", # disable automatic dithering + "-R", + # -R is supposed to be repeatable, though the implementation looks suspicious + # and not setting the seed to a fixed value. + # https://fossies.org/dox/sox-14.4.2/sox_8c_source.html + # search "sox_globals.repeatable" + ] + if bit_depth is not None: + command += ["--bits", str(bit_depth)] + command += [ + "--rate", + str(sample_rate), + "--null", # no input + "--channels", + str(num_channels), + ] + if compression is not None: + command += ["--compression", str(compression)] + if bit_depth is not None: + command += ["--bits", str(bit_depth)] + if encoding is not None: + command += ["--encoding", str(encoding)] + if comment_file is not None: + command += ["--comment-file", str(comment_file)] + command += [ + str(path), + "synth", + str(duration), # synthesizes for the given duration [sec] + "sawtooth", + "1", + # saw tooth covers the both ends of value range, which is a good property for test. + # similar to linspace(-1., 1.) + # this introduces bigger boundary effect than sine when converted to mp3 + ] + if attenuation is not None: + command += ["vol", f"-{attenuation}dB"] + print(" ".join(command), file=sys.stderr) + subprocess.run(command, check=True) + + +def convert_audio_file(src_path, dst_path, *, encoding=None, bit_depth=None, compression=None): + """Convert audio file with `sox` command.""" + command = ["sox", "-V3", "--no-dither", "-R", str(src_path)] + if encoding is not None: + command += ["--encoding", str(encoding)] + if bit_depth is not None: + command += ["--bits", str(bit_depth)] + if compression is not None: + command += ["--compression", str(compression)] + command += [dst_path] + print(" ".join(command), file=sys.stderr) + subprocess.run(command, check=True) + + +def _flattern(effects): + if not effects: + return effects + if isinstance(effects[0], str): + return effects + return [item for sublist in effects for item in sublist] + + +def run_sox_effect(input_file, output_file, effect, *, output_sample_rate=None, output_bitdepth=None): + """Run sox effects""" + effect = _flattern(effect) + command = ["sox", "-V", "--no-dither", input_file] + if output_bitdepth: + command += ["--bits", str(output_bitdepth)] + command += [output_file] + effect + if output_sample_rate: + command += ["rate", str(output_sample_rate)] + print(" ".join(command)) + subprocess.run(command, check=True)