From 59d82c0c65566477f6adc2c324a8ee0ce59cf853 Mon Sep 17 00:00:00 2001 From: YangZhou Date: Tue, 2 Aug 2022 21:44:23 +0800 Subject: [PATCH] add test_load.py --- paddlespeech/audio/_extension.py | 1 - paddlespeech/audio/backends/sox_io_backend.py | 2 +- paddlespeech/audio/src/pybind/sox/utils.cpp | 44 +++++++-- tests/unit/audio/backends/sox_io/common.py | 32 +++++++ tests/unit/audio/backends/sox_io/info_test.py | 34 +++++++ tests/unit/audio/backends/sox_io/load_test.py | 47 ++++++++++ tests/unit/audio/backends/sox_io/save_test.py | 34 +++++++ tests/unit/audio/backends/sox_io/testdata | 1 + tests/unit/common_utils/__init__.py | 8 ++ tests/unit/common_utils/wav_utils.py | 92 +++++++++++++++++++ 10 files changed, 287 insertions(+), 8 deletions(-) create mode 100644 tests/unit/audio/backends/sox_io/common.py create mode 100644 tests/unit/audio/backends/sox_io/info_test.py create mode 100644 tests/unit/audio/backends/sox_io/load_test.py create mode 100644 tests/unit/audio/backends/sox_io/save_test.py create mode 120000 tests/unit/audio/backends/sox_io/testdata create mode 100644 tests/unit/common_utils/__init__.py create mode 100644 tests/unit/common_utils/wav_utils.py diff --git a/paddlespeech/audio/_extension.py b/paddlespeech/audio/_extension.py index 000fae131..ac82c06e5 100644 --- a/paddlespeech/audio/_extension.py +++ b/paddlespeech/audio/_extension.py @@ -103,7 +103,6 @@ def _load_lib(lib: str) -> bool: If a dependency is missing, then users have to install it. """ path = _get_lib_path(lib) - warnings.warn("lib path is :" + str(path)) if not path.exists(): warnings.warn("lib path is not exists:" + str(path)) return False diff --git a/paddlespeech/audio/backends/sox_io_backend.py b/paddlespeech/audio/backends/sox_io_backend.py index 750d4de1a..c75894181 100644 --- a/paddlespeech/audio/backends/sox_io_backend.py +++ b/paddlespeech/audio/backends/sox_io_backend.py @@ -8,7 +8,7 @@ from paddle import Tensor from .common import AudioMetaData from paddlespeech.audio._internal import module_utils as _mod_utils -from paddlespeech.aduio import _paddleaudio as paddleaudio +from paddlespeech.audio import _paddleaudio as paddleaudio #https://github.com/pytorch/audio/blob/main/torchaudio/backend/sox_io_backend.py diff --git a/paddlespeech/audio/src/pybind/sox/utils.cpp b/paddlespeech/audio/src/pybind/sox/utils.cpp index a930f8cdd..5c78bc116 100644 --- a/paddlespeech/audio/src/pybind/sox/utils.cpp +++ b/paddlespeech/audio/src/pybind/sox/utils.cpp @@ -178,36 +178,68 @@ py::array convert_to_tensor( const py::dtype dtype, const bool normalize, const bool channels_first) { + // todo refector later(SGoat) py::array t; uint64_t dummy = 0; SOX_SAMPLE_LOCALS; + int32_t num_rows = num_samples / num_channels; if (normalize || dtype.char_() == 'f') { - t = py::array(dtype, {num_samples / num_channels, num_channels}); + t = py::array(dtype, {num_rows, num_channels}); auto ptr = (float*)t.mutable_data(0, 0); for (int32_t i = 0; i < num_samples; ++i) { ptr[i] = SOX_SAMPLE_TO_FLOAT_32BIT(buffer[i], dummy); } + if (channels_first) { + py::array t2 = py::array(dtype, {num_channels, num_rows}); + for (int32_t row_idx = 0; row_idx < num_channels; ++row_idx) { + for (int32_t col_idx = 0; col_idx < num_rows; ++col_idx) + *(float*)t2.mutable_data(row_idx, col_idx) = *(float*)t.data(col_idx, row_idx); + } + return t2; + } } else if (dtype.char_() == 'i') { - //t = torch::from_blob( - // buffer, {num_samples / num_channels, num_channels}, torch::kInt32) - // .clone(); - t = py::array(dtype, {num_samples / num_channels, num_channels}); + t = py::array(dtype, {num_rows, num_channels}); auto ptr = (int*)t.mutable_data(0, 0); for (int32_t i = 0; i < num_samples; ++i) { ptr[i] = buffer[i]; } + if (channels_first) { + py::array t2 = py::array(dtype, {num_channels, num_rows}); + for (int32_t row_idx = 0; row_idx < num_channels; ++row_idx) { + for (int32_t col_idx = 0; col_idx < num_rows; ++col_idx) + *(int*)t2.mutable_data(row_idx, col_idx) = *(int*)t.data(col_idx, row_idx); + } + return t2; + } } else if (dtype.char_() == 'h') { // int16 - t = py::array(dtype, {num_samples / num_channels, num_channels}); + t = py::array(dtype, {num_rows, num_channels}); auto ptr = (int16_t*)t.mutable_data(0, 0); for (int32_t i = 0; i < num_samples; ++i) { ptr[i] = SOX_SAMPLE_TO_SIGNED_16BIT(buffer[i], dummy); } + if (channels_first) { + py::array t2 = py::array(dtype, {num_channels, num_rows}); + for (int32_t row_idx = 0; row_idx < num_channels; ++row_idx) { + for (int32_t col_idx = 0; col_idx < num_rows; ++col_idx) + *(int16_t*)t2.mutable_data(row_idx, col_idx) = *(int16_t*)t.data(col_idx, row_idx); + } + return t2; + } } else if (dtype.char_() == 'b') { //t = torch::empty({num_samples / num_channels, num_channels}, torch::kUInt8); + t = py::array(dtype, {num_rows, num_channels}); auto ptr = (uint8_t*)t.mutable_data(0,0); for (int32_t i = 0; i < num_samples; ++i) { ptr[i] = SOX_SAMPLE_TO_UNSIGNED_8BIT(buffer[i], dummy); } + if (channels_first) { + py::array t2 = py::array(dtype, {num_channels, num_rows}); + for (int32_t row_idx = 0; row_idx < num_channels; ++row_idx) { + for (int32_t col_idx = 0; col_idx < num_rows; ++col_idx) + *(uint8_t*)t2.mutable_data(row_idx, col_idx) = *(uint8_t*)t.data(col_idx, row_idx); + } + return t2; + } } else { throw std::runtime_error("Unsupported dtype."); } diff --git a/tests/unit/audio/backends/sox_io/common.py b/tests/unit/audio/backends/sox_io/common.py new file mode 100644 index 000000000..79b922a91 --- /dev/null +++ b/tests/unit/audio/backends/sox_io/common.py @@ -0,0 +1,32 @@ + +def get_encoding(ext, dtype): + exts = { + "mp3", + "flac", + "vorbis", + } + encodings = { + "float32": "PCM_F", + "int32": "PCM_S", + "int16": "PCM_S", + "uint8": "PCM_U", + } + return ext.upper() if ext in exts else encodings[dtype] + + +def get_bit_depth(dtype): + bit_depths = { + "float32": 32, + "int32": 32, + "int16": 16, + "uint8": 8, + } + return bit_depths[dtype] + +def get_bits_per_sample(ext, dtype): + bits_per_samples = { + "flac": 24, + "mp3": 0, + "vorbis": 0, + } + return bits_per_samples.get(ext, get_bit_depth(dtype)) diff --git a/tests/unit/audio/backends/sox_io/info_test.py b/tests/unit/audio/backends/sox_io/info_test.py new file mode 100644 index 000000000..ae18a29ef --- /dev/null +++ b/tests/unit/audio/backends/sox_io/info_test.py @@ -0,0 +1,34 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +import paddle + +from paddlespeech.audio.backends import sox_io_backend + +class TestInfo(unittest.TestCase): + + def test_wav(self, dtype, sample_rate, num_channels, sample_size): + """check wav file correctly """ + path = 'testdata/test.wav' + info = sox_io_backend.get_info_file(path) + assert info.sample_rate == sample_rate + assert info.num_frames == sample_size # duration*sample_rate + assert info.num_channels == num_channels + assert info.bits_per_sample == get_bit_depth(dtype) + assert info.encoding == get_encoding('wav', dtype) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/unit/audio/backends/sox_io/load_test.py b/tests/unit/audio/backends/sox_io/load_test.py new file mode 100644 index 000000000..8e141750b --- /dev/null +++ b/tests/unit/audio/backends/sox_io/load_test.py @@ -0,0 +1,47 @@ +import unittest +import itertools + +from parameterized import parameterized +import numpy as np +from paddlespeech.audio._internal import module_utils as _mod_utils +from paddlespeech.audio.backends import sox_io_backend + +from tests.unit.common_utils import ( + get_wav_data, + load_wav, + save_wav, +) + +#code is from:https://github.com/pytorch/audio/blob/main/torchaudio/test/torchaudio_unittest/backend/sox_io/load_test.py + +class TestLoad(unittest.TestCase): + + def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration): + """`sox_io_backend.load` can load wav format correctly. + + Wav data loaded with sox_io backend should match those with scipy + """ + path = 'testdata/reference.wav' + data = get_wav_data(dtype, num_channels, normalize=normalize, num_frames=duration * sample_rate) + save_wav(path, data, sample_rate) + expected = load_wav(path, normalize=normalize)[0] + data, sr = sox_io_backend.load(path, normalize=normalize) + assert sr == sample_rate + np.testing.assert_array_almost_equal(data, expected, decimal=4) + + @parameterized.expand( + list( + itertools.product( + ["float64", "float32", "int32",], + [8000, 16000], + [1, 2], + [False, True], + ) + ), + ) + def test_wav(self, dtype, sample_rate, num_channels, normalize): + """`sox_io_backend.load` can load wav format correctly.""" + self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/unit/audio/backends/sox_io/save_test.py b/tests/unit/audio/backends/sox_io/save_test.py new file mode 100644 index 000000000..ae18a29ef --- /dev/null +++ b/tests/unit/audio/backends/sox_io/save_test.py @@ -0,0 +1,34 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +import paddle + +from paddlespeech.audio.backends import sox_io_backend + +class TestInfo(unittest.TestCase): + + def test_wav(self, dtype, sample_rate, num_channels, sample_size): + """check wav file correctly """ + path = 'testdata/test.wav' + info = sox_io_backend.get_info_file(path) + assert info.sample_rate == sample_rate + assert info.num_frames == sample_size # duration*sample_rate + assert info.num_channels == num_channels + assert info.bits_per_sample == get_bit_depth(dtype) + assert info.encoding == get_encoding('wav', dtype) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/unit/audio/backends/sox_io/testdata b/tests/unit/audio/backends/sox_io/testdata new file mode 120000 index 000000000..485a3dd63 --- /dev/null +++ b/tests/unit/audio/backends/sox_io/testdata @@ -0,0 +1 @@ +../../features/testdata \ No newline at end of file diff --git a/tests/unit/common_utils/__init__.py b/tests/unit/common_utils/__init__.py new file mode 100644 index 000000000..dae409f3c --- /dev/null +++ b/tests/unit/common_utils/__init__.py @@ -0,0 +1,8 @@ +from .wav_utils import get_wav_data, load_wav, save_wav, normalize_wav + +__all__ = [ + "get_wav_data", + "load_wav", + "save_wav", + "normalize_wav" +] diff --git a/tests/unit/common_utils/wav_utils.py b/tests/unit/common_utils/wav_utils.py new file mode 100644 index 000000000..dbdd453e0 --- /dev/null +++ b/tests/unit/common_utils/wav_utils.py @@ -0,0 +1,92 @@ +from typing import Optional + +import scipy.io.wavfile +import paddle + +def normalize_wav(tensor: paddle.Tensor) -> paddle.Tensor: + if tensor.dtype == paddle.float32: + pass + elif tensor.dtype == paddle.int32: + tensor = paddle.cast(tensor, paddle.float32) + tensor[tensor > 0] /= 2147483647.0 + tensor[tensor < 0] /= 2147483648.0 + elif tensor.dtype == paddle.int16: + tensor = paddle.cast(tensor, paddle.float32) + tensor[tensor > 0] /= 32767.0 + tensor[tensor < 0] /= 32768.0 + elif tensor.dtype == paddle.uint8: + tensor = paddle.cast(tensor, paddle.float32) - 128 + tensor[tensor > 0] /= 127.0 + tensor[tensor < 0] /= 128.0 + return tensor + + +def get_wav_data( + dtype: str, + num_channels: int, + *, + num_frames: Optional[int] = None, + normalize: bool = True, + channels_first: bool = True, +): + """Generate linear signal of the given dtype and num_channels + + Data range is + [-1.0, 1.0] for float32, + [-2147483648, 2147483647] for int32 + [-32768, 32767] for int16 + [0, 255] for uint8 + + num_frames allow to change the linear interpolation parameter. + Default values are 256 for uint8, else 1 << 16. + 1 << 16 as default is so that int16 value range is completely covered. + """ + dtype_ = getattr(paddle, dtype) + + if num_frames is None: + if dtype == "uint8": + num_frames = 256 + else: + num_frames = 1 << 16 + + # paddle linspace not support uint8, int8, int16 + #if dtype == "uint8": + # base = paddle.linspace(0, 255, num_frames, dtype=dtype_) + #elif dtype == "int8": + # base = paddle.linspace(-128, 127, num_frames, dtype=dtype_) + if dtype == "float32": + base = paddle.linspace(-1.0, 1.0, num_frames, dtype=dtype_) + elif dtype == "float64": + base = paddle.linspace(-1.0, 1.0, num_frames, dtype=dtype_) + elif dtype == "int32": + base = paddle.linspace(-2147483648, 2147483647, num_frames, dtype=dtype_) + #elif dtype == "int16": + # base = paddle.linspace(-32768, 32767, num_frames, dtype=dtype_) + else: + raise NotImplementedError(f"Unsupported dtype {dtype}") + data = base.tile([num_channels, 1]) + if not channels_first: + data = data.transpose([1, 0]) + if normalize: + data = normalize_wav(data) + return data + + +def load_wav(path: str, normalize=True, channels_first=True) -> paddle.Tensor: + """Load wav file without paddleaudio""" + sample_rate, data = scipy.io.wavfile.read(path) + data = paddle.to_tensor(data.copy()) + if data.ndim == 1: + data = data.unsqueeze(1) + if normalize: + data = normalize_wav(data) + if channels_first: + data = data.transpose([1, 0]) + return data, sample_rate + + +def save_wav(path, data, sample_rate, channels_first=True): + """Save wav file without paddleaudio""" + if channels_first: + data = data.transpose([1, 0]) + scipy.io.wavfile.write(path, sample_rate, data.numpy())