You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/audio/tests/backends/soundfile/save_test.py

324 lines
11 KiB

import io
import unittest
from unittest.mock import patch
import numpy as np
import paddle
import soundfile
from common import fetch_wav_subtype
from common import parameterize
from common import skipIfFormatNotSupported
from common_utils import get_wav_data
from common_utils import load_wav
from common_utils import nested_params
from common_utils import TempDirMixin
from paddleaudio.backends import soundfile_backend
class MockedSaveTest(unittest.TestCase):
@nested_params(
["float32", "int32"],
[8000, 16000],
[1, 2],
[False, True],
[
(None, None),
("PCM_U", None),
("PCM_U", 8),
("PCM_S", None),
("PCM_S", 16),
("PCM_S", 32),
("PCM_F", None),
("PCM_F", 32),
("PCM_F", 64),
("ULAW", None),
("ULAW", 8),
("ALAW", None),
("ALAW", 8),
], )
@patch("soundfile.write")
def test_wav(self, dtype, sample_rate, num_channels, channels_first,
enc_params, mocked_write):
"""soundfile_backend.save passes correct subtype to soundfile.write when WAV"""
filepath = "foo.wav"
input_tensor = get_wav_data(
dtype,
num_channels,
num_frames=3 * sample_rate,
normalize=dtype == "float32",
channels_first=channels_first, )
input_tensor = paddle.transpose(input_tensor, [1, 0])
encoding, bits_per_sample = enc_params
soundfile_backend.save(
filepath,
input_tensor,
sample_rate,
channels_first=channels_first,
encoding=encoding,
bits_per_sample=bits_per_sample, )
# on +Py3.8 call_args.kwargs is more descreptive
args = mocked_write.call_args[1]
assert args["file"] == filepath
assert args["samplerate"] == sample_rate
assert args["subtype"] == fetch_wav_subtype(dtype, encoding,
bits_per_sample)
assert args["format"] is None
tensor_result = paddle.transpose(
input_tensor, [1, 0]) if channels_first else input_tensor
#self.assertEqual(args["data"], tensor_result.numpy())
np.testing.assert_array_almost_equal(args["data"].numpy(),
tensor_result.numpy())
@patch("soundfile.write")
def assert_non_wav(
self,
fmt,
dtype,
sample_rate,
num_channels,
channels_first,
mocked_write,
encoding=None,
bits_per_sample=None, ):
"""soundfile_backend.save passes correct subtype and format to soundfile.write when SPHERE"""
filepath = f"foo.{fmt}"
input_tensor = get_wav_data(
dtype,
num_channels,
num_frames=3 * sample_rate,
normalize=False,
channels_first=channels_first, )
input_tensor = paddle.transpose(input_tensor, [1, 0])
expected_data = paddle.transpose(
input_tensor, [1, 0]) if channels_first else input_tensor
soundfile_backend.save(
filepath,
input_tensor,
sample_rate,
channels_first,
encoding=encoding,
bits_per_sample=bits_per_sample, )
# on +Py3.8 call_args.kwargs is more descreptive
args = mocked_write.call_args[1]
assert args["file"] == filepath
assert args["samplerate"] == sample_rate
if fmt in ["sph", "nist", "nis"]:
assert args["format"] == "NIST"
else:
assert args["format"] is None
np.testing.assert_array_almost_equal(args["data"].numpy(),
expected_data.numpy())
#self.assertEqual(args["data"], expected_data)
@nested_params(
["sph", "nist", "nis"],
["int32"],
[8000, 16000],
[1, 2],
[False, True],
[
("PCM_S", 8),
("PCM_S", 16),
("PCM_S", 24),
("PCM_S", 32),
("ULAW", 8),
("ALAW", 8),
("ALAW", 16),
("ALAW", 24),
("ALAW", 32),
], )
def test_sph(self, fmt, dtype, sample_rate, num_channels, channels_first,
enc_params):
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
encoding, bits_per_sample = enc_params
self.assert_non_wav(
fmt,
dtype,
sample_rate,
num_channels,
channels_first,
encoding=encoding,
bits_per_sample=bits_per_sample)
@parameterize(
["int32"],
[8000, 16000],
[1, 2],
[False, True],
[8, 16, 24], )
def test_flac(self, dtype, sample_rate, num_channels, channels_first,
bits_per_sample):
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
self.assert_non_wav(
"flac",
dtype,
sample_rate,
num_channels,
channels_first,
bits_per_sample=bits_per_sample)
@parameterize(
["int32"],
[8000, 16000],
[1, 2],
[False, True], )
def test_ogg(self, dtype, sample_rate, num_channels, channels_first):
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
self.assert_non_wav("ogg", dtype, sample_rate, num_channels,
channels_first)
class SaveTestBase(TempDirMixin, unittest.TestCase):
def assert_wav(self, dtype, sample_rate, num_channels, num_frames):
"""`soundfile_backend.save` can save wav format."""
path = self.get_temp_path("data.wav")
expected = get_wav_data(
dtype, num_channels, num_frames=num_frames, normalize=False)
soundfile_backend.save(path, expected, sample_rate)
found, sr = load_wav(path, normalize=False)
assert sample_rate == sr
#self.assertEqual(found, expected)
np.testing.assert_array_almost_equal(found.numpy(), expected.numpy())
def _assert_non_wav(self, fmt, dtype, sample_rate, num_channels):
"""`soundfile_backend.save` can save non-wav format.
Due to precision missmatch, and the lack of alternative way to decode the
resulting files without using soundfile, only meta data are validated.
"""
num_frames = sample_rate * 3
path = self.get_temp_path(f"data.{fmt}")
expected = get_wav_data(
dtype, num_channels, num_frames=num_frames, normalize=False)
soundfile_backend.save(path, expected, sample_rate)
sinfo = soundfile.info(path)
assert sinfo.format == fmt.upper()
#assert sinfo.frames == num_frames this go wrong
assert sinfo.channels == num_channels
assert sinfo.samplerate == sample_rate
def assert_flac(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.save` can save flac format."""
self._assert_non_wav("flac", dtype, sample_rate, num_channels)
def assert_sphere(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.save` can save sph format."""
self._assert_non_wav("nist", dtype, sample_rate, num_channels)
def assert_ogg(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.save` can save ogg format.
As we cannot inspect the OGG format (it's lossy), we only check the metadata.
"""
self._assert_non_wav("ogg", dtype, sample_rate, num_channels)
class TestSave(SaveTestBase):
@parameterize(
["float32", "int32"],
[8000, 16000],
[1, 2], )
def test_wav(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.save` can save wav format."""
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None)
@parameterize(
["float32", "int32"],
[4, 8, 16, 32], )
def test_multiple_channels(self, dtype, num_channels):
"""`soundfile_backend.save` can save wav with more than 2 channels."""
sample_rate = 8000
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None)
@parameterize(
["int32"],
[8000, 16000],
[1, 2], )
@skipIfFormatNotSupported("NIST")
def test_sphere(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.save` can save sph format."""
self.assert_sphere(dtype, sample_rate, num_channels)
@parameterize(
[8000, 16000],
[1, 2], )
@skipIfFormatNotSupported("FLAC")
def test_flac(self, sample_rate, num_channels):
"""`soundfile_backend.save` can save flac format."""
self.assert_flac("float32", sample_rate, num_channels)
@parameterize(
[8000, 16000],
[1, 2], )
@skipIfFormatNotSupported("OGG")
def test_ogg(self, sample_rate, num_channels):
"""`soundfile_backend.save` can save ogg/vorbis format."""
self.assert_ogg("float32", sample_rate, num_channels)
class TestSaveParams(TempDirMixin, unittest.TestCase):
"""Test the correctness of optional parameters of `soundfile_backend.save`"""
@parameterize([True, False])
def test_channels_first(self, channels_first):
"""channels_first swaps axes"""
path = self.get_temp_path("data.wav")
data = get_wav_data("int32", 2, channels_first=channels_first)
soundfile_backend.save(path, data, 8000, channels_first=channels_first)
found = load_wav(path)[0]
expected = data if channels_first else data.transpose([1, 0])
#self.assertEqual(found, expected, atol=1e-4, rtol=1e-8)
np.testing.assert_array_almost_equal(found.numpy(), expected.numpy())
class TestFileObject(TempDirMixin, unittest.TestCase):
def _test_fileobj(self, ext):
"""Saving audio to file-like object works"""
sample_rate = 16000
path = self.get_temp_path(f"test.{ext}")
subtype = "FLOAT" if ext == "wav" else None
data = get_wav_data("float32", num_channels=2)
soundfile.write(path, data.numpy().T, sample_rate, subtype=subtype)
expected = soundfile.read(path, dtype="float32")[0]
fileobj = io.BytesIO()
soundfile_backend.save(fileobj, data, sample_rate, format=ext)
fileobj.seek(0)
found, sr = soundfile.read(fileobj, dtype="float32")
assert sr == sample_rate
#self.assertEqual(expected, found, atol=1e-4, rtol=1e-8)
np.testing.assert_array_almost_equal(found, expected)
def test_fileobj_wav(self):
"""Saving audio via file-like object works"""
self._test_fileobj("wav")
@skipIfFormatNotSupported("FLAC")
def test_fileobj_flac(self):
"""Saving audio via file-like object works"""
self._test_fileobj("flac")
@skipIfFormatNotSupported("NIST")
def test_fileobj_nist(self):
"""Saving audio via file-like object works"""
self._test_fileobj("NIST")
@skipIfFormatNotSupported("OGG")
def test_fileobj_ogg(self):
"""Saving audio via file-like object works"""
self._test_fileobj("OGG")
if __name__ == '__main__':
unittest.main()