import io import platform import unittest if platform.system() == "Windows": import warnings warnings.warn("sox io not support in Windows, please skip test.") exit() import numpy as np from paddleaudio.backends import sox_io_backend from common_utils import (get_wav_data, load_wav, save_wav, nested_params, TempDirMixin, sox_utils) #code is from:https://github.com/pytorch/audio/blob/main/torchaudio/test/torchaudio_unittest/backend/sox_io/save_test.py def _get_sox_encoding(encoding): encodings = { "PCM_F": "floating-point", "PCM_S": "signed-integer", "PCM_U": "unsigned-integer", "ULAW": "u-law", "ALAW": "a-law", } return encodings.get(encoding) class TestSaveBase(TempDirMixin): def assert_save_consistency( self, format: str, *, compression: float=None, encoding: str=None, bits_per_sample: int=None, sample_rate: float=8000, num_channels: int=2, num_frames: float=3 * 8000, src_dtype: str="int32", test_mode: str="path", ): """`save` function produces file that is comparable with `sox` command To compare that the file produced by `save` function agains the file produced by the equivalent `sox` command, we need to load both files. But there are many formats that cannot be opened with common Python modules (like SciPy). So we use `sox` command to prepare the original data and convert the saved files into a format that SciPy can read (PCM wav). The following diagram illustrates this process. The difference is 2.1. and 3.1. This assumes that - loading data with SciPy preserves the data well. - converting the resulting files into WAV format with `sox` preserve the data well. x | 1. Generate source wav file with SciPy | v -------------- wav ---------------- | | | 2.1. load with scipy | 3.1. Convert to the target | then save it into the target | format depth with sox | format with paddleaudio | v v target format target format | | | 2.2. Convert to wav with sox | 3.2. Convert to wav with sox | | v v wav wav | | | 2.3. load with scipy | 3.3. load with scipy | | v v tensor -------> compare <--------- tensor """ cmp_encoding = "floating-point" cmp_bit_depth = 32 src_path = self.get_temp_path("1.source.wav") tgt_path = self.get_temp_path(f"2.1.paddleaudio.{format}") tst_path = self.get_temp_path("2.2.result.wav") sox_path = self.get_temp_path(f"3.1.sox.{format}") ref_path = self.get_temp_path("3.2.ref.wav") # 1. Generate original wav data = get_wav_data( src_dtype, num_channels, normalize=False, num_frames=num_frames) save_wav(src_path, data, sample_rate) # 2.1. Convert the original wav to target format with paddleaudio data = load_wav(src_path, normalize=False)[0] if test_mode == "path": sox_io_backend.save( tgt_path, data, sample_rate, compression=compression, encoding=encoding, bits_per_sample=bits_per_sample) elif test_mode == "fileobj": with open(tgt_path, "bw") as file_: sox_io_backend.save( file_, data, sample_rate, format=format, compression=compression, encoding=encoding, bits_per_sample=bits_per_sample, ) elif test_mode == "bytesio": file_ = io.BytesIO() sox_io_backend.save( file_, data, sample_rate, format=format, compression=compression, encoding=encoding, bits_per_sample=bits_per_sample, ) file_.seek(0) with open(tgt_path, "bw") as f: f.write(file_.read()) else: raise ValueError(f"Unexpected test mode: {test_mode}") # 2.2. Convert the target format to wav with sox sox_utils.convert_audio_file( tgt_path, tst_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth) # 2.3. Load with SciPy found = load_wav(tst_path, normalize=False)[0] # 3.1. Convert the original wav to target format with sox sox_encoding = _get_sox_encoding(encoding) sox_utils.convert_audio_file( src_path, sox_path, compression=compression, encoding=sox_encoding, bit_depth=bits_per_sample) # 3.2. Convert the target format to wav with sox sox_utils.convert_audio_file( sox_path, ref_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth) # 3.3. Load with SciPy expected = load_wav(ref_path, normalize=False)[0] np.testing.assert_array_almost_equal(found, expected) class TestSave(TestSaveBase, unittest.TestCase): @nested_params( [ "path", ], [ ("PCM_U", 8), ("PCM_S", 16), ("PCM_S", 32), ("PCM_F", 32), ("PCM_F", 64), ("ULAW", 8), ("ALAW", 8), ], ) def test_save_wav(self, test_mode, enc_params): encoding, bits_per_sample = enc_params self.assert_save_consistency( "wav", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode) @nested_params( [ "path", ], [ ("float32", ), ("int32", ), ], ) def test_save_wav_dtype(self, test_mode, params): (dtype, ) = params self.assert_save_consistency( "wav", src_dtype=dtype, test_mode=test_mode) if __name__ == '__main__': unittest.main()