You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
200 lines
7.5 KiB
200 lines
7.5 KiB
#this code is from: https://github.com/pytorch/audio/blob/main/test/torchaudio_unittest/backend/soundfile/info_test.py
|
|
import tarfile
|
|
import unittest
|
|
import warnings
|
|
from unittest.mock import patch
|
|
|
|
import paddle
|
|
import soundfile
|
|
from common import get_bits_per_sample
|
|
from common import get_encoding
|
|
from common import parameterize
|
|
from common import skipIfFormatNotSupported
|
|
from common_utils import get_wav_data
|
|
from common_utils import nested_params
|
|
from common_utils import save_wav
|
|
from common_utils import TempDirMixin
|
|
from paddleaudio.backends import soundfile_backend
|
|
|
|
|
|
class TestInfo(TempDirMixin, unittest.TestCase):
|
|
@parameterize(
|
|
["float32", "int32"],
|
|
[8000, 16000],
|
|
[1, 2], )
|
|
def test_wav(self, dtype, sample_rate, num_channels):
|
|
"""`soundfile_backend.info` can check wav file correctly"""
|
|
duration = 1
|
|
path = self.get_temp_path("data.wav")
|
|
data = get_wav_data(
|
|
dtype,
|
|
num_channels,
|
|
normalize=False,
|
|
num_frames=duration * sample_rate)
|
|
save_wav(path, data, sample_rate)
|
|
info = soundfile_backend.info(path)
|
|
assert info.sample_rate == sample_rate
|
|
assert info.num_frames == sample_rate * duration
|
|
assert info.num_channels == num_channels
|
|
assert info.bits_per_sample == get_bits_per_sample("wav", dtype)
|
|
assert info.encoding == get_encoding("wav", dtype)
|
|
|
|
@parameterize([8000, 16000], [1, 2])
|
|
@skipIfFormatNotSupported("FLAC")
|
|
def test_flac(self, sample_rate, num_channels):
|
|
"""`soundfile_backend.info` can check flac file correctly"""
|
|
duration = 1
|
|
num_frames = sample_rate * duration
|
|
#data = torch.randn(num_frames, num_channels).numpy()
|
|
data = paddle.randn(shape=[num_frames, num_channels]).numpy()
|
|
|
|
path = self.get_temp_path("data.flac")
|
|
soundfile.write(path, data, sample_rate)
|
|
|
|
info = soundfile_backend.info(path)
|
|
assert info.sample_rate == sample_rate
|
|
assert info.num_frames == num_frames
|
|
assert info.num_channels == num_channels
|
|
assert info.bits_per_sample == 16
|
|
assert info.encoding == "FLAC"
|
|
|
|
#@parameterize([8000, 16000], [1, 2])
|
|
#@skipIfFormatNotSupported("OGG")
|
|
#def test_ogg(self, sample_rate, num_channels):
|
|
#"""`soundfile_backend.info` can check ogg file correctly"""
|
|
#duration = 1
|
|
#num_frames = sample_rate * duration
|
|
##data = torch.randn(num_frames, num_channels).numpy()
|
|
#data = paddle.randn(shape=[num_frames, num_channels]).numpy()
|
|
#print(len(data))
|
|
#path = self.get_temp_path("data.ogg")
|
|
#soundfile.write(path, data, sample_rate)
|
|
|
|
#info = soundfile_backend.info(path)
|
|
#print(info)
|
|
#assert info.sample_rate == sample_rate
|
|
#print("info")
|
|
#print(info.num_frames)
|
|
#print("jiji")
|
|
#print(sample_rate*duration)
|
|
##assert info.num_frames == sample_rate * duration
|
|
#assert info.num_channels == num_channels
|
|
#assert info.bits_per_sample == 0
|
|
#assert info.encoding == "VORBIS"
|
|
|
|
@nested_params(
|
|
[8000, 16000],
|
|
[1, 2],
|
|
[("PCM_24", 24), ("PCM_32", 32)], )
|
|
@skipIfFormatNotSupported("NIST")
|
|
def test_sphere(self, sample_rate, num_channels, subtype_and_bit_depth):
|
|
"""`soundfile_backend.info` can check sph file correctly"""
|
|
duration = 1
|
|
num_frames = sample_rate * duration
|
|
#data = torch.randn(num_frames, num_channels).numpy()
|
|
data = paddle.randn(shape=[num_frames, num_channels]).numpy()
|
|
path = self.get_temp_path("data.nist")
|
|
subtype, bits_per_sample = subtype_and_bit_depth
|
|
soundfile.write(path, data, sample_rate, subtype=subtype)
|
|
|
|
info = soundfile_backend.info(path)
|
|
assert info.sample_rate == sample_rate
|
|
assert info.num_frames == sample_rate * duration
|
|
assert info.num_channels == num_channels
|
|
assert info.bits_per_sample == bits_per_sample
|
|
assert info.encoding == "PCM_S"
|
|
|
|
def test_unknown_subtype_warning(self):
|
|
"""soundfile_backend.info issues a warning when the subtype is unknown
|
|
|
|
This will happen if a new subtype is supported in SoundFile: the _SUBTYPE_TO_BITS_PER_SAMPLE
|
|
dict should be updated.
|
|
"""
|
|
|
|
def _mock_info_func(_):
|
|
class MockSoundFileInfo:
|
|
samplerate = 8000
|
|
frames = 356
|
|
channels = 2
|
|
subtype = "UNSEEN_SUBTYPE"
|
|
format = "UNKNOWN"
|
|
|
|
return MockSoundFileInfo()
|
|
|
|
with patch("soundfile.info", _mock_info_func):
|
|
with warnings.catch_warnings(record=True) as w:
|
|
info = soundfile_backend.info("foo")
|
|
assert len(w) == 1
|
|
assert "UNSEEN_SUBTYPE subtype is unknown to PaddleAudio" in str(
|
|
w[-1].message)
|
|
assert info.bits_per_sample == 0
|
|
|
|
|
|
class TestFileObject(TempDirMixin, unittest.TestCase):
|
|
def _test_fileobj(self, ext, subtype, bits_per_sample):
|
|
"""Query audio via file-like object works"""
|
|
duration = 2
|
|
sample_rate = 16000
|
|
num_channels = 2
|
|
num_frames = sample_rate * duration
|
|
path = self.get_temp_path(f"test.{ext}")
|
|
|
|
#data = torch.randn(num_frames, num_channels).numpy()
|
|
data = paddle.randn(shape=[num_frames, num_channels]).numpy()
|
|
soundfile.write(path, data, sample_rate, subtype=subtype)
|
|
|
|
with open(path, "rb") as fileobj:
|
|
info = soundfile_backend.info(fileobj)
|
|
assert info.sample_rate == sample_rate
|
|
assert info.num_frames == num_frames
|
|
assert info.num_channels == num_channels
|
|
assert info.bits_per_sample == bits_per_sample
|
|
assert info.encoding == "FLAC" if ext == "flac" else "PCM_S"
|
|
|
|
def test_fileobj_wav(self):
|
|
"""Loading audio via file-like object works"""
|
|
self._test_fileobj("wav", "PCM_16", 16)
|
|
|
|
@skipIfFormatNotSupported("FLAC")
|
|
def test_fileobj_flac(self):
|
|
"""Loading audio via file-like object works"""
|
|
self._test_fileobj("flac", "PCM_16", 16)
|
|
|
|
def _test_tarobj(self, ext, subtype, bits_per_sample):
|
|
"""Query compressed audio via file-like object works"""
|
|
duration = 2
|
|
sample_rate = 16000
|
|
num_channels = 2
|
|
num_frames = sample_rate * duration
|
|
audio_file = f"test.{ext}"
|
|
audio_path = self.get_temp_path(audio_file)
|
|
archive_path = self.get_temp_path("archive.tar.gz")
|
|
|
|
#data = torch.randn(num_frames, num_channels).numpy()
|
|
data = paddle.randn(shape=[num_frames, num_channels]).numpy()
|
|
soundfile.write(audio_path, data, sample_rate, subtype=subtype)
|
|
|
|
with tarfile.TarFile(archive_path, "w") as tarobj:
|
|
tarobj.add(audio_path, arcname=audio_file)
|
|
with tarfile.TarFile(archive_path, "r") as tarobj:
|
|
fileobj = tarobj.extractfile(audio_file)
|
|
info = soundfile_backend.info(fileobj)
|
|
assert info.sample_rate == sample_rate
|
|
assert info.num_frames == num_frames
|
|
assert info.num_channels == num_channels
|
|
assert info.bits_per_sample == bits_per_sample
|
|
assert info.encoding == "FLAC" if ext == "flac" else "PCM_S"
|
|
|
|
def test_tarobj_wav(self):
|
|
"""Query compressed audio via file-like object works"""
|
|
self._test_tarobj("wav", "PCM_16", 16)
|
|
|
|
@skipIfFormatNotSupported("FLAC")
|
|
def test_tarobj_flac(self):
|
|
"""Query compressed audio via file-like object works"""
|
|
self._test_tarobj("flac", "PCM_16", 16)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|