commit
dc6f8ff10c
@ -0,0 +1,57 @@
|
|||||||
|
import itertools
|
||||||
|
from unittest import skipIf
|
||||||
|
|
||||||
|
from parameterized import parameterized
|
||||||
|
from paddlespeech.audio._internal.module_utils import is_module_available
|
||||||
|
|
||||||
|
|
||||||
|
def name_func(func, _, params):
|
||||||
|
return f'{func.__name__}_{"_".join(str(arg) for arg in params.args)}'
|
||||||
|
|
||||||
|
|
||||||
|
def dtype2subtype(dtype):
|
||||||
|
return {
|
||||||
|
"float64": "DOUBLE",
|
||||||
|
"float32": "FLOAT",
|
||||||
|
"int32": "PCM_32",
|
||||||
|
"int16": "PCM_16",
|
||||||
|
"uint8": "PCM_U8",
|
||||||
|
"int8": "PCM_S8",
|
||||||
|
}[dtype]
|
||||||
|
|
||||||
|
|
||||||
|
def skipIfFormatNotSupported(fmt):
|
||||||
|
fmts = []
|
||||||
|
if is_module_available("soundfile"):
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
fmts = soundfile.available_formats()
|
||||||
|
return skipIf(fmt not in fmts, f'"{fmt}" is not supported by soundfile')
|
||||||
|
return skipIf(True, '"soundfile" not available.')
|
||||||
|
|
||||||
|
|
||||||
|
def parameterize(*params):
|
||||||
|
return parameterized.expand(list(itertools.product(*params)), name_func=name_func)
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_wav_subtype(dtype, encoding, bits_per_sample):
|
||||||
|
subtype = {
|
||||||
|
(None, None): dtype2subtype(dtype),
|
||||||
|
(None, 8): "PCM_U8",
|
||||||
|
("PCM_U", None): "PCM_U8",
|
||||||
|
("PCM_U", 8): "PCM_U8",
|
||||||
|
("PCM_S", None): "PCM_32",
|
||||||
|
("PCM_S", 16): "PCM_16",
|
||||||
|
("PCM_S", 32): "PCM_32",
|
||||||
|
("PCM_F", None): "FLOAT",
|
||||||
|
("PCM_F", 32): "FLOAT",
|
||||||
|
("PCM_F", 64): "DOUBLE",
|
||||||
|
("ULAW", None): "ULAW",
|
||||||
|
("ULAW", 8): "ULAW",
|
||||||
|
("ALAW", None): "ALAW",
|
||||||
|
("ALAW", 8): "ALAW",
|
||||||
|
}.get((encoding, bits_per_sample))
|
||||||
|
if subtype:
|
||||||
|
return subtype
|
||||||
|
raise ValueError(f"wav does not support ({encoding}, {bits_per_sample}).")
|
||||||
|
|
@ -0,0 +1,199 @@
|
|||||||
|
#this code is from: https://github.com/pytorch/audio/blob/main/test/torchaudio_unittest/backend/soundfile/info_test.py
|
||||||
|
|
||||||
|
import tarfile
|
||||||
|
import warnings
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddlespeech.audio._internal import module_utils as _mod_utils
|
||||||
|
from paddlespeech.audio.backends import soundfile_backend
|
||||||
|
from tests.unit.audio.backends.common import get_bits_per_sample, get_encoding
|
||||||
|
from tests.unit.common_utils import (
|
||||||
|
get_wav_data,
|
||||||
|
nested_params,
|
||||||
|
save_wav,
|
||||||
|
TempDirMixin,
|
||||||
|
)
|
||||||
|
|
||||||
|
from common import parameterize, skipIfFormatNotSupported
|
||||||
|
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
|
||||||
|
class TestInfo(TempDirMixin, unittest.TestCase):
|
||||||
|
@parameterize(
|
||||||
|
["float32", "int32"],
|
||||||
|
[8000, 16000],
|
||||||
|
[1, 2],
|
||||||
|
)
|
||||||
|
def test_wav(self, dtype, sample_rate, num_channels):
|
||||||
|
"""`soundfile_backend.info` can check wav file correctly"""
|
||||||
|
duration = 1
|
||||||
|
path = self.get_temp_path("data.wav")
|
||||||
|
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
|
||||||
|
save_wav(path, data, sample_rate)
|
||||||
|
info = soundfile_backend.info(path)
|
||||||
|
assert info.sample_rate == sample_rate
|
||||||
|
assert info.num_frames == sample_rate * duration
|
||||||
|
assert info.num_channels == num_channels
|
||||||
|
assert info.bits_per_sample == get_bits_per_sample("wav", dtype)
|
||||||
|
assert info.encoding == get_encoding("wav", dtype)
|
||||||
|
|
||||||
|
@parameterize([8000, 16000], [1, 2])
|
||||||
|
@skipIfFormatNotSupported("FLAC")
|
||||||
|
def test_flac(self, sample_rate, num_channels):
|
||||||
|
"""`soundfile_backend.info` can check flac file correctly"""
|
||||||
|
duration = 1
|
||||||
|
num_frames = sample_rate * duration
|
||||||
|
#data = torch.randn(num_frames, num_channels).numpy()
|
||||||
|
data = paddle.randn(shape=[num_frames, num_channels]).numpy()
|
||||||
|
|
||||||
|
path = self.get_temp_path("data.flac")
|
||||||
|
soundfile.write(path, data, sample_rate)
|
||||||
|
|
||||||
|
info = soundfile_backend.info(path)
|
||||||
|
assert info.sample_rate == sample_rate
|
||||||
|
assert info.num_frames == num_frames
|
||||||
|
assert info.num_channels == num_channels
|
||||||
|
assert info.bits_per_sample == 16
|
||||||
|
assert info.encoding == "FLAC"
|
||||||
|
|
||||||
|
#@parameterize([8000, 16000], [1, 2])
|
||||||
|
#@skipIfFormatNotSupported("OGG")
|
||||||
|
#def test_ogg(self, sample_rate, num_channels):
|
||||||
|
#"""`soundfile_backend.info` can check ogg file correctly"""
|
||||||
|
#duration = 1
|
||||||
|
#num_frames = sample_rate * duration
|
||||||
|
##data = torch.randn(num_frames, num_channels).numpy()
|
||||||
|
#data = paddle.randn(shape=[num_frames, num_channels]).numpy()
|
||||||
|
#print(len(data))
|
||||||
|
#path = self.get_temp_path("data.ogg")
|
||||||
|
#soundfile.write(path, data, sample_rate)
|
||||||
|
|
||||||
|
#info = soundfile_backend.info(path)
|
||||||
|
#print(info)
|
||||||
|
#assert info.sample_rate == sample_rate
|
||||||
|
#print("info")
|
||||||
|
#print(info.num_frames)
|
||||||
|
#print("jiji")
|
||||||
|
#print(sample_rate*duration)
|
||||||
|
##assert info.num_frames == sample_rate * duration
|
||||||
|
#assert info.num_channels == num_channels
|
||||||
|
#assert info.bits_per_sample == 0
|
||||||
|
#assert info.encoding == "VORBIS"
|
||||||
|
|
||||||
|
@nested_params(
|
||||||
|
[8000, 16000],
|
||||||
|
[1, 2],
|
||||||
|
[("PCM_24", 24), ("PCM_32", 32)],
|
||||||
|
)
|
||||||
|
@skipIfFormatNotSupported("NIST")
|
||||||
|
def test_sphere(self, sample_rate, num_channels, subtype_and_bit_depth):
|
||||||
|
"""`soundfile_backend.info` can check sph file correctly"""
|
||||||
|
duration = 1
|
||||||
|
num_frames = sample_rate * duration
|
||||||
|
#data = torch.randn(num_frames, num_channels).numpy()
|
||||||
|
data = paddle.randn(shape=[num_frames, num_channels]).numpy()
|
||||||
|
path = self.get_temp_path("data.nist")
|
||||||
|
subtype, bits_per_sample = subtype_and_bit_depth
|
||||||
|
soundfile.write(path, data, sample_rate, subtype=subtype)
|
||||||
|
|
||||||
|
info = soundfile_backend.info(path)
|
||||||
|
assert info.sample_rate == sample_rate
|
||||||
|
assert info.num_frames == sample_rate * duration
|
||||||
|
assert info.num_channels == num_channels
|
||||||
|
assert info.bits_per_sample == bits_per_sample
|
||||||
|
assert info.encoding == "PCM_S"
|
||||||
|
|
||||||
|
def test_unknown_subtype_warning(self):
|
||||||
|
"""soundfile_backend.info issues a warning when the subtype is unknown
|
||||||
|
|
||||||
|
This will happen if a new subtype is supported in SoundFile: the _SUBTYPE_TO_BITS_PER_SAMPLE
|
||||||
|
dict should be updated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _mock_info_func(_):
|
||||||
|
class MockSoundFileInfo:
|
||||||
|
samplerate = 8000
|
||||||
|
frames = 356
|
||||||
|
channels = 2
|
||||||
|
subtype = "UNSEEN_SUBTYPE"
|
||||||
|
format = "UNKNOWN"
|
||||||
|
|
||||||
|
return MockSoundFileInfo()
|
||||||
|
|
||||||
|
with patch("soundfile.info", _mock_info_func):
|
||||||
|
with warnings.catch_warnings(record=True) as w:
|
||||||
|
info = soundfile_backend.info("foo")
|
||||||
|
assert len(w) == 1
|
||||||
|
assert "UNSEEN_SUBTYPE subtype is unknown to PaddleAudio" in str(w[-1].message)
|
||||||
|
assert info.bits_per_sample == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestFileObject(TempDirMixin, unittest.TestCase):
|
||||||
|
def _test_fileobj(self, ext, subtype, bits_per_sample):
|
||||||
|
"""Query audio via file-like object works"""
|
||||||
|
duration = 2
|
||||||
|
sample_rate = 16000
|
||||||
|
num_channels = 2
|
||||||
|
num_frames = sample_rate * duration
|
||||||
|
path = self.get_temp_path(f"test.{ext}")
|
||||||
|
|
||||||
|
#data = torch.randn(num_frames, num_channels).numpy()
|
||||||
|
data = paddle.randn(shape=[num_frames, num_channels]).numpy()
|
||||||
|
soundfile.write(path, data, sample_rate, subtype=subtype)
|
||||||
|
|
||||||
|
with open(path, "rb") as fileobj:
|
||||||
|
info = soundfile_backend.info(fileobj)
|
||||||
|
assert info.sample_rate == sample_rate
|
||||||
|
assert info.num_frames == num_frames
|
||||||
|
assert info.num_channels == num_channels
|
||||||
|
assert info.bits_per_sample == bits_per_sample
|
||||||
|
assert info.encoding == "FLAC" if ext == "flac" else "PCM_S"
|
||||||
|
|
||||||
|
def test_fileobj_wav(self):
|
||||||
|
"""Loading audio via file-like object works"""
|
||||||
|
self._test_fileobj("wav", "PCM_16", 16)
|
||||||
|
|
||||||
|
@skipIfFormatNotSupported("FLAC")
|
||||||
|
def test_fileobj_flac(self):
|
||||||
|
"""Loading audio via file-like object works"""
|
||||||
|
self._test_fileobj("flac", "PCM_16", 16)
|
||||||
|
|
||||||
|
def _test_tarobj(self, ext, subtype, bits_per_sample):
|
||||||
|
"""Query compressed audio via file-like object works"""
|
||||||
|
duration = 2
|
||||||
|
sample_rate = 16000
|
||||||
|
num_channels = 2
|
||||||
|
num_frames = sample_rate * duration
|
||||||
|
audio_file = f"test.{ext}"
|
||||||
|
audio_path = self.get_temp_path(audio_file)
|
||||||
|
archive_path = self.get_temp_path("archive.tar.gz")
|
||||||
|
|
||||||
|
#data = torch.randn(num_frames, num_channels).numpy()
|
||||||
|
data = paddle.randn(shape=[num_frames, num_channels]).numpy()
|
||||||
|
soundfile.write(audio_path, data, sample_rate, subtype=subtype)
|
||||||
|
|
||||||
|
with tarfile.TarFile(archive_path, "w") as tarobj:
|
||||||
|
tarobj.add(audio_path, arcname=audio_file)
|
||||||
|
with tarfile.TarFile(archive_path, "r") as tarobj:
|
||||||
|
fileobj = tarobj.extractfile(audio_file)
|
||||||
|
info = soundfile_backend.info(fileobj)
|
||||||
|
assert info.sample_rate == sample_rate
|
||||||
|
assert info.num_frames == num_frames
|
||||||
|
assert info.num_channels == num_channels
|
||||||
|
assert info.bits_per_sample == bits_per_sample
|
||||||
|
assert info.encoding == "FLAC" if ext == "flac" else "PCM_S"
|
||||||
|
|
||||||
|
def test_tarobj_wav(self):
|
||||||
|
"""Query compressed audio via file-like object works"""
|
||||||
|
self._test_tarobj("wav", "PCM_16", 16)
|
||||||
|
|
||||||
|
@skipIfFormatNotSupported("FLAC")
|
||||||
|
def test_tarobj_flac(self):
|
||||||
|
"""Query compressed audio via file-like object works"""
|
||||||
|
self._test_tarobj("flac", "PCM_16", 16)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
@ -0,0 +1,369 @@
|
|||||||
|
#this code is from: https://github.com/pytorch/audio/blob/main/test/torchaudio_unittest/backend/soundfile/load_test.py
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tarfile
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from parameterized import parameterized
|
||||||
|
import paddle
|
||||||
|
from paddlespeech.audio._internal import module_utils as _mod_utils
|
||||||
|
from paddlespeech.audio.backends import soundfile_backend
|
||||||
|
from tests.unit.audio.backends.common import get_bits_per_sample, get_encoding
|
||||||
|
from tests.unit.common_utils import (
|
||||||
|
get_wav_data,
|
||||||
|
load_wav,
|
||||||
|
nested_params,
|
||||||
|
normalize_wav,
|
||||||
|
save_wav,
|
||||||
|
TempDirMixin,
|
||||||
|
)
|
||||||
|
|
||||||
|
from common import dtype2subtype, parameterize, skipIfFormatNotSupported
|
||||||
|
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
|
||||||
|
def _get_mock_path(
|
||||||
|
ext: str,
|
||||||
|
dtype: str,
|
||||||
|
sample_rate: int,
|
||||||
|
num_channels: int,
|
||||||
|
num_frames: int,
|
||||||
|
):
|
||||||
|
return f"{dtype}_{sample_rate}_{num_channels}_{num_frames}.{ext}"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_mock_params(path: str):
|
||||||
|
filename, ext = path.split(".")
|
||||||
|
parts = filename.split("_")
|
||||||
|
return {
|
||||||
|
"ext": ext,
|
||||||
|
"dtype": parts[0],
|
||||||
|
"sample_rate": int(parts[1]),
|
||||||
|
"num_channels": int(parts[2]),
|
||||||
|
"num_frames": int(parts[3]),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SoundFileMock:
|
||||||
|
def __init__(self, path, mode):
|
||||||
|
assert mode == "r"
|
||||||
|
self.path = path
|
||||||
|
self._params = _get_mock_params(path)
|
||||||
|
self._start = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def samplerate(self):
|
||||||
|
return self._params["sample_rate"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def format(self):
|
||||||
|
if self._params["ext"] == "wav":
|
||||||
|
return "WAV"
|
||||||
|
if self._params["ext"] == "flac":
|
||||||
|
return "FLAC"
|
||||||
|
if self._params["ext"] == "ogg":
|
||||||
|
return "OGG"
|
||||||
|
if self._params["ext"] in ["sph", "nis", "nist"]:
|
||||||
|
return "NIST"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def subtype(self):
|
||||||
|
if self._params["ext"] == "ogg":
|
||||||
|
return "VORBIS"
|
||||||
|
return dtype2subtype(self._params["dtype"])
|
||||||
|
|
||||||
|
def _prepare_read(self, start, stop, frames):
|
||||||
|
assert stop is None
|
||||||
|
self._start = start
|
||||||
|
return frames
|
||||||
|
|
||||||
|
def read(self, frames, dtype, always_2d):
|
||||||
|
assert always_2d
|
||||||
|
data = get_wav_data(
|
||||||
|
dtype,
|
||||||
|
self._params["num_channels"],
|
||||||
|
normalize=False,
|
||||||
|
num_frames=self._params["num_frames"],
|
||||||
|
channels_first=False,
|
||||||
|
).numpy()
|
||||||
|
return data[self._start : self._start + frames]
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MockedLoadTest(unittest.TestCase):
|
||||||
|
def assert_dtype(self, ext, dtype, sample_rate, num_channels, normalize, channels_first):
|
||||||
|
"""When format is WAV or NIST, normalize=False will return the native dtype Tensor, otherwise float32"""
|
||||||
|
num_frames = 3 * sample_rate
|
||||||
|
path = _get_mock_path(ext, dtype, sample_rate, num_channels, num_frames)
|
||||||
|
expected_dtype = paddle.float32 if normalize or ext not in ["wav", "nist"] else getattr(paddle, dtype)
|
||||||
|
with patch("soundfile.SoundFile", SoundFileMock):
|
||||||
|
found, sr = soundfile_backend.load(path, normalize=normalize, channels_first=channels_first)
|
||||||
|
assert found.dtype == expected_dtype
|
||||||
|
assert sample_rate == sr
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
["int32", "float32", "float64"],
|
||||||
|
[8000, 16000],
|
||||||
|
[1, 2],
|
||||||
|
[True, False],
|
||||||
|
[True, False],
|
||||||
|
)
|
||||||
|
def test_wav(self, dtype, sample_rate, num_channels, normalize, channels_first):
|
||||||
|
"""Returns native dtype when normalize=False else float32"""
|
||||||
|
self.assert_dtype("wav", dtype, sample_rate, num_channels, normalize, channels_first)
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
["int32"],
|
||||||
|
[8000, 16000],
|
||||||
|
[1, 2],
|
||||||
|
[True, False],
|
||||||
|
[True, False],
|
||||||
|
)
|
||||||
|
def test_sphere(self, dtype, sample_rate, num_channels, normalize, channels_first):
|
||||||
|
"""Returns float32 always"""
|
||||||
|
self.assert_dtype("sph", dtype, sample_rate, num_channels, normalize, channels_first)
|
||||||
|
|
||||||
|
@parameterize([8000, 16000], [1, 2], [True, False], [True, False])
|
||||||
|
def test_ogg(self, sample_rate, num_channels, normalize, channels_first):
|
||||||
|
"""Returns float32 always"""
|
||||||
|
self.assert_dtype("ogg", "int16", sample_rate, num_channels, normalize, channels_first)
|
||||||
|
|
||||||
|
@parameterize([8000, 16000], [1, 2], [True, False], [True, False])
|
||||||
|
def test_flac(self, sample_rate, num_channels, normalize, channels_first):
|
||||||
|
"""`soundfile_backend.load` can load ogg format."""
|
||||||
|
self.assert_dtype("flac", "int16", sample_rate, num_channels, normalize, channels_first)
|
||||||
|
|
||||||
|
|
||||||
|
class LoadTestBase(TempDirMixin, unittest.TestCase):
|
||||||
|
def assert_wav(
|
||||||
|
self,
|
||||||
|
dtype,
|
||||||
|
sample_rate,
|
||||||
|
num_channels,
|
||||||
|
normalize,
|
||||||
|
channels_first=True,
|
||||||
|
duration=1,
|
||||||
|
):
|
||||||
|
"""`soundfile_backend.load` can load wav format correctly.
|
||||||
|
|
||||||
|
Wav data loaded with soundfile backend should match those with scipy
|
||||||
|
"""
|
||||||
|
path = self.get_temp_path("reference.wav")
|
||||||
|
num_frames = duration * sample_rate
|
||||||
|
data = get_wav_data(
|
||||||
|
dtype,
|
||||||
|
num_channels,
|
||||||
|
normalize=normalize,
|
||||||
|
num_frames=num_frames,
|
||||||
|
channels_first=channels_first,
|
||||||
|
)
|
||||||
|
save_wav(path, data, sample_rate, channels_first=channels_first)
|
||||||
|
expected = load_wav(path, normalize=normalize, channels_first=channels_first)[0]
|
||||||
|
data, sr = soundfile_backend.load(path, normalize=normalize, channels_first=channels_first)
|
||||||
|
assert sr == sample_rate
|
||||||
|
np.testing.assert_array_almost_equal(data.numpy(), expected.numpy())
|
||||||
|
|
||||||
|
def assert_sphere(
|
||||||
|
self,
|
||||||
|
dtype,
|
||||||
|
sample_rate,
|
||||||
|
num_channels,
|
||||||
|
channels_first=True,
|
||||||
|
duration=1,
|
||||||
|
):
|
||||||
|
"""`soundfile_backend.load` can load SPHERE format correctly."""
|
||||||
|
path = self.get_temp_path("reference.sph")
|
||||||
|
num_frames = duration * sample_rate
|
||||||
|
raw = get_wav_data(
|
||||||
|
dtype,
|
||||||
|
num_channels,
|
||||||
|
num_frames=num_frames,
|
||||||
|
normalize=False,
|
||||||
|
channels_first=False,
|
||||||
|
)
|
||||||
|
soundfile.write(path, raw, sample_rate, subtype=dtype2subtype(dtype), format="NIST")
|
||||||
|
expected = normalize_wav(raw.t() if channels_first else raw)
|
||||||
|
data, sr = soundfile_backend.load(path, channels_first=channels_first)
|
||||||
|
assert sr == sample_rate
|
||||||
|
#self.assertEqual(data, expected, atol=1e-4, rtol=1e-8)
|
||||||
|
np.testing.assert_array_almost_equal(data.numpy(), expected.numpy())
|
||||||
|
|
||||||
|
def assert_flac(
|
||||||
|
self,
|
||||||
|
dtype,
|
||||||
|
sample_rate,
|
||||||
|
num_channels,
|
||||||
|
channels_first=True,
|
||||||
|
duration=1,
|
||||||
|
):
|
||||||
|
"""`soundfile_backend.load` can load FLAC format correctly."""
|
||||||
|
path = self.get_temp_path("reference.flac")
|
||||||
|
num_frames = duration * sample_rate
|
||||||
|
raw = get_wav_data(
|
||||||
|
dtype,
|
||||||
|
num_channels,
|
||||||
|
num_frames=num_frames,
|
||||||
|
normalize=False,
|
||||||
|
channels_first=False,
|
||||||
|
)
|
||||||
|
soundfile.write(path, raw, sample_rate)
|
||||||
|
expected = normalize_wav(raw.t() if channels_first else raw)
|
||||||
|
data, sr = soundfile_backend.load(path, channels_first=channels_first)
|
||||||
|
assert sr == sample_rate
|
||||||
|
#self.assertEqual(data, expected, atol=1e-4, rtol=1e-8)
|
||||||
|
np.testing.assert_array_almost_equal(data.numpy(), expected.numpy())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoad(LoadTestBase):
|
||||||
|
"""Test the correctness of `soundfile_backend.load` for various formats"""
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
["float32", "int32"],
|
||||||
|
[8000, 16000],
|
||||||
|
[1, 2],
|
||||||
|
[False, True],
|
||||||
|
[False, True],
|
||||||
|
)
|
||||||
|
def test_wav(self, dtype, sample_rate, num_channels, normalize, channels_first):
|
||||||
|
"""`soundfile_backend.load` can load wav format correctly."""
|
||||||
|
self.assert_wav(dtype, sample_rate, num_channels, normalize, channels_first)
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
["int32"],
|
||||||
|
[16000],
|
||||||
|
[2],
|
||||||
|
[False],
|
||||||
|
)
|
||||||
|
def test_wav_large(self, dtype, sample_rate, num_channels, normalize):
|
||||||
|
"""`soundfile_backend.load` can load large wav file correctly."""
|
||||||
|
two_hours = 2 * 60 * 60
|
||||||
|
self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=two_hours)
|
||||||
|
|
||||||
|
@parameterize(["float32", "int32"], [4, 8, 16, 32], [False, True])
|
||||||
|
def test_multiple_channels(self, dtype, num_channels, channels_first):
|
||||||
|
"""`soundfile_backend.load` can load wav file with more than 2 channels."""
|
||||||
|
sample_rate = 8000
|
||||||
|
normalize = False
|
||||||
|
self.assert_wav(dtype, sample_rate, num_channels, normalize, channels_first)
|
||||||
|
|
||||||
|
#@parameterize(["int32"], [8000, 16000], [1, 2], [False, True])
|
||||||
|
#@skipIfFormatNotSupported("NIST")
|
||||||
|
#def test_sphere(self, dtype, sample_rate, num_channels, channels_first):
|
||||||
|
#"""`soundfile_backend.load` can load sphere format correctly."""
|
||||||
|
#self.assert_sphere(dtype, sample_rate, num_channels, channels_first)
|
||||||
|
|
||||||
|
#@parameterize(["int32"], [8000, 16000], [1, 2], [False, True])
|
||||||
|
#@skipIfFormatNotSupported("FLAC")
|
||||||
|
#def test_flac(self, dtype, sample_rate, num_channels, channels_first):
|
||||||
|
#"""`soundfile_backend.load` can load flac format correctly."""
|
||||||
|
#self.assert_flac(dtype, sample_rate, num_channels, channels_first)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoadFormat(TempDirMixin, unittest.TestCase):
|
||||||
|
"""Given `format` parameter, `so.load` can load files without extension"""
|
||||||
|
|
||||||
|
original = None
|
||||||
|
path = None
|
||||||
|
|
||||||
|
def _make_file(self, format_):
|
||||||
|
sample_rate = 8000
|
||||||
|
path_with_ext = self.get_temp_path(f"test.{format_}")
|
||||||
|
data = get_wav_data("float32", num_channels=2).numpy().T
|
||||||
|
soundfile.write(path_with_ext, data, sample_rate)
|
||||||
|
expected = soundfile.read(path_with_ext, dtype="float32")[0].T
|
||||||
|
path = os.path.splitext(path_with_ext)[0]
|
||||||
|
os.rename(path_with_ext, path)
|
||||||
|
return path, expected
|
||||||
|
|
||||||
|
def _test_format(self, format_):
|
||||||
|
"""Providing format allows to read file without extension"""
|
||||||
|
path, expected = self._make_file(format_)
|
||||||
|
found, _ = soundfile_backend.load(path)
|
||||||
|
#self.assertEqual(found, expected)
|
||||||
|
np.testing.assert_array_almost_equal(found, expected)
|
||||||
|
|
||||||
|
@parameterized.expand(
|
||||||
|
[
|
||||||
|
("WAV",),
|
||||||
|
("wav",),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_wav(self, format_):
|
||||||
|
self._test_format(format_)
|
||||||
|
|
||||||
|
@parameterized.expand(
|
||||||
|
[
|
||||||
|
("FLAC",),
|
||||||
|
("flac",),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
@skipIfFormatNotSupported("FLAC")
|
||||||
|
def test_flac(self, format_):
|
||||||
|
self._test_format(format_)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFileObject(TempDirMixin, unittest.TestCase):
|
||||||
|
def _test_fileobj(self, ext):
|
||||||
|
"""Loading audio via file-like object works"""
|
||||||
|
sample_rate = 16000
|
||||||
|
path = self.get_temp_path(f"test.{ext}")
|
||||||
|
|
||||||
|
data = get_wav_data("float32", num_channels=2).numpy().T
|
||||||
|
soundfile.write(path, data, sample_rate)
|
||||||
|
expected = soundfile.read(path, dtype="float32")[0].T
|
||||||
|
|
||||||
|
with open(path, "rb") as fileobj:
|
||||||
|
found, sr = soundfile_backend.load(fileobj)
|
||||||
|
assert sr == sample_rate
|
||||||
|
#self.assertEqual(expected, found)
|
||||||
|
np.testing.assert_array_almost_equal(found, expected)
|
||||||
|
|
||||||
|
def test_fileobj_wav(self):
|
||||||
|
"""Loading audio via file-like object works"""
|
||||||
|
self._test_fileobj("wav")
|
||||||
|
|
||||||
|
def test_fileobj_flac(self):
|
||||||
|
"""Loading audio via file-like object works"""
|
||||||
|
self._test_fileobj("flac")
|
||||||
|
|
||||||
|
def _test_tarfile(self, ext):
|
||||||
|
"""Loading audio via file-like object works"""
|
||||||
|
sample_rate = 16000
|
||||||
|
audio_file = f"test.{ext}"
|
||||||
|
audio_path = self.get_temp_path(audio_file)
|
||||||
|
archive_path = self.get_temp_path("archive.tar.gz")
|
||||||
|
|
||||||
|
data = get_wav_data("float32", num_channels=2).numpy().T
|
||||||
|
soundfile.write(audio_path, data, sample_rate)
|
||||||
|
expected = soundfile.read(audio_path, dtype="float32")[0].T
|
||||||
|
|
||||||
|
with tarfile.TarFile(archive_path, "w") as tarobj:
|
||||||
|
tarobj.add(audio_path, arcname=audio_file)
|
||||||
|
with tarfile.TarFile(archive_path, "r") as tarobj:
|
||||||
|
fileobj = tarobj.extractfile(audio_file)
|
||||||
|
found, sr = soundfile_backend.load(fileobj)
|
||||||
|
|
||||||
|
assert sr == sample_rate
|
||||||
|
#self.assertEqual(expected, found)
|
||||||
|
np.testing.assert_array_almost_equal(found.numpy(), expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tarfile_wav(self):
|
||||||
|
"""Loading audio via file-like object works"""
|
||||||
|
self._test_tarfile("wav")
|
||||||
|
|
||||||
|
def test_tarfile_flac(self):
|
||||||
|
"""Loading audio via file-like object works"""
|
||||||
|
self._test_tarfile("flac")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
@ -0,0 +1,322 @@
|
|||||||
|
import io
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from paddlespeech.audio._internal import module_utils as _mod_utils
|
||||||
|
from paddlespeech.audio.backends import soundfile_backend
|
||||||
|
from tests.unit.common_utils import (
|
||||||
|
get_wav_data,
|
||||||
|
load_wav,
|
||||||
|
nested_params,
|
||||||
|
normalize_wav,
|
||||||
|
save_wav,
|
||||||
|
TempDirMixin,
|
||||||
|
)
|
||||||
|
|
||||||
|
from common import fetch_wav_subtype, parameterize, skipIfFormatNotSupported
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
|
||||||
|
class MockedSaveTest(unittest.TestCase):
|
||||||
|
@nested_params(
|
||||||
|
["float32", "int32"],
|
||||||
|
[8000, 16000],
|
||||||
|
[1, 2],
|
||||||
|
[False, True],
|
||||||
|
[
|
||||||
|
(None, None),
|
||||||
|
("PCM_U", None),
|
||||||
|
("PCM_U", 8),
|
||||||
|
("PCM_S", None),
|
||||||
|
("PCM_S", 16),
|
||||||
|
("PCM_S", 32),
|
||||||
|
("PCM_F", None),
|
||||||
|
("PCM_F", 32),
|
||||||
|
("PCM_F", 64),
|
||||||
|
("ULAW", None),
|
||||||
|
("ULAW", 8),
|
||||||
|
("ALAW", None),
|
||||||
|
("ALAW", 8),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@patch("soundfile.write")
|
||||||
|
def test_wav(self, dtype, sample_rate, num_channels, channels_first, enc_params, mocked_write):
|
||||||
|
"""soundfile_backend.save passes correct subtype to soundfile.write when WAV"""
|
||||||
|
filepath = "foo.wav"
|
||||||
|
input_tensor = get_wav_data(
|
||||||
|
dtype,
|
||||||
|
num_channels,
|
||||||
|
num_frames=3 * sample_rate,
|
||||||
|
normalize=dtype == "float32",
|
||||||
|
channels_first=channels_first,
|
||||||
|
)
|
||||||
|
input_tensor = paddle.transpose(input_tensor, [1, 0])
|
||||||
|
|
||||||
|
encoding, bits_per_sample = enc_params
|
||||||
|
soundfile_backend.save(
|
||||||
|
filepath,
|
||||||
|
input_tensor,
|
||||||
|
sample_rate,
|
||||||
|
channels_first=channels_first,
|
||||||
|
encoding=encoding,
|
||||||
|
bits_per_sample=bits_per_sample,
|
||||||
|
)
|
||||||
|
|
||||||
|
# on +Py3.8 call_args.kwargs is more descreptive
|
||||||
|
args = mocked_write.call_args[1]
|
||||||
|
assert args["file"] == filepath
|
||||||
|
assert args["samplerate"] == sample_rate
|
||||||
|
assert args["subtype"] == fetch_wav_subtype(dtype, encoding, bits_per_sample)
|
||||||
|
assert args["format"] is None
|
||||||
|
tensor_result = paddle.transpose(input_tensor, [1, 0]) if channels_first else input_tensor
|
||||||
|
#self.assertEqual(args["data"], tensor_result.numpy())
|
||||||
|
np.testing.assert_array_almost_equal(args["data"].numpy(), tensor_result.numpy())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@patch("soundfile.write")
|
||||||
|
def assert_non_wav(
|
||||||
|
self,
|
||||||
|
fmt,
|
||||||
|
dtype,
|
||||||
|
sample_rate,
|
||||||
|
num_channels,
|
||||||
|
channels_first,
|
||||||
|
mocked_write,
|
||||||
|
encoding=None,
|
||||||
|
bits_per_sample=None,
|
||||||
|
):
|
||||||
|
"""soundfile_backend.save passes correct subtype and format to soundfile.write when SPHERE"""
|
||||||
|
filepath = f"foo.{fmt}"
|
||||||
|
input_tensor = get_wav_data(
|
||||||
|
dtype,
|
||||||
|
num_channels,
|
||||||
|
num_frames=3 * sample_rate,
|
||||||
|
normalize=False,
|
||||||
|
channels_first=channels_first,
|
||||||
|
)
|
||||||
|
input_tensor = paddle.transpose(input_tensor, [1, 0])
|
||||||
|
|
||||||
|
expected_data = paddle.transpose(input_tensor, [1, 0]) if channels_first else input_tensor
|
||||||
|
|
||||||
|
soundfile_backend.save(
|
||||||
|
filepath,
|
||||||
|
input_tensor,
|
||||||
|
sample_rate,
|
||||||
|
channels_first,
|
||||||
|
encoding=encoding,
|
||||||
|
bits_per_sample=bits_per_sample,
|
||||||
|
)
|
||||||
|
|
||||||
|
# on +Py3.8 call_args.kwargs is more descreptive
|
||||||
|
args = mocked_write.call_args[1]
|
||||||
|
assert args["file"] == filepath
|
||||||
|
assert args["samplerate"] == sample_rate
|
||||||
|
if fmt in ["sph", "nist", "nis"]:
|
||||||
|
assert args["format"] == "NIST"
|
||||||
|
else:
|
||||||
|
assert args["format"] is None
|
||||||
|
np.testing.assert_array_almost_equal(args["data"].numpy(), expected_data.numpy())
|
||||||
|
#self.assertEqual(args["data"], expected_data)
|
||||||
|
|
||||||
|
@nested_params(
|
||||||
|
["sph", "nist", "nis"],
|
||||||
|
["int32"],
|
||||||
|
[8000, 16000],
|
||||||
|
[1, 2],
|
||||||
|
[False, True],
|
||||||
|
[
|
||||||
|
("PCM_S", 8),
|
||||||
|
("PCM_S", 16),
|
||||||
|
("PCM_S", 24),
|
||||||
|
("PCM_S", 32),
|
||||||
|
("ULAW", 8),
|
||||||
|
("ALAW", 8),
|
||||||
|
("ALAW", 16),
|
||||||
|
("ALAW", 24),
|
||||||
|
("ALAW", 32),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_sph(self, fmt, dtype, sample_rate, num_channels, channels_first, enc_params):
|
||||||
|
"""soundfile_backend.save passes default format and subtype (None-s) to
|
||||||
|
soundfile.write when not WAV"""
|
||||||
|
encoding, bits_per_sample = enc_params
|
||||||
|
self.assert_non_wav(
|
||||||
|
fmt, dtype, sample_rate, num_channels, channels_first, encoding=encoding, bits_per_sample=bits_per_sample
|
||||||
|
)
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
["int32"],
|
||||||
|
[8000, 16000],
|
||||||
|
[1, 2],
|
||||||
|
[False, True],
|
||||||
|
[8, 16, 24],
|
||||||
|
)
|
||||||
|
def test_flac(self, dtype, sample_rate, num_channels, channels_first, bits_per_sample):
|
||||||
|
"""soundfile_backend.save passes default format and subtype (None-s) to
|
||||||
|
soundfile.write when not WAV"""
|
||||||
|
self.assert_non_wav("flac", dtype, sample_rate, num_channels, channels_first, bits_per_sample=bits_per_sample)
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
["int32"],
|
||||||
|
[8000, 16000],
|
||||||
|
[1, 2],
|
||||||
|
[False, True],
|
||||||
|
)
|
||||||
|
def test_ogg(self, dtype, sample_rate, num_channels, channels_first):
|
||||||
|
"""soundfile_backend.save passes default format and subtype (None-s) to
|
||||||
|
soundfile.write when not WAV"""
|
||||||
|
self.assert_non_wav("ogg", dtype, sample_rate, num_channels, channels_first)
|
||||||
|
|
||||||
|
|
||||||
|
class SaveTestBase(TempDirMixin, unittest.TestCase):
|
||||||
|
def assert_wav(self, dtype, sample_rate, num_channels, num_frames):
|
||||||
|
"""`soundfile_backend.save` can save wav format."""
|
||||||
|
path = self.get_temp_path("data.wav")
|
||||||
|
expected = get_wav_data(dtype, num_channels, num_frames=num_frames, normalize=False)
|
||||||
|
soundfile_backend.save(path, expected, sample_rate)
|
||||||
|
found, sr = load_wav(path, normalize=False)
|
||||||
|
assert sample_rate == sr
|
||||||
|
#self.assertEqual(found, expected)
|
||||||
|
np.testing.assert_array_almost_equal(found.numpy(), expected.numpy())
|
||||||
|
|
||||||
|
def _assert_non_wav(self, fmt, dtype, sample_rate, num_channels):
|
||||||
|
"""`soundfile_backend.save` can save non-wav format.
|
||||||
|
|
||||||
|
Due to precision missmatch, and the lack of alternative way to decode the
|
||||||
|
resulting files without using soundfile, only meta data are validated.
|
||||||
|
"""
|
||||||
|
num_frames = sample_rate * 3
|
||||||
|
path = self.get_temp_path(f"data.{fmt}")
|
||||||
|
expected = get_wav_data(dtype, num_channels, num_frames=num_frames, normalize=False)
|
||||||
|
soundfile_backend.save(path, expected, sample_rate)
|
||||||
|
sinfo = soundfile.info(path)
|
||||||
|
assert sinfo.format == fmt.upper()
|
||||||
|
#assert sinfo.frames == num_frames this go wrong
|
||||||
|
assert sinfo.channels == num_channels
|
||||||
|
assert sinfo.samplerate == sample_rate
|
||||||
|
|
||||||
|
def assert_flac(self, dtype, sample_rate, num_channels):
|
||||||
|
"""`soundfile_backend.save` can save flac format."""
|
||||||
|
self._assert_non_wav("flac", dtype, sample_rate, num_channels)
|
||||||
|
|
||||||
|
def assert_sphere(self, dtype, sample_rate, num_channels):
|
||||||
|
"""`soundfile_backend.save` can save sph format."""
|
||||||
|
self._assert_non_wav("nist", dtype, sample_rate, num_channels)
|
||||||
|
|
||||||
|
def assert_ogg(self, dtype, sample_rate, num_channels):
|
||||||
|
"""`soundfile_backend.save` can save ogg format.
|
||||||
|
|
||||||
|
As we cannot inspect the OGG format (it's lossy), we only check the metadata.
|
||||||
|
"""
|
||||||
|
self._assert_non_wav("ogg", dtype, sample_rate, num_channels)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSave(SaveTestBase):
|
||||||
|
@parameterize(
|
||||||
|
["float32", "int32"],
|
||||||
|
[8000, 16000],
|
||||||
|
[1, 2],
|
||||||
|
)
|
||||||
|
def test_wav(self, dtype, sample_rate, num_channels):
|
||||||
|
"""`soundfile_backend.save` can save wav format."""
|
||||||
|
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None)
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
["float32", "int32"],
|
||||||
|
[4, 8, 16, 32],
|
||||||
|
)
|
||||||
|
def test_multiple_channels(self, dtype, num_channels):
|
||||||
|
"""`soundfile_backend.save` can save wav with more than 2 channels."""
|
||||||
|
sample_rate = 8000
|
||||||
|
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None)
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
["int32"],
|
||||||
|
[8000, 16000],
|
||||||
|
[1, 2],
|
||||||
|
)
|
||||||
|
@skipIfFormatNotSupported("NIST")
|
||||||
|
def test_sphere(self, dtype, sample_rate, num_channels):
|
||||||
|
"""`soundfile_backend.save` can save sph format."""
|
||||||
|
self.assert_sphere(dtype, sample_rate, num_channels)
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
[8000, 16000],
|
||||||
|
[1, 2],
|
||||||
|
)
|
||||||
|
@skipIfFormatNotSupported("FLAC")
|
||||||
|
def test_flac(self, sample_rate, num_channels):
|
||||||
|
"""`soundfile_backend.save` can save flac format."""
|
||||||
|
self.assert_flac("float32", sample_rate, num_channels)
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
[8000, 16000],
|
||||||
|
[1, 2],
|
||||||
|
)
|
||||||
|
@skipIfFormatNotSupported("OGG")
|
||||||
|
def test_ogg(self, sample_rate, num_channels):
|
||||||
|
"""`soundfile_backend.save` can save ogg/vorbis format."""
|
||||||
|
self.assert_ogg("float32", sample_rate, num_channels)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSaveParams(TempDirMixin, unittest.TestCase):
|
||||||
|
"""Test the correctness of optional parameters of `soundfile_backend.save`"""
|
||||||
|
|
||||||
|
@parameterize([True, False])
|
||||||
|
def test_channels_first(self, channels_first):
|
||||||
|
"""channels_first swaps axes"""
|
||||||
|
path = self.get_temp_path("data.wav")
|
||||||
|
data = get_wav_data("int32", 2, channels_first=channels_first)
|
||||||
|
soundfile_backend.save(path, data, 8000, channels_first=channels_first)
|
||||||
|
found = load_wav(path)[0]
|
||||||
|
expected = data if channels_first else data.transpose([1, 0])
|
||||||
|
#self.assertEqual(found, expected, atol=1e-4, rtol=1e-8)
|
||||||
|
np.testing.assert_array_almost_equal(found.numpy(), expected.numpy())
|
||||||
|
|
||||||
|
|
||||||
|
class TestFileObject(TempDirMixin, unittest.TestCase):
|
||||||
|
def _test_fileobj(self, ext):
|
||||||
|
"""Saving audio to file-like object works"""
|
||||||
|
sample_rate = 16000
|
||||||
|
path = self.get_temp_path(f"test.{ext}")
|
||||||
|
|
||||||
|
subtype = "FLOAT" if ext == "wav" else None
|
||||||
|
data = get_wav_data("float32", num_channels=2)
|
||||||
|
soundfile.write(path, data.numpy().T, sample_rate, subtype=subtype)
|
||||||
|
expected = soundfile.read(path, dtype="float32")[0]
|
||||||
|
|
||||||
|
fileobj = io.BytesIO()
|
||||||
|
soundfile_backend.save(fileobj, data, sample_rate, format=ext)
|
||||||
|
fileobj.seek(0)
|
||||||
|
found, sr = soundfile.read(fileobj, dtype="float32")
|
||||||
|
|
||||||
|
assert sr == sample_rate
|
||||||
|
#self.assertEqual(expected, found, atol=1e-4, rtol=1e-8)
|
||||||
|
np.testing.assert_array_almost_equal(found, expected)
|
||||||
|
|
||||||
|
def test_fileobj_wav(self):
|
||||||
|
"""Saving audio via file-like object works"""
|
||||||
|
self._test_fileobj("wav")
|
||||||
|
|
||||||
|
@skipIfFormatNotSupported("FLAC")
|
||||||
|
def test_fileobj_flac(self):
|
||||||
|
"""Saving audio via file-like object works"""
|
||||||
|
self._test_fileobj("flac")
|
||||||
|
|
||||||
|
@skipIfFormatNotSupported("NIST")
|
||||||
|
def test_fileobj_nist(self):
|
||||||
|
"""Saving audio via file-like object works"""
|
||||||
|
self._test_fileobj("NIST")
|
||||||
|
|
||||||
|
@skipIfFormatNotSupported("OGG")
|
||||||
|
def test_fileobj_ogg(self):
|
||||||
|
"""Saving audio via file-like object works"""
|
||||||
|
self._test_fileobj("OGG")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue