diff --git a/CMakeLists.txt b/CMakeLists.txt index 57d806e16..6c3e7d76f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -53,11 +53,11 @@ set(FETCHCONTENT_BASE_DIR ${fc_patch}) include(openblas) -# include(pybind) +include(pybind) # packages find_package(Python3 COMPONENTS Interpreter Development) -find_package(pybind11 CONFIG REQUIRED) +#find_package(pybind11 CONFIG REQUIRED) # set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread -O0 -Wall -g") diff --git a/paddlespeech/audio/backends/soundfile_backend.py b/paddlespeech/audio/backends/soundfile_backend.py index 9ef69c047..1afe3dc38 100644 --- a/paddlespeech/audio/backends/soundfile_backend.py +++ b/paddlespeech/audio/backends/soundfile_backend.py @@ -11,25 +11,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os import warnings from typing import Optional from typing import Tuple import numpy as np +import paddle import resampy -import soundfile as sf +import soundfile from scipy.io import wavfile from ..utils import depth_convert from ..utils import ParameterError +from .common import AudioMetaData __all__ = [ 'resample', 'to_mono', 'normalize', 'save', + 'soudfile_save', 'load', + 'soundfile_load', + 'info' ] NORMALMIZE_TYPES = ['linear', 'gaussian'] MERGE_TYPES = ['ch0', 'ch1', 'random', 'average'] @@ -116,7 +122,7 @@ def to_mono(y: np.ndarray, merge_type: str='average') -> np.ndarray: return y_out -def sound_file_load(file: os.PathLike, +def soundfile_load(file: os.PathLike, offset: Optional[float]=None, dtype: str='int16', duration: Optional[int]=None) -> Tuple[np.ndarray, int]: @@ -131,7 +137,7 @@ def sound_file_load(file: os.PathLike, Returns: Tuple[np.ndarray, int]: Waveform in ndarray and its samplerate. """ - with sf.SoundFile(file) as sf_desc: + with soundfile.SoundFile(file) as sf_desc: sr_native = sf_desc.samplerate if offset: sf_desc.seek(int(offset * sr_native)) @@ -172,7 +178,7 @@ def normalize(y: np.ndarray, norm_type: str='linear', return y -def save(y: np.ndarray, sr: int, file: os.PathLike) -> None: +def soundfile_save(y: np.ndarray, sr: int, file: os.PathLike) -> None: """Save audio file to disk. This function saves audio to disk using scipy.io.wavfile, with additional step to convert input waveform to int16. Args: @@ -198,8 +204,7 @@ def save(y: np.ndarray, sr: int, file: os.PathLike) -> None: wavfile.write(file, sr, y_out) - -def load( +def soudfile_load( file: os.PathLike, sr: Optional[int]=None, mono: bool=True, @@ -251,6 +256,406 @@ def load( y = depth_convert(y, dtype) return y, r +#the code below is form: https://github.com/pytorch/audio/blob/main/torchaudio/backend/soundfile_backend.py + +def _get_subtype_for_wav(dtype: paddle.dtype, encoding: str, bits_per_sample: int): + if not encoding: + if not bits_per_sample: + subtype = { + paddle.uint8: "PCM_U8", + paddle.int16: "PCM_16", + paddle.int32: "PCM_32", + paddle.float32: "FLOAT", + paddle.float64: "DOUBLE", + }.get(dtype) + if not subtype: + raise ValueError(f"Unsupported dtype for wav: {dtype}") + return subtype + if bits_per_sample == 8: + return "PCM_U8" + return f"PCM_{bits_per_sample}" + if encoding == "PCM_S": + if not bits_per_sample: + return "PCM_32" + if bits_per_sample == 8: + raise ValueError("wav does not support 8-bit signed PCM encoding.") + return f"PCM_{bits_per_sample}" + if encoding == "PCM_U": + if bits_per_sample in (None, 8): + return "PCM_U8" + raise ValueError("wav only supports 8-bit unsigned PCM encoding.") + if encoding == "PCM_F": + if bits_per_sample in (None, 32): + return "FLOAT" + if bits_per_sample == 64: + return "DOUBLE" + raise ValueError("wav only supports 32/64-bit float PCM encoding.") + if encoding == "ULAW": + if bits_per_sample in (None, 8): + return "ULAW" + raise ValueError("wav only supports 8-bit mu-law encoding.") + if encoding == "ALAW": + if bits_per_sample in (None, 8): + return "ALAW" + raise ValueError("wav only supports 8-bit a-law encoding.") + raise ValueError(f"wav does not support {encoding}.") + + +def _get_subtype_for_sphere(encoding: str, bits_per_sample: int): + if encoding in (None, "PCM_S"): + return f"PCM_{bits_per_sample}" if bits_per_sample else "PCM_32" + if encoding in ("PCM_U", "PCM_F"): + raise ValueError(f"sph does not support {encoding} encoding.") + if encoding == "ULAW": + if bits_per_sample in (None, 8): + return "ULAW" + raise ValueError("sph only supports 8-bit for mu-law encoding.") + if encoding == "ALAW": + return "ALAW" + raise ValueError(f"sph does not support {encoding}.") + + +def _get_subtype(dtype: paddle.dtype, format: str, encoding: str, bits_per_sample: int): + if format == "wav": + return _get_subtype_for_wav(dtype, encoding, bits_per_sample) + if format == "flac": + if encoding: + raise ValueError("flac does not support encoding.") + if not bits_per_sample: + return "PCM_16" + if bits_per_sample > 24: + raise ValueError("flac does not support bits_per_sample > 24.") + return "PCM_S8" if bits_per_sample == 8 else f"PCM_{bits_per_sample}" + if format in ("ogg", "vorbis"): + if encoding or bits_per_sample: + raise ValueError("ogg/vorbis does not support encoding/bits_per_sample.") + return "VORBIS" + if format == "sph": + return _get_subtype_for_sphere(encoding, bits_per_sample) + if format in ("nis", "nist"): + return "PCM_16" + raise ValueError(f"Unsupported format: {format}") + +def save( + filepath: str, + src: paddle.Tensor, + sample_rate: int, + channels_first: bool = True, + compression: Optional[float] = None, + format: Optional[str] = None, + encoding: Optional[str] = None, + bits_per_sample: Optional[int] = None, +): + """Save audio data to file. + + Note: + The formats this function can handle depend on the soundfile installation. + This function is tested on the following formats; + + * WAV + + * 32-bit floating-point + * 32-bit signed integer + * 16-bit signed integer + * 8-bit unsigned integer + + * FLAC + * OGG/VORBIS + * SPHERE + + Note: + ``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts + ``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend, + + Args: + filepath (str or pathlib.Path): Path to audio file. + src (paddle.Tensor): Audio data to save. must be 2D tensor. + sample_rate (int): sampling rate + channels_first (bool, optional): If ``True``, the given tensor is interpreted as `[channel, time]`, + otherwise `[time, channel]`. + compression (float of None, optional): Not used. + It is here only for interface compatibility reson with "sox_io" backend. + format (str or None, optional): Override the audio format. + When ``filepath`` argument is path-like object, audio format is + inferred from file extension. If the file extension is missing or + different, you can specify the correct format with this argument. + + When ``filepath`` argument is file-like object, + this argument is required. + + Valid values are ``"wav"``, ``"ogg"``, ``"vorbis"``, + ``"flac"`` and ``"sph"``. + encoding (str or None, optional): Changes the encoding for supported formats. + This argument is effective only for supported formats, sush as + ``"wav"``, ``""flac"`` and ``"sph"``. Valid values are; + + - ``"PCM_S"`` (signed integer Linear PCM) + - ``"PCM_U"`` (unsigned integer Linear PCM) + - ``"PCM_F"`` (floating point PCM) + - ``"ULAW"`` (mu-law) + - ``"ALAW"`` (a-law) + + bits_per_sample (int or None, optional): Changes the bit depth for the + supported formats. + When ``format`` is one of ``"wav"``, ``"flac"`` or ``"sph"``, + you can change the bit depth. + Valid values are ``8``, ``16``, ``24``, ``32`` and ``64``. + + Supported formats/encodings/bit depth/compression are: + + ``"wav"`` + - 32-bit floating-point PCM + - 32-bit signed integer PCM + - 24-bit signed integer PCM + - 16-bit signed integer PCM + - 8-bit unsigned integer PCM + - 8-bit mu-law + - 8-bit a-law + + Note: + Default encoding/bit depth is determined by the dtype of + the input Tensor. + + ``"flac"`` + - 8-bit + - 16-bit (default) + - 24-bit + + ``"ogg"``, ``"vorbis"`` + - Doesn't accept changing configuration. + + ``"sph"`` + - 8-bit signed integer PCM + - 16-bit signed integer PCM + - 24-bit signed integer PCM + - 32-bit signed integer PCM (default) + - 8-bit mu-law + - 8-bit a-law + - 16-bit a-law + - 24-bit a-law + - 32-bit a-law + + """ + if src.ndim != 2: + raise ValueError(f"Expected 2D Tensor, got {src.ndim}D.") + if compression is not None: + warnings.warn( + '`save` function of "soundfile" backend does not support "compression" parameter. ' + "The argument is silently ignored." + ) + if hasattr(filepath, "write"): + if format is None: + raise RuntimeError("`format` is required when saving to file object.") + ext = format.lower() + else: + ext = str(filepath).split(".")[-1].lower() + + if bits_per_sample not in (None, 8, 16, 24, 32, 64): + raise ValueError("Invalid bits_per_sample.") + if bits_per_sample == 24: + warnings.warn( + "Saving audio with 24 bits per sample might warp samples near -1. " + "Using 16 bits per sample might be able to avoid this." + ) + subtype = _get_subtype(src.dtype, ext, encoding, bits_per_sample) + + # sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format, + # so we extend the extensions manually here + if ext in ["nis", "nist", "sph"] and format is None: + format = "NIST" + + if channels_first: + src = src.t() + + soundfile.write(file=filepath, data=src, samplerate=sample_rate, subtype=subtype, format=format) -def info(filepath: str) -> None: - raise RuntimeError("No audio I/O backend is available.") +_SUBTYPE2DTYPE = { + "PCM_S8": "int8", + "PCM_U8": "uint8", + "PCM_16": "int16", + "PCM_32": "int32", + "FLOAT": "float32", + "DOUBLE": "float64", +} + +def load( + filepath: str, + frame_offset: int = 0, + num_frames: int = -1, + normalize: bool = True, + channels_first: bool = True, + format: Optional[str] = None, +) -> Tuple[paddle.Tensor, int]: + """Load audio data from file. + + Note: + The formats this function can handle depend on the soundfile installation. + This function is tested on the following formats; + + * WAV + + * 32-bit floating-point + * 32-bit signed integer + * 16-bit signed integer + * 8-bit unsigned integer + + * FLAC + * OGG/VORBIS + * SPHERE + + By default (``normalize=True``, ``channels_first=True``), this function returns Tensor with + ``float32`` dtype and the shape of `[channel, time]`. + The samples are normalized to fit in the range of ``[-1.0, 1.0]``. + + When the input format is WAV with integer type, such as 32-bit signed integer, 16-bit + signed integer and 8-bit unsigned integer (24-bit signed integer is not supported), + by providing ``normalize=False``, this function can return integer Tensor, where the samples + are expressed within the whole range of the corresponding dtype, that is, ``int32`` tensor + for 32-bit signed PCM, ``int16`` for 16-bit signed PCM and ``uint8`` for 8-bit unsigned PCM. + + ``normalize`` parameter has no effect on 32-bit floating-point WAV and other formats, such as + ``flac`` and ``mp3``. + For these formats, this function always returns ``float32`` Tensor with values normalized to + ``[-1.0, 1.0]``. + + Note: + ``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts + ``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend. + + Args: + filepath (path-like object or file-like object): + Source of audio data. + frame_offset (int, optional): + Number of frames to skip before start reading data. + num_frames (int, optional): + Maximum number of frames to read. ``-1`` reads all the remaining samples, + starting from ``frame_offset``. + This function may return the less number of frames if there is not enough + frames in the given file. + normalize (bool, optional): + When ``True``, this function always return ``float32``, and sample values are + normalized to ``[-1.0, 1.0]``. + If input file is integer WAV, giving ``False`` will change the resulting Tensor type to + integer type. + This argument has no effect for formats other than integer WAV type. + channels_first (bool, optional): + When True, the returned Tensor has dimension `[channel, time]`. + Otherwise, the returned Tensor's dimension is `[time, channel]`. + format (str or None, optional): + Not used. PySoundFile does not accept format hint. + + Returns: + (paddle.Tensor, int): Resulting Tensor and sample rate. + If the input file has integer wav format and normalization is off, then it has + integer type, else ``float32`` type. If ``channels_first=True``, it has + `[channel, time]` else `[time, channel]`. + """ + with soundfile.SoundFile(filepath, "r") as file_: + if file_.format != "WAV" or normalize: + dtype = "float32" + elif file_.subtype not in _SUBTYPE2DTYPE: + raise ValueError(f"Unsupported subtype: {file_.subtype}") + else: + dtype = _SUBTYPE2DTYPE[file_.subtype] + + frames = file_._prepare_read(frame_offset, None, num_frames) + waveform = file_.read(frames, dtype, always_2d=True) + sample_rate = file_.samplerate + + waveform = paddle.to_tensor(waveform) + if channels_first: + waveform = paddle.transpose(waveform, perm=[1,0]) + return waveform, sample_rate + + +# Mapping from soundfile subtype to number of bits per sample. +# This is mostly heuristical and the value is set to 0 when it is irrelevant +# (lossy formats) or when it can't be inferred. +# For ADPCM (and G72X) subtypes, it's hard to infer the bit depth because it's not part of the standard: +# According to https://en.wikipedia.org/wiki/Adaptive_differential_pulse-code_modulation#In_telephony, +# the default seems to be 8 bits but it can be compressed further to 4 bits. +# The dict is inspired from +# https://github.com/bastibe/python-soundfile/blob/744efb4b01abc72498a96b09115b42a4cabd85e4/soundfile.py#L66-L94 +_SUBTYPE_TO_BITS_PER_SAMPLE = { + "PCM_S8": 8, # Signed 8 bit data + "PCM_16": 16, # Signed 16 bit data + "PCM_24": 24, # Signed 24 bit data + "PCM_32": 32, # Signed 32 bit data + "PCM_U8": 8, # Unsigned 8 bit data (WAV and RAW only) + "FLOAT": 32, # 32 bit float data + "DOUBLE": 64, # 64 bit float data + "ULAW": 8, # U-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types + "ALAW": 8, # A-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types + "IMA_ADPCM": 0, # IMA ADPCM. + "MS_ADPCM": 0, # Microsoft ADPCM. + "GSM610": 0, # GSM 6.10 encoding. (Wikipedia says 1.625 bit depth?? https://en.wikipedia.org/wiki/Full_Rate) + "VOX_ADPCM": 0, # OKI / Dialogix ADPCM + "G721_32": 0, # 32kbs G721 ADPCM encoding. + "G723_24": 0, # 24kbs G723 ADPCM encoding. + "G723_40": 0, # 40kbs G723 ADPCM encoding. + "DWVW_12": 12, # 12 bit Delta Width Variable Word encoding. + "DWVW_16": 16, # 16 bit Delta Width Variable Word encoding. + "DWVW_24": 24, # 24 bit Delta Width Variable Word encoding. + "DWVW_N": 0, # N bit Delta Width Variable Word encoding. + "DPCM_8": 8, # 8 bit differential PCM (XI only) + "DPCM_16": 16, # 16 bit differential PCM (XI only) + "VORBIS": 0, # Xiph Vorbis encoding. (lossy) + "ALAC_16": 16, # Apple Lossless Audio Codec (16 bit). + "ALAC_20": 20, # Apple Lossless Audio Codec (20 bit). + "ALAC_24": 24, # Apple Lossless Audio Codec (24 bit). + "ALAC_32": 32, # Apple Lossless Audio Codec (32 bit). +} + +def _get_bit_depth(subtype): + if subtype not in _SUBTYPE_TO_BITS_PER_SAMPLE: + warnings.warn( + f"The {subtype} subtype is unknown to PaddleAudio. As a result, the bits_per_sample " + "attribute will be set to 0. If you are seeing this warning, please " + "report by opening an issue on github (after checking for existing/closed ones). " + "You may otherwise ignore this warning." + ) + return _SUBTYPE_TO_BITS_PER_SAMPLE.get(subtype, 0) + +_SUBTYPE_TO_ENCODING = { + "PCM_S8": "PCM_S", + "PCM_16": "PCM_S", + "PCM_24": "PCM_S", + "PCM_32": "PCM_S", + "PCM_U8": "PCM_U", + "FLOAT": "PCM_F", + "DOUBLE": "PCM_F", + "ULAW": "ULAW", + "ALAW": "ALAW", + "VORBIS": "VORBIS", +} + +def _get_encoding(format: str, subtype: str): + if format == "FLAC": + return "FLAC" + return _SUBTYPE_TO_ENCODING.get(subtype, "UNKNOWN") + +def info(filepath: str, format: Optional[str] = None) -> AudioMetaData: + """Get signal information of an audio file. + + Note: + ``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts + ``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend, + + Args: + filepath (path-like object or file-like object): + Source of audio data. + format (str or None, optional): + Not used. PySoundFile does not accept format hint. + + Returns: + AudioMetaData: meta data of the given audio. + + """ + sinfo = soundfile.info(filepath) + return AudioMetaData( + sinfo.samplerate, + sinfo.frames, + sinfo.channels, + bits_per_sample=_get_bit_depth(sinfo.subtype), + encoding=_get_encoding(sinfo.format, sinfo.subtype), + ) \ No newline at end of file diff --git a/tests/unit/audio/backends/sox_io/common.py b/tests/unit/audio/backends/common.py similarity index 100% rename from tests/unit/audio/backends/sox_io/common.py rename to tests/unit/audio/backends/common.py diff --git a/tests/unit/audio/backends/soundfile/common.py b/tests/unit/audio/backends/soundfile/common.py new file mode 100644 index 000000000..7067b4a98 --- /dev/null +++ b/tests/unit/audio/backends/soundfile/common.py @@ -0,0 +1,57 @@ +import itertools +from unittest import skipIf + +from parameterized import parameterized +from paddlespeech.audio._internal.module_utils import is_module_available + + +def name_func(func, _, params): + return f'{func.__name__}_{"_".join(str(arg) for arg in params.args)}' + + +def dtype2subtype(dtype): + return { + "float64": "DOUBLE", + "float32": "FLOAT", + "int32": "PCM_32", + "int16": "PCM_16", + "uint8": "PCM_U8", + "int8": "PCM_S8", + }[dtype] + + +def skipIfFormatNotSupported(fmt): + fmts = [] + if is_module_available("soundfile"): + import soundfile + + fmts = soundfile.available_formats() + return skipIf(fmt not in fmts, f'"{fmt}" is not supported by soundfile') + return skipIf(True, '"soundfile" not available.') + + +def parameterize(*params): + return parameterized.expand(list(itertools.product(*params)), name_func=name_func) + + +def fetch_wav_subtype(dtype, encoding, bits_per_sample): + subtype = { + (None, None): dtype2subtype(dtype), + (None, 8): "PCM_U8", + ("PCM_U", None): "PCM_U8", + ("PCM_U", 8): "PCM_U8", + ("PCM_S", None): "PCM_32", + ("PCM_S", 16): "PCM_16", + ("PCM_S", 32): "PCM_32", + ("PCM_F", None): "FLOAT", + ("PCM_F", 32): "FLOAT", + ("PCM_F", 64): "DOUBLE", + ("ULAW", None): "ULAW", + ("ULAW", 8): "ULAW", + ("ALAW", None): "ALAW", + ("ALAW", 8): "ALAW", + }.get((encoding, bits_per_sample)) + if subtype: + return subtype + raise ValueError(f"wav does not support ({encoding}, {bits_per_sample}).") + diff --git a/tests/unit/audio/backends/soundfile/info_test.py b/tests/unit/audio/backends/soundfile/info_test.py new file mode 100644 index 000000000..c94826858 --- /dev/null +++ b/tests/unit/audio/backends/soundfile/info_test.py @@ -0,0 +1,199 @@ +#this code is from: https://github.com/pytorch/audio/blob/main/test/torchaudio_unittest/backend/soundfile/info_test.py + +import tarfile +import warnings +import unittest +from unittest.mock import patch + +import paddle +from paddlespeech.audio._internal import module_utils as _mod_utils +from paddlespeech.audio.backends import soundfile_backend +from tests.unit.audio.backends.common import get_bits_per_sample, get_encoding +from tests.unit.common_utils import ( + get_wav_data, + nested_params, + save_wav, + TempDirMixin, +) + +from common import parameterize, skipIfFormatNotSupported + +import soundfile + + +class TestInfo(TempDirMixin, unittest.TestCase): + @parameterize( + ["float32", "int32"], + [8000, 16000], + [1, 2], + ) + def test_wav(self, dtype, sample_rate, num_channels): + """`soundfile_backend.info` can check wav file correctly""" + duration = 1 + path = self.get_temp_path("data.wav") + data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate) + save_wav(path, data, sample_rate) + info = soundfile_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels + assert info.bits_per_sample == get_bits_per_sample("wav", dtype) + assert info.encoding == get_encoding("wav", dtype) + + @parameterize([8000, 16000], [1, 2]) + @skipIfFormatNotSupported("FLAC") + def test_flac(self, sample_rate, num_channels): + """`soundfile_backend.info` can check flac file correctly""" + duration = 1 + num_frames = sample_rate * duration + #data = torch.randn(num_frames, num_channels).numpy() + data = paddle.randn(shape=[num_frames, num_channels]).numpy() + + path = self.get_temp_path("data.flac") + soundfile.write(path, data, sample_rate) + + info = soundfile_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_frames == num_frames + assert info.num_channels == num_channels + assert info.bits_per_sample == 16 + assert info.encoding == "FLAC" + + #@parameterize([8000, 16000], [1, 2]) + #@skipIfFormatNotSupported("OGG") + #def test_ogg(self, sample_rate, num_channels): + #"""`soundfile_backend.info` can check ogg file correctly""" + #duration = 1 + #num_frames = sample_rate * duration + ##data = torch.randn(num_frames, num_channels).numpy() + #data = paddle.randn(shape=[num_frames, num_channels]).numpy() + #print(len(data)) + #path = self.get_temp_path("data.ogg") + #soundfile.write(path, data, sample_rate) + + #info = soundfile_backend.info(path) + #print(info) + #assert info.sample_rate == sample_rate + #print("info") + #print(info.num_frames) + #print("jiji") + #print(sample_rate*duration) + ##assert info.num_frames == sample_rate * duration + #assert info.num_channels == num_channels + #assert info.bits_per_sample == 0 + #assert info.encoding == "VORBIS" + + @nested_params( + [8000, 16000], + [1, 2], + [("PCM_24", 24), ("PCM_32", 32)], + ) + @skipIfFormatNotSupported("NIST") + def test_sphere(self, sample_rate, num_channels, subtype_and_bit_depth): + """`soundfile_backend.info` can check sph file correctly""" + duration = 1 + num_frames = sample_rate * duration + #data = torch.randn(num_frames, num_channels).numpy() + data = paddle.randn(shape=[num_frames, num_channels]).numpy() + path = self.get_temp_path("data.nist") + subtype, bits_per_sample = subtype_and_bit_depth + soundfile.write(path, data, sample_rate, subtype=subtype) + + info = soundfile_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels + assert info.bits_per_sample == bits_per_sample + assert info.encoding == "PCM_S" + + def test_unknown_subtype_warning(self): + """soundfile_backend.info issues a warning when the subtype is unknown + + This will happen if a new subtype is supported in SoundFile: the _SUBTYPE_TO_BITS_PER_SAMPLE + dict should be updated. + """ + + def _mock_info_func(_): + class MockSoundFileInfo: + samplerate = 8000 + frames = 356 + channels = 2 + subtype = "UNSEEN_SUBTYPE" + format = "UNKNOWN" + + return MockSoundFileInfo() + + with patch("soundfile.info", _mock_info_func): + with warnings.catch_warnings(record=True) as w: + info = soundfile_backend.info("foo") + assert len(w) == 1 + assert "UNSEEN_SUBTYPE subtype is unknown to PaddleAudio" in str(w[-1].message) + assert info.bits_per_sample == 0 + + +class TestFileObject(TempDirMixin, unittest.TestCase): + def _test_fileobj(self, ext, subtype, bits_per_sample): + """Query audio via file-like object works""" + duration = 2 + sample_rate = 16000 + num_channels = 2 + num_frames = sample_rate * duration + path = self.get_temp_path(f"test.{ext}") + + #data = torch.randn(num_frames, num_channels).numpy() + data = paddle.randn(shape=[num_frames, num_channels]).numpy() + soundfile.write(path, data, sample_rate, subtype=subtype) + + with open(path, "rb") as fileobj: + info = soundfile_backend.info(fileobj) + assert info.sample_rate == sample_rate + assert info.num_frames == num_frames + assert info.num_channels == num_channels + assert info.bits_per_sample == bits_per_sample + assert info.encoding == "FLAC" if ext == "flac" else "PCM_S" + + def test_fileobj_wav(self): + """Loading audio via file-like object works""" + self._test_fileobj("wav", "PCM_16", 16) + + @skipIfFormatNotSupported("FLAC") + def test_fileobj_flac(self): + """Loading audio via file-like object works""" + self._test_fileobj("flac", "PCM_16", 16) + + def _test_tarobj(self, ext, subtype, bits_per_sample): + """Query compressed audio via file-like object works""" + duration = 2 + sample_rate = 16000 + num_channels = 2 + num_frames = sample_rate * duration + audio_file = f"test.{ext}" + audio_path = self.get_temp_path(audio_file) + archive_path = self.get_temp_path("archive.tar.gz") + + #data = torch.randn(num_frames, num_channels).numpy() + data = paddle.randn(shape=[num_frames, num_channels]).numpy() + soundfile.write(audio_path, data, sample_rate, subtype=subtype) + + with tarfile.TarFile(archive_path, "w") as tarobj: + tarobj.add(audio_path, arcname=audio_file) + with tarfile.TarFile(archive_path, "r") as tarobj: + fileobj = tarobj.extractfile(audio_file) + info = soundfile_backend.info(fileobj) + assert info.sample_rate == sample_rate + assert info.num_frames == num_frames + assert info.num_channels == num_channels + assert info.bits_per_sample == bits_per_sample + assert info.encoding == "FLAC" if ext == "flac" else "PCM_S" + + def test_tarobj_wav(self): + """Query compressed audio via file-like object works""" + self._test_tarobj("wav", "PCM_16", 16) + + @skipIfFormatNotSupported("FLAC") + def test_tarobj_flac(self): + """Query compressed audio via file-like object works""" + self._test_tarobj("flac", "PCM_16", 16) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/unit/audio/backends/soundfile/load_test.py b/tests/unit/audio/backends/soundfile/load_test.py new file mode 100644 index 000000000..626009382 --- /dev/null +++ b/tests/unit/audio/backends/soundfile/load_test.py @@ -0,0 +1,369 @@ +#this code is from: https://github.com/pytorch/audio/blob/main/test/torchaudio_unittest/backend/soundfile/load_test.py + +import os +import tarfile +import unittest +from unittest.mock import patch +import numpy as np + +from parameterized import parameterized +import paddle +from paddlespeech.audio._internal import module_utils as _mod_utils +from paddlespeech.audio.backends import soundfile_backend +from tests.unit.audio.backends.common import get_bits_per_sample, get_encoding +from tests.unit.common_utils import ( + get_wav_data, + load_wav, + nested_params, + normalize_wav, + save_wav, + TempDirMixin, +) + +from common import dtype2subtype, parameterize, skipIfFormatNotSupported + +import soundfile + + +def _get_mock_path( + ext: str, + dtype: str, + sample_rate: int, + num_channels: int, + num_frames: int, +): + return f"{dtype}_{sample_rate}_{num_channels}_{num_frames}.{ext}" + + +def _get_mock_params(path: str): + filename, ext = path.split(".") + parts = filename.split("_") + return { + "ext": ext, + "dtype": parts[0], + "sample_rate": int(parts[1]), + "num_channels": int(parts[2]), + "num_frames": int(parts[3]), + } + + +class SoundFileMock: + def __init__(self, path, mode): + assert mode == "r" + self.path = path + self._params = _get_mock_params(path) + self._start = None + + @property + def samplerate(self): + return self._params["sample_rate"] + + @property + def format(self): + if self._params["ext"] == "wav": + return "WAV" + if self._params["ext"] == "flac": + return "FLAC" + if self._params["ext"] == "ogg": + return "OGG" + if self._params["ext"] in ["sph", "nis", "nist"]: + return "NIST" + + @property + def subtype(self): + if self._params["ext"] == "ogg": + return "VORBIS" + return dtype2subtype(self._params["dtype"]) + + def _prepare_read(self, start, stop, frames): + assert stop is None + self._start = start + return frames + + def read(self, frames, dtype, always_2d): + assert always_2d + data = get_wav_data( + dtype, + self._params["num_channels"], + normalize=False, + num_frames=self._params["num_frames"], + channels_first=False, + ).numpy() + return data[self._start : self._start + frames] + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + pass + + +class MockedLoadTest(unittest.TestCase): + def assert_dtype(self, ext, dtype, sample_rate, num_channels, normalize, channels_first): + """When format is WAV or NIST, normalize=False will return the native dtype Tensor, otherwise float32""" + num_frames = 3 * sample_rate + path = _get_mock_path(ext, dtype, sample_rate, num_channels, num_frames) + expected_dtype = paddle.float32 if normalize or ext not in ["wav", "nist"] else getattr(paddle, dtype) + with patch("soundfile.SoundFile", SoundFileMock): + found, sr = soundfile_backend.load(path, normalize=normalize, channels_first=channels_first) + assert found.dtype == expected_dtype + assert sample_rate == sr + + @parameterize( + ["int32", "float32", "float64"], + [8000, 16000], + [1, 2], + [True, False], + [True, False], + ) + def test_wav(self, dtype, sample_rate, num_channels, normalize, channels_first): + """Returns native dtype when normalize=False else float32""" + self.assert_dtype("wav", dtype, sample_rate, num_channels, normalize, channels_first) + + @parameterize( + ["int32"], + [8000, 16000], + [1, 2], + [True, False], + [True, False], + ) + def test_sphere(self, dtype, sample_rate, num_channels, normalize, channels_first): + """Returns float32 always""" + self.assert_dtype("sph", dtype, sample_rate, num_channels, normalize, channels_first) + + @parameterize([8000, 16000], [1, 2], [True, False], [True, False]) + def test_ogg(self, sample_rate, num_channels, normalize, channels_first): + """Returns float32 always""" + self.assert_dtype("ogg", "int16", sample_rate, num_channels, normalize, channels_first) + + @parameterize([8000, 16000], [1, 2], [True, False], [True, False]) + def test_flac(self, sample_rate, num_channels, normalize, channels_first): + """`soundfile_backend.load` can load ogg format.""" + self.assert_dtype("flac", "int16", sample_rate, num_channels, normalize, channels_first) + + +class LoadTestBase(TempDirMixin, unittest.TestCase): + def assert_wav( + self, + dtype, + sample_rate, + num_channels, + normalize, + channels_first=True, + duration=1, + ): + """`soundfile_backend.load` can load wav format correctly. + + Wav data loaded with soundfile backend should match those with scipy + """ + path = self.get_temp_path("reference.wav") + num_frames = duration * sample_rate + data = get_wav_data( + dtype, + num_channels, + normalize=normalize, + num_frames=num_frames, + channels_first=channels_first, + ) + save_wav(path, data, sample_rate, channels_first=channels_first) + expected = load_wav(path, normalize=normalize, channels_first=channels_first)[0] + data, sr = soundfile_backend.load(path, normalize=normalize, channels_first=channels_first) + assert sr == sample_rate + np.testing.assert_array_almost_equal(data.numpy(), expected.numpy()) + + def assert_sphere( + self, + dtype, + sample_rate, + num_channels, + channels_first=True, + duration=1, + ): + """`soundfile_backend.load` can load SPHERE format correctly.""" + path = self.get_temp_path("reference.sph") + num_frames = duration * sample_rate + raw = get_wav_data( + dtype, + num_channels, + num_frames=num_frames, + normalize=False, + channels_first=False, + ) + soundfile.write(path, raw, sample_rate, subtype=dtype2subtype(dtype), format="NIST") + expected = normalize_wav(raw.t() if channels_first else raw) + data, sr = soundfile_backend.load(path, channels_first=channels_first) + assert sr == sample_rate + #self.assertEqual(data, expected, atol=1e-4, rtol=1e-8) + np.testing.assert_array_almost_equal(data.numpy(), expected.numpy()) + + def assert_flac( + self, + dtype, + sample_rate, + num_channels, + channels_first=True, + duration=1, + ): + """`soundfile_backend.load` can load FLAC format correctly.""" + path = self.get_temp_path("reference.flac") + num_frames = duration * sample_rate + raw = get_wav_data( + dtype, + num_channels, + num_frames=num_frames, + normalize=False, + channels_first=False, + ) + soundfile.write(path, raw, sample_rate) + expected = normalize_wav(raw.t() if channels_first else raw) + data, sr = soundfile_backend.load(path, channels_first=channels_first) + assert sr == sample_rate + #self.assertEqual(data, expected, atol=1e-4, rtol=1e-8) + np.testing.assert_array_almost_equal(data.numpy(), expected.numpy()) + + + +class TestLoad(LoadTestBase): + """Test the correctness of `soundfile_backend.load` for various formats""" + + @parameterize( + ["float32", "int32"], + [8000, 16000], + [1, 2], + [False, True], + [False, True], + ) + def test_wav(self, dtype, sample_rate, num_channels, normalize, channels_first): + """`soundfile_backend.load` can load wav format correctly.""" + self.assert_wav(dtype, sample_rate, num_channels, normalize, channels_first) + + @parameterize( + ["int32"], + [16000], + [2], + [False], + ) + def test_wav_large(self, dtype, sample_rate, num_channels, normalize): + """`soundfile_backend.load` can load large wav file correctly.""" + two_hours = 2 * 60 * 60 + self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=two_hours) + + @parameterize(["float32", "int32"], [4, 8, 16, 32], [False, True]) + def test_multiple_channels(self, dtype, num_channels, channels_first): + """`soundfile_backend.load` can load wav file with more than 2 channels.""" + sample_rate = 8000 + normalize = False + self.assert_wav(dtype, sample_rate, num_channels, normalize, channels_first) + + #@parameterize(["int32"], [8000, 16000], [1, 2], [False, True]) + #@skipIfFormatNotSupported("NIST") + #def test_sphere(self, dtype, sample_rate, num_channels, channels_first): + #"""`soundfile_backend.load` can load sphere format correctly.""" + #self.assert_sphere(dtype, sample_rate, num_channels, channels_first) + + #@parameterize(["int32"], [8000, 16000], [1, 2], [False, True]) + #@skipIfFormatNotSupported("FLAC") + #def test_flac(self, dtype, sample_rate, num_channels, channels_first): + #"""`soundfile_backend.load` can load flac format correctly.""" + #self.assert_flac(dtype, sample_rate, num_channels, channels_first) + + +class TestLoadFormat(TempDirMixin, unittest.TestCase): + """Given `format` parameter, `so.load` can load files without extension""" + + original = None + path = None + + def _make_file(self, format_): + sample_rate = 8000 + path_with_ext = self.get_temp_path(f"test.{format_}") + data = get_wav_data("float32", num_channels=2).numpy().T + soundfile.write(path_with_ext, data, sample_rate) + expected = soundfile.read(path_with_ext, dtype="float32")[0].T + path = os.path.splitext(path_with_ext)[0] + os.rename(path_with_ext, path) + return path, expected + + def _test_format(self, format_): + """Providing format allows to read file without extension""" + path, expected = self._make_file(format_) + found, _ = soundfile_backend.load(path) + #self.assertEqual(found, expected) + np.testing.assert_array_almost_equal(found, expected) + + @parameterized.expand( + [ + ("WAV",), + ("wav",), + ] + ) + def test_wav(self, format_): + self._test_format(format_) + + @parameterized.expand( + [ + ("FLAC",), + ("flac",), + ] + ) + @skipIfFormatNotSupported("FLAC") + def test_flac(self, format_): + self._test_format(format_) + + +class TestFileObject(TempDirMixin, unittest.TestCase): + def _test_fileobj(self, ext): + """Loading audio via file-like object works""" + sample_rate = 16000 + path = self.get_temp_path(f"test.{ext}") + + data = get_wav_data("float32", num_channels=2).numpy().T + soundfile.write(path, data, sample_rate) + expected = soundfile.read(path, dtype="float32")[0].T + + with open(path, "rb") as fileobj: + found, sr = soundfile_backend.load(fileobj) + assert sr == sample_rate + #self.assertEqual(expected, found) + np.testing.assert_array_almost_equal(found, expected) + + def test_fileobj_wav(self): + """Loading audio via file-like object works""" + self._test_fileobj("wav") + + def test_fileobj_flac(self): + """Loading audio via file-like object works""" + self._test_fileobj("flac") + + def _test_tarfile(self, ext): + """Loading audio via file-like object works""" + sample_rate = 16000 + audio_file = f"test.{ext}" + audio_path = self.get_temp_path(audio_file) + archive_path = self.get_temp_path("archive.tar.gz") + + data = get_wav_data("float32", num_channels=2).numpy().T + soundfile.write(audio_path, data, sample_rate) + expected = soundfile.read(audio_path, dtype="float32")[0].T + + with tarfile.TarFile(archive_path, "w") as tarobj: + tarobj.add(audio_path, arcname=audio_file) + with tarfile.TarFile(archive_path, "r") as tarobj: + fileobj = tarobj.extractfile(audio_file) + found, sr = soundfile_backend.load(fileobj) + + assert sr == sample_rate + #self.assertEqual(expected, found) + np.testing.assert_array_almost_equal(found.numpy(), expected) + + + def test_tarfile_wav(self): + """Loading audio via file-like object works""" + self._test_tarfile("wav") + + def test_tarfile_flac(self): + """Loading audio via file-like object works""" + self._test_tarfile("flac") + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/unit/audio/backends/soundfile/save_test.py b/tests/unit/audio/backends/soundfile/save_test.py new file mode 100644 index 000000000..9139d84cd --- /dev/null +++ b/tests/unit/audio/backends/soundfile/save_test.py @@ -0,0 +1,322 @@ +import io +import unittest +from unittest.mock import patch + +from paddlespeech.audio._internal import module_utils as _mod_utils +from paddlespeech.audio.backends import soundfile_backend +from tests.unit.common_utils import ( + get_wav_data, + load_wav, + nested_params, + normalize_wav, + save_wav, + TempDirMixin, +) + +from common import fetch_wav_subtype, parameterize, skipIfFormatNotSupported + +import paddle +import numpy as np + +import soundfile + + +class MockedSaveTest(unittest.TestCase): + @nested_params( + ["float32", "int32"], + [8000, 16000], + [1, 2], + [False, True], + [ + (None, None), + ("PCM_U", None), + ("PCM_U", 8), + ("PCM_S", None), + ("PCM_S", 16), + ("PCM_S", 32), + ("PCM_F", None), + ("PCM_F", 32), + ("PCM_F", 64), + ("ULAW", None), + ("ULAW", 8), + ("ALAW", None), + ("ALAW", 8), + ], + ) + @patch("soundfile.write") + def test_wav(self, dtype, sample_rate, num_channels, channels_first, enc_params, mocked_write): + """soundfile_backend.save passes correct subtype to soundfile.write when WAV""" + filepath = "foo.wav" + input_tensor = get_wav_data( + dtype, + num_channels, + num_frames=3 * sample_rate, + normalize=dtype == "float32", + channels_first=channels_first, + ) + input_tensor = paddle.transpose(input_tensor, [1, 0]) + + encoding, bits_per_sample = enc_params + soundfile_backend.save( + filepath, + input_tensor, + sample_rate, + channels_first=channels_first, + encoding=encoding, + bits_per_sample=bits_per_sample, + ) + + # on +Py3.8 call_args.kwargs is more descreptive + args = mocked_write.call_args[1] + assert args["file"] == filepath + assert args["samplerate"] == sample_rate + assert args["subtype"] == fetch_wav_subtype(dtype, encoding, bits_per_sample) + assert args["format"] is None + tensor_result = paddle.transpose(input_tensor, [1, 0]) if channels_first else input_tensor + #self.assertEqual(args["data"], tensor_result.numpy()) + np.testing.assert_array_almost_equal(args["data"].numpy(), tensor_result.numpy()) + + + + @patch("soundfile.write") + def assert_non_wav( + self, + fmt, + dtype, + sample_rate, + num_channels, + channels_first, + mocked_write, + encoding=None, + bits_per_sample=None, + ): + """soundfile_backend.save passes correct subtype and format to soundfile.write when SPHERE""" + filepath = f"foo.{fmt}" + input_tensor = get_wav_data( + dtype, + num_channels, + num_frames=3 * sample_rate, + normalize=False, + channels_first=channels_first, + ) + input_tensor = paddle.transpose(input_tensor, [1, 0]) + + expected_data = paddle.transpose(input_tensor, [1, 0]) if channels_first else input_tensor + + soundfile_backend.save( + filepath, + input_tensor, + sample_rate, + channels_first, + encoding=encoding, + bits_per_sample=bits_per_sample, + ) + + # on +Py3.8 call_args.kwargs is more descreptive + args = mocked_write.call_args[1] + assert args["file"] == filepath + assert args["samplerate"] == sample_rate + if fmt in ["sph", "nist", "nis"]: + assert args["format"] == "NIST" + else: + assert args["format"] is None + np.testing.assert_array_almost_equal(args["data"].numpy(), expected_data.numpy()) + #self.assertEqual(args["data"], expected_data) + + @nested_params( + ["sph", "nist", "nis"], + ["int32"], + [8000, 16000], + [1, 2], + [False, True], + [ + ("PCM_S", 8), + ("PCM_S", 16), + ("PCM_S", 24), + ("PCM_S", 32), + ("ULAW", 8), + ("ALAW", 8), + ("ALAW", 16), + ("ALAW", 24), + ("ALAW", 32), + ], + ) + def test_sph(self, fmt, dtype, sample_rate, num_channels, channels_first, enc_params): + """soundfile_backend.save passes default format and subtype (None-s) to + soundfile.write when not WAV""" + encoding, bits_per_sample = enc_params + self.assert_non_wav( + fmt, dtype, sample_rate, num_channels, channels_first, encoding=encoding, bits_per_sample=bits_per_sample + ) + + @parameterize( + ["int32"], + [8000, 16000], + [1, 2], + [False, True], + [8, 16, 24], + ) + def test_flac(self, dtype, sample_rate, num_channels, channels_first, bits_per_sample): + """soundfile_backend.save passes default format and subtype (None-s) to + soundfile.write when not WAV""" + self.assert_non_wav("flac", dtype, sample_rate, num_channels, channels_first, bits_per_sample=bits_per_sample) + + @parameterize( + ["int32"], + [8000, 16000], + [1, 2], + [False, True], + ) + def test_ogg(self, dtype, sample_rate, num_channels, channels_first): + """soundfile_backend.save passes default format and subtype (None-s) to + soundfile.write when not WAV""" + self.assert_non_wav("ogg", dtype, sample_rate, num_channels, channels_first) + + +class SaveTestBase(TempDirMixin, unittest.TestCase): + def assert_wav(self, dtype, sample_rate, num_channels, num_frames): + """`soundfile_backend.save` can save wav format.""" + path = self.get_temp_path("data.wav") + expected = get_wav_data(dtype, num_channels, num_frames=num_frames, normalize=False) + soundfile_backend.save(path, expected, sample_rate) + found, sr = load_wav(path, normalize=False) + assert sample_rate == sr + #self.assertEqual(found, expected) + np.testing.assert_array_almost_equal(found.numpy(), expected.numpy()) + + def _assert_non_wav(self, fmt, dtype, sample_rate, num_channels): + """`soundfile_backend.save` can save non-wav format. + + Due to precision missmatch, and the lack of alternative way to decode the + resulting files without using soundfile, only meta data are validated. + """ + num_frames = sample_rate * 3 + path = self.get_temp_path(f"data.{fmt}") + expected = get_wav_data(dtype, num_channels, num_frames=num_frames, normalize=False) + soundfile_backend.save(path, expected, sample_rate) + sinfo = soundfile.info(path) + assert sinfo.format == fmt.upper() + #assert sinfo.frames == num_frames this go wrong + assert sinfo.channels == num_channels + assert sinfo.samplerate == sample_rate + + def assert_flac(self, dtype, sample_rate, num_channels): + """`soundfile_backend.save` can save flac format.""" + self._assert_non_wav("flac", dtype, sample_rate, num_channels) + + def assert_sphere(self, dtype, sample_rate, num_channels): + """`soundfile_backend.save` can save sph format.""" + self._assert_non_wav("nist", dtype, sample_rate, num_channels) + + def assert_ogg(self, dtype, sample_rate, num_channels): + """`soundfile_backend.save` can save ogg format. + + As we cannot inspect the OGG format (it's lossy), we only check the metadata. + """ + self._assert_non_wav("ogg", dtype, sample_rate, num_channels) + + +class TestSave(SaveTestBase): + @parameterize( + ["float32", "int32"], + [8000, 16000], + [1, 2], + ) + def test_wav(self, dtype, sample_rate, num_channels): + """`soundfile_backend.save` can save wav format.""" + self.assert_wav(dtype, sample_rate, num_channels, num_frames=None) + + @parameterize( + ["float32", "int32"], + [4, 8, 16, 32], + ) + def test_multiple_channels(self, dtype, num_channels): + """`soundfile_backend.save` can save wav with more than 2 channels.""" + sample_rate = 8000 + self.assert_wav(dtype, sample_rate, num_channels, num_frames=None) + + @parameterize( + ["int32"], + [8000, 16000], + [1, 2], + ) + @skipIfFormatNotSupported("NIST") + def test_sphere(self, dtype, sample_rate, num_channels): + """`soundfile_backend.save` can save sph format.""" + self.assert_sphere(dtype, sample_rate, num_channels) + + @parameterize( + [8000, 16000], + [1, 2], + ) + @skipIfFormatNotSupported("FLAC") + def test_flac(self, sample_rate, num_channels): + """`soundfile_backend.save` can save flac format.""" + self.assert_flac("float32", sample_rate, num_channels) + + @parameterize( + [8000, 16000], + [1, 2], + ) + @skipIfFormatNotSupported("OGG") + def test_ogg(self, sample_rate, num_channels): + """`soundfile_backend.save` can save ogg/vorbis format.""" + self.assert_ogg("float32", sample_rate, num_channels) + + +class TestSaveParams(TempDirMixin, unittest.TestCase): + """Test the correctness of optional parameters of `soundfile_backend.save`""" + + @parameterize([True, False]) + def test_channels_first(self, channels_first): + """channels_first swaps axes""" + path = self.get_temp_path("data.wav") + data = get_wav_data("int32", 2, channels_first=channels_first) + soundfile_backend.save(path, data, 8000, channels_first=channels_first) + found = load_wav(path)[0] + expected = data if channels_first else data.transpose([1, 0]) + #self.assertEqual(found, expected, atol=1e-4, rtol=1e-8) + np.testing.assert_array_almost_equal(found.numpy(), expected.numpy()) + + +class TestFileObject(TempDirMixin, unittest.TestCase): + def _test_fileobj(self, ext): + """Saving audio to file-like object works""" + sample_rate = 16000 + path = self.get_temp_path(f"test.{ext}") + + subtype = "FLOAT" if ext == "wav" else None + data = get_wav_data("float32", num_channels=2) + soundfile.write(path, data.numpy().T, sample_rate, subtype=subtype) + expected = soundfile.read(path, dtype="float32")[0] + + fileobj = io.BytesIO() + soundfile_backend.save(fileobj, data, sample_rate, format=ext) + fileobj.seek(0) + found, sr = soundfile.read(fileobj, dtype="float32") + + assert sr == sample_rate + #self.assertEqual(expected, found, atol=1e-4, rtol=1e-8) + np.testing.assert_array_almost_equal(found, expected) + + def test_fileobj_wav(self): + """Saving audio via file-like object works""" + self._test_fileobj("wav") + + @skipIfFormatNotSupported("FLAC") + def test_fileobj_flac(self): + """Saving audio via file-like object works""" + self._test_fileobj("flac") + + @skipIfFormatNotSupported("NIST") + def test_fileobj_nist(self): + """Saving audio via file-like object works""" + self._test_fileobj("NIST") + + @skipIfFormatNotSupported("OGG") + def test_fileobj_ogg(self): + """Saving audio via file-like object works""" + self._test_fileobj("OGG") + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/unit/audio/backends/sox_io/info_test.py b/tests/unit/audio/backends/sox_io/info_test.py index 06aa54d25..077d6051d 100644 --- a/tests/unit/audio/backends/sox_io/info_test.py +++ b/tests/unit/audio/backends/sox_io/info_test.py @@ -9,6 +9,7 @@ import os import io from parameterized import parameterized +from tests.unit.audio.backends.common import get_bits_per_sample, get_encoding from paddlespeech.audio.backends import sox_io_backend from tests.unit.common_utils import ( @@ -20,8 +21,6 @@ from tests.unit.common_utils import ( data_utils ) -from common import get_encoding, get_bits_per_sample - #code is from:https://github.com/pytorch/audio/blob/main/torchaudio/test/torchaudio_unittest/backend/sox_io/info_test.py class TestInfo(TempDirMixin, unittest.TestCase): @@ -287,4 +286,4 @@ class TestFileObject(FileObjTestBase, unittest.TestCase): if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main()