import io import itertools import os import platform import tarfile import unittest from contextlib import contextmanager if platform.system() == "Windows": import warnings warnings.warn("sox io not support in Windows, please skip test.") exit() from parameterized import parameterized from common import get_bits_per_sample, get_encoding from paddleaudio.backends import sox_io_backend from common_utils import ( get_wav_data, save_wav, TempDirMixin, sox_utils, ) #code is from:https://github.com/pytorch/audio/blob/main/torchaudio/test/torchaudio_unittest/backend/sox_io/info_test.py class TestInfo(TempDirMixin, unittest.TestCase): @parameterized.expand( list( itertools.product( [ "float32", "int32", ], [8000, 16000], [1, 2], )), ) def test_wav(self, dtype, sample_rate, num_channels): """`sox_io_backend.info` can check wav file correctly""" duration = 1 path = self.get_temp_path("data.wav") data = get_wav_data( dtype, num_channels, normalize=False, num_frames=duration * sample_rate) save_wav(path, data, sample_rate) info = sox_io_backend.info(path) assert info.sample_rate == sample_rate assert info.num_frames == sample_rate * duration assert info.num_channels == num_channels assert info.bits_per_sample == sox_utils.get_bit_depth(dtype) assert info.encoding == get_encoding("wav", dtype) @parameterized.expand( list( itertools.product( ["float32", "int32"], [8000, 16000], [4, 8, 16, 32], )), ) def test_wav_multiple_channels(self, dtype, sample_rate, num_channels): """`sox_io_backend.info` can check wav file with channels more than 2 correctly""" duration = 1 path = self.get_temp_path("data.wav") data = get_wav_data( dtype, num_channels, normalize=False, num_frames=duration * sample_rate) save_wav(path, data, sample_rate) info = sox_io_backend.info(path) assert info.sample_rate == sample_rate assert info.num_frames == sample_rate * duration assert info.num_channels == num_channels assert info.bits_per_sample == sox_utils.get_bit_depth(dtype) def test_ulaw(self): """`sox_io_backend.info` can check ulaw file correctly""" duration = 1 num_channels = 1 sample_rate = 8000 path = self.get_temp_path("data.wav") sox_utils.gen_audio_file( path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=8, encoding="u-law", duration=duration) info = sox_io_backend.info(path) assert info.sample_rate == sample_rate assert info.num_frames == sample_rate * duration assert info.num_channels == num_channels assert info.bits_per_sample == 8 assert info.encoding == "ULAW" def test_alaw(self): """`sox_io_backend.info` can check alaw file correctly""" duration = 1 num_channels = 1 sample_rate = 8000 path = self.get_temp_path("data.wav") sox_utils.gen_audio_file( path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=8, encoding="a-law", duration=duration) info = sox_io_backend.info(path) assert info.sample_rate == sample_rate assert info.num_frames == sample_rate * duration assert info.num_channels == num_channels assert info.bits_per_sample == 8 assert info.encoding == "ALAW" #class TestInfoOpus(unittest.TestCase): #@parameterized.expand( #list( #itertools.product( #["96k"], #[1, 2], #[0, 5, 10], #) #), #) #def test_opus(self, bitrate, num_channels, compression_level): #"""`sox_io_backend.info` can check opus file correcty""" #path = data_utils.get_asset_path("io", f"{bitrate}_{compression_level}_{num_channels}ch.opus") #info = sox_io_backend.info(path) #assert info.sample_rate == 48000 #assert info.num_frames == 32768 #assert info.num_channels == num_channels #assert info.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats #assert info.encoding == "OPUS" class FileObjTestBase(TempDirMixin): def _gen_file(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None): path = self.get_temp_path(f"test.{ext}") bit_depth = sox_utils.get_bit_depth(dtype) duration = num_frames / sample_rate comment_file = self._gen_comment_file(comments) if comments else None sox_utils.gen_audio_file( path, sample_rate, num_channels=num_channels, encoding=sox_utils.get_encoding(dtype), bit_depth=bit_depth, duration=duration, comment_file=comment_file, ) return path def _gen_comment_file(self, comments): comment_path = self.get_temp_path("comment.txt") with open(comment_path, "w") as file_: file_.writelines(comments) return comment_path class Unseekable: def __init__(self, fileobj): self.fileobj = fileobj def read(self, n): return self.fileobj.read(n) class TestFileObject(FileObjTestBase, unittest.TestCase): def _query_fileobj(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None): path = self._gen_file( ext, dtype, sample_rate, num_channels, num_frames, comments=comments) format_ = ext if ext in ["mp3"] else None with open(path, "rb") as fileobj: return sox_io_backend.info(fileobj, format_) def _query_bytesio(self, ext, dtype, sample_rate, num_channels, num_frames): path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames) format_ = ext if ext in ["mp3"] else None with open(path, "rb") as file_: fileobj = io.BytesIO(file_.read()) return sox_io_backend.info(fileobj, format_) def _query_tarfile(self, ext, dtype, sample_rate, num_channels, num_frames): audio_path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames) audio_file = os.path.basename(audio_path) archive_path = self.get_temp_path("archive.tar.gz") with tarfile.TarFile(archive_path, "w") as tarobj: tarobj.add(audio_path, arcname=audio_file) format_ = ext if ext in ["mp3"] else None with tarfile.TarFile(archive_path, "r") as tarobj: fileobj = tarobj.extractfile(audio_file) return sox_io_backend.info(fileobj, format_) @contextmanager def _set_buffer_size(self, buffer_size): try: original_buffer_size = get_buffer_size() set_buffer_size(buffer_size) yield finally: set_buffer_size(original_buffer_size) @parameterized.expand([ ("wav", "float32"), ("wav", "int32"), ("wav", "int16"), ("wav", "uint8"), ]) def test_fileobj(self, ext, dtype): """Querying audio via file object works""" sample_rate = 16000 num_frames = 3 * sample_rate num_channels = 2 sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames) bits_per_sample = get_bits_per_sample(ext, dtype) num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames assert sinfo.sample_rate == sample_rate assert sinfo.num_channels == num_channels assert sinfo.num_frames == num_frames assert sinfo.bits_per_sample == bits_per_sample assert sinfo.encoding == get_encoding(ext, dtype) @parameterized.expand([ ("wav", "float32"), ("wav", "int32"), ("wav", "int16"), ("wav", "uint8"), ]) def test_bytesio(self, ext, dtype): """Querying audio via ByteIO object works for small data""" sample_rate = 16000 num_frames = 3 * sample_rate num_channels = 2 sinfo = self._query_bytesio(ext, dtype, sample_rate, num_channels, num_frames) bits_per_sample = get_bits_per_sample(ext, dtype) num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames assert sinfo.sample_rate == sample_rate assert sinfo.num_channels == num_channels assert sinfo.num_frames == num_frames assert sinfo.bits_per_sample == bits_per_sample assert sinfo.encoding == get_encoding(ext, dtype) @parameterized.expand([ ("wav", "float32"), ("wav", "int32"), ("wav", "int16"), ("wav", "uint8"), ]) def test_bytesio_tiny(self, ext, dtype): """Querying audio via ByteIO object works for small data""" sample_rate = 8000 num_frames = 4 num_channels = 2 sinfo = self._query_bytesio(ext, dtype, sample_rate, num_channels, num_frames) bits_per_sample = get_bits_per_sample(ext, dtype) num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames assert sinfo.sample_rate == sample_rate assert sinfo.num_channels == num_channels assert sinfo.num_frames == num_frames assert sinfo.bits_per_sample == bits_per_sample assert sinfo.encoding == get_encoding(ext, dtype) @parameterized.expand([ ("wav", "float32"), ("wav", "int32"), ("wav", "int16"), ("wav", "uint8"), ("flac", "float32"), ("vorbis", "float32"), ("amb", "int16"), ]) def test_tarfile(self, ext, dtype): """Querying compressed audio via file-like object works""" sample_rate = 16000 num_frames = 3.0 * sample_rate num_channels = 2 sinfo = self._query_tarfile(ext, dtype, sample_rate, num_channels, num_frames) bits_per_sample = get_bits_per_sample(ext, dtype) num_frames = 0 if ext in ["vorbis"] else num_frames assert sinfo.sample_rate == sample_rate assert sinfo.num_channels == num_channels assert sinfo.num_frames == num_frames assert sinfo.bits_per_sample == bits_per_sample assert sinfo.encoding == get_encoding(ext, dtype) if __name__ == '__main__': unittest.main()