From a55592f7c2199be2fb9e20af0a6254a5fcf32342 Mon Sep 17 00:00:00 2001 From: YangZhou Date: Mon, 15 Aug 2022 19:57:29 +0800 Subject: [PATCH] modify info test --- paddlespeech/audio/backends/sox_io_backend.py | 2 +- paddlespeech/audio/src/pybind/sox/io.cpp | 10 +- paddlespeech/audio/src/pybind/sox/io.h | 6 +- tests/unit/audio/backends/sox_io/info_test.py | 300 ++++++++++++++++-- tests/unit/audio/backends/sox_io/testdata | 1 - 5 files changed, 289 insertions(+), 30 deletions(-) delete mode 120000 tests/unit/audio/backends/sox_io/testdata diff --git a/paddlespeech/audio/backends/sox_io_backend.py b/paddlespeech/audio/backends/sox_io_backend.py index 2037ad81d..fff9e2069 100644 --- a/paddlespeech/audio/backends/sox_io_backend.py +++ b/paddlespeech/audio/backends/sox_io_backend.py @@ -88,7 +88,7 @@ def save(filepath: str, ) @_mod_utils.requires_sox() -def info(filepath: str, format: Optional[str] = "") -> None: +def info(filepath: str, format: Optional[str] = None,) -> AudioMetaData: if hasattr(filepath, "read"): sinfo = paddleaudio.get_info_fileobj(filepath, format) if sinfo is not None: diff --git a/paddlespeech/audio/src/pybind/sox/io.cpp b/paddlespeech/audio/src/pybind/sox/io.cpp index 78b8af991..60f9222ab 100644 --- a/paddlespeech/audio/src/pybind/sox/io.cpp +++ b/paddlespeech/audio/src/pybind/sox/io.cpp @@ -13,13 +13,14 @@ using namespace paddleaudio::sox_utils; namespace paddleaudio { namespace sox_io { -auto get_info_file(const std::string &path, const std::string &format) +auto get_info_file(const std::string &path, + const tl::optional &format) -> std::tuple { SoxFormat sf( sox_open_read(path.data(), /*signal=*/nullptr, /*encoding=*/nullptr, - /*filetype=*/format.empty() ? nullptr : format.data())); + /*filetype=*/format.has_value() ? format.value().c_str() : nullptr)); validate_input_file(sf, path); @@ -61,7 +62,8 @@ std::vector> get_effects( return effects; } -auto get_info_fileobj(py::object fileobj, const std::string &format) +auto get_info_fileobj(py::object fileobj, + const tl::optional &format) -> std::tuple { const auto capacity = [&]() { const auto bufsiz = get_buffer_size(); @@ -80,7 +82,7 @@ auto get_info_fileobj(py::object fileobj, const std::string &format) buf_size, /*signal=*/nullptr, /*encoding=*/nullptr, - /*filetype=*/format.empty() ? nullptr : format.data())); + /*filetype=*/format.has_value() ? format.value().c_str() : nullptr)); // In case of streamed data, length can be 0 validate_input_memfile(sf); diff --git a/paddlespeech/audio/src/pybind/sox/io.h b/paddlespeech/audio/src/pybind/sox/io.h index 94ce18f22..3734bcb34 100644 --- a/paddlespeech/audio/src/pybind/sox/io.h +++ b/paddlespeech/audio/src/pybind/sox/io.h @@ -10,10 +10,12 @@ namespace py = pybind11; namespace paddleaudio { namespace sox_io { -auto get_info_file(const std::string &path, const std::string &format) +auto get_info_file(const std::string &path, + const tl::optional &format) -> std::tuple; -auto get_info_fileobj(py::object fileobj, const std::string &format) +auto get_info_fileobj(py::object fileobj, + const tl::optional &format) -> std::tuple; tl::optional> load_audio_fileobj( diff --git a/tests/unit/audio/backends/sox_io/info_test.py b/tests/unit/audio/backends/sox_io/info_test.py index ae18a29ef..06aa54d25 100644 --- a/tests/unit/audio/backends/sox_io/info_test.py +++ b/tests/unit/audio/backends/sox_io/info_test.py @@ -1,34 +1,290 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. import unittest +import itertools +import tarfile +from contextlib import contextmanager import numpy as np import paddle +import os +import io +from parameterized import parameterized from paddlespeech.audio.backends import sox_io_backend -class TestInfo(unittest.TestCase): +from tests.unit.common_utils import ( + get_wav_data, + load_wav, + save_wav, + TempDirMixin, + sox_utils, + data_utils +) - def test_wav(self, dtype, sample_rate, num_channels, sample_size): - """check wav file correctly """ - path = 'testdata/test.wav' - info = sox_io_backend.get_info_file(path) +from common import get_encoding, get_bits_per_sample + +#code is from:https://github.com/pytorch/audio/blob/main/torchaudio/test/torchaudio_unittest/backend/sox_io/info_test.py + +class TestInfo(TempDirMixin, unittest.TestCase): + @parameterized.expand( + list( + itertools.product( + ["float32", "int32",], + [8000, 16000], + [1, 2], + ) + ), + ) + def test_wav(self, dtype, sample_rate, num_channels): + """`sox_io_backend.info` can check wav file correctly""" + duration = 1 + path = self.get_temp_path("data.wav") + data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate) + save_wav(path, data, sample_rate) + info = sox_io_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels + assert info.bits_per_sample == sox_utils.get_bit_depth(dtype) + assert info.encoding == get_encoding("wav", dtype) + + @parameterized.expand( + list( + itertools.product( + ["float32", "int32"], + [8000, 16000], + [4, 8, 16, 32], + ) + ), + ) + def test_wav_multiple_channels(self, dtype, sample_rate, num_channels): + """`sox_io_backend.info` can check wav file with channels more than 2 correctly""" + duration = 1 + path = self.get_temp_path("data.wav") + data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate) + save_wav(path, data, sample_rate) + info = sox_io_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels + assert info.bits_per_sample == sox_utils.get_bit_depth(dtype) + + def test_ulaw(self): + """`sox_io_backend.info` can check ulaw file correctly""" + duration = 1 + num_channels = 1 + sample_rate = 8000 + path = self.get_temp_path("data.wav") + sox_utils.gen_audio_file( + path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=8, encoding="u-law", duration=duration + ) + info = sox_io_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels + assert info.bits_per_sample == 8 + assert info.encoding == "ULAW" + + def test_alaw(self): + """`sox_io_backend.info` can check alaw file correctly""" + duration = 1 + num_channels = 1 + sample_rate = 8000 + path = self.get_temp_path("data.wav") + sox_utils.gen_audio_file( + path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=8, encoding="a-law", duration=duration + ) + info = sox_io_backend.info(path) assert info.sample_rate == sample_rate - assert info.num_frames == sample_size # duration*sample_rate + assert info.num_frames == sample_rate * duration assert info.num_channels == num_channels - assert info.bits_per_sample == get_bit_depth(dtype) - assert info.encoding == get_encoding('wav', dtype) - + assert info.bits_per_sample == 8 + assert info.encoding == "ALAW" + +#class TestInfoOpus(unittest.TestCase): + #@parameterized.expand( + #list( + #itertools.product( + #["96k"], + #[1, 2], + #[0, 5, 10], + #) + #), + #) + #def test_opus(self, bitrate, num_channels, compression_level): + #"""`sox_io_backend.info` can check opus file correcty""" + #path = data_utils.get_asset_path("io", f"{bitrate}_{compression_level}_{num_channels}ch.opus") + #info = sox_io_backend.info(path) + #assert info.sample_rate == 48000 + #assert info.num_frames == 32768 + #assert info.num_channels == num_channels + #assert info.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats + #assert info.encoding == "OPUS" + +class FileObjTestBase(TempDirMixin): + def _gen_file(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None): + path = self.get_temp_path(f"test.{ext}") + bit_depth = sox_utils.get_bit_depth(dtype) + duration = num_frames / sample_rate + comment_file = self._gen_comment_file(comments) if comments else None + + sox_utils.gen_audio_file( + path, + sample_rate, + num_channels=num_channels, + encoding=sox_utils.get_encoding(dtype), + bit_depth=bit_depth, + duration=duration, + comment_file=comment_file, + ) + return path + + def _gen_comment_file(self, comments): + comment_path = self.get_temp_path("comment.txt") + with open(comment_path, "w") as file_: + file_.writelines(comments) + return comment_path + +class Unseekable: + def __init__(self, fileobj): + self.fileobj = fileobj + + def read(self, n): + return self.fileobj.read(n) + +class TestFileObject(FileObjTestBase, unittest.TestCase): + def _query_fileobj(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None): + path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames, comments=comments) + format_ = ext if ext in ["mp3"] else None + with open(path, "rb") as fileobj: + return sox_io_backend.info(fileobj, format_) + + def _query_bytesio(self, ext, dtype, sample_rate, num_channels, num_frames): + path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames) + format_ = ext if ext in ["mp3"] else None + with open(path, "rb") as file_: + fileobj = io.BytesIO(file_.read()) + return sox_io_backend.info(fileobj, format_) + + def _query_tarfile(self, ext, dtype, sample_rate, num_channels, num_frames): + audio_path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames) + audio_file = os.path.basename(audio_path) + archive_path = self.get_temp_path("archive.tar.gz") + with tarfile.TarFile(archive_path, "w") as tarobj: + tarobj.add(audio_path, arcname=audio_file) + format_ = ext if ext in ["mp3"] else None + with tarfile.TarFile(archive_path, "r") as tarobj: + fileobj = tarobj.extractfile(audio_file) + return sox_io_backend.info(fileobj, format_) + + @contextmanager + def _set_buffer_size(self, buffer_size): + try: + original_buffer_size = get_buffer_size() + set_buffer_size(buffer_size) + yield + finally: + set_buffer_size(original_buffer_size) + + @parameterized.expand( + [ + ("wav", "float32"), + ("wav", "int32"), + ("wav", "int16"), + ("wav", "uint8"), + ] + ) + def test_fileobj(self, ext, dtype): + """Querying audio via file object works""" + sample_rate = 16000 + num_frames = 3 * sample_rate + num_channels = 2 + sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames) + + bits_per_sample = get_bits_per_sample(ext, dtype) + num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames + + assert sinfo.sample_rate == sample_rate + assert sinfo.num_channels == num_channels + assert sinfo.num_frames == num_frames + assert sinfo.bits_per_sample == bits_per_sample + assert sinfo.encoding == get_encoding(ext, dtype) + + @parameterized.expand( + [ + ("wav", "float32"), + ("wav", "int32"), + ("wav", "int16"), + ("wav", "uint8"), + ] + ) + def test_bytesio(self, ext, dtype): + """Querying audio via ByteIO object works for small data""" + sample_rate = 16000 + num_frames = 3 * sample_rate + num_channels = 2 + sinfo = self._query_bytesio(ext, dtype, sample_rate, num_channels, num_frames) + + bits_per_sample = get_bits_per_sample(ext, dtype) + num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames + + assert sinfo.sample_rate == sample_rate + assert sinfo.num_channels == num_channels + assert sinfo.num_frames == num_frames + assert sinfo.bits_per_sample == bits_per_sample + assert sinfo.encoding == get_encoding(ext, dtype) + + @parameterized.expand( + [ + ("wav", "float32"), + ("wav", "int32"), + ("wav", "int16"), + ("wav", "uint8"), + ] + ) + def test_bytesio_tiny(self, ext, dtype): + """Querying audio via ByteIO object works for small data""" + sample_rate = 8000 + num_frames = 4 + num_channels = 2 + sinfo = self._query_bytesio(ext, dtype, sample_rate, num_channels, num_frames) + + bits_per_sample = get_bits_per_sample(ext, dtype) + num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames + + assert sinfo.sample_rate == sample_rate + assert sinfo.num_channels == num_channels + assert sinfo.num_frames == num_frames + assert sinfo.bits_per_sample == bits_per_sample + assert sinfo.encoding == get_encoding(ext, dtype) + + @parameterized.expand( + [ + ("wav", "float32"), + ("wav", "int32"), + ("wav", "int16"), + ("wav", "uint8"), + ("flac", "float32"), + ("vorbis", "float32"), + ("amb", "int16"), + ] + ) + def test_tarfile(self, ext, dtype): + """Querying compressed audio via file-like object works""" + sample_rate = 16000 + num_frames = 3.0 * sample_rate + num_channels = 2 + sinfo = self._query_tarfile(ext, dtype, sample_rate, num_channels, num_frames) + + bits_per_sample = get_bits_per_sample(ext, dtype) + num_frames = 0 if ext in ["vorbis"] else num_frames + + assert sinfo.sample_rate == sample_rate + assert sinfo.num_channels == num_channels + assert sinfo.num_frames == num_frames + assert sinfo.bits_per_sample == bits_per_sample + assert sinfo.encoding == get_encoding(ext, dtype) + + + if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tests/unit/audio/backends/sox_io/testdata b/tests/unit/audio/backends/sox_io/testdata deleted file mode 120000 index 485a3dd63..000000000 --- a/tests/unit/audio/backends/sox_io/testdata +++ /dev/null @@ -1 +0,0 @@ -../../features/testdata \ No newline at end of file