Merge pull request #2257 from SmileGoat/add_pitch2

[audio]add soundfile backend
pull/2382/head
Hui Zhang 3 years ago committed by GitHub
commit dc6f8ff10c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -53,11 +53,11 @@ set(FETCHCONTENT_BASE_DIR ${fc_patch})
include(openblas)
# include(pybind)
include(pybind)
# packages
find_package(Python3 COMPONENTS Interpreter Development)
find_package(pybind11 CONFIG REQUIRED)
#find_package(pybind11 CONFIG REQUIRED)
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread -O0 -Wall -g")

@ -11,25 +11,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
from typing import Optional
from typing import Tuple
import numpy as np
import paddle
import resampy
import soundfile as sf
import soundfile
from scipy.io import wavfile
from ..utils import depth_convert
from ..utils import ParameterError
from .common import AudioMetaData
__all__ = [
'resample',
'to_mono',
'normalize',
'save',
'soudfile_save',
'load',
'soundfile_load',
'info'
]
NORMALMIZE_TYPES = ['linear', 'gaussian']
MERGE_TYPES = ['ch0', 'ch1', 'random', 'average']
@ -116,7 +122,7 @@ def to_mono(y: np.ndarray, merge_type: str='average') -> np.ndarray:
return y_out
def sound_file_load(file: os.PathLike,
def soundfile_load(file: os.PathLike,
offset: Optional[float]=None,
dtype: str='int16',
duration: Optional[int]=None) -> Tuple[np.ndarray, int]:
@ -131,7 +137,7 @@ def sound_file_load(file: os.PathLike,
Returns:
Tuple[np.ndarray, int]: Waveform in ndarray and its samplerate.
"""
with sf.SoundFile(file) as sf_desc:
with soundfile.SoundFile(file) as sf_desc:
sr_native = sf_desc.samplerate
if offset:
sf_desc.seek(int(offset * sr_native))
@ -172,7 +178,7 @@ def normalize(y: np.ndarray, norm_type: str='linear',
return y
def save(y: np.ndarray, sr: int, file: os.PathLike) -> None:
def soundfile_save(y: np.ndarray, sr: int, file: os.PathLike) -> None:
"""Save audio file to disk. This function saves audio to disk using scipy.io.wavfile, with additional step to convert input waveform to int16.
Args:
@ -198,8 +204,7 @@ def save(y: np.ndarray, sr: int, file: os.PathLike) -> None:
wavfile.write(file, sr, y_out)
def load(
def soudfile_load(
file: os.PathLike,
sr: Optional[int]=None,
mono: bool=True,
@ -251,6 +256,406 @@ def load(
y = depth_convert(y, dtype)
return y, r
#the code below is form: https://github.com/pytorch/audio/blob/main/torchaudio/backend/soundfile_backend.py
def _get_subtype_for_wav(dtype: paddle.dtype, encoding: str, bits_per_sample: int):
if not encoding:
if not bits_per_sample:
subtype = {
paddle.uint8: "PCM_U8",
paddle.int16: "PCM_16",
paddle.int32: "PCM_32",
paddle.float32: "FLOAT",
paddle.float64: "DOUBLE",
}.get(dtype)
if not subtype:
raise ValueError(f"Unsupported dtype for wav: {dtype}")
return subtype
if bits_per_sample == 8:
return "PCM_U8"
return f"PCM_{bits_per_sample}"
if encoding == "PCM_S":
if not bits_per_sample:
return "PCM_32"
if bits_per_sample == 8:
raise ValueError("wav does not support 8-bit signed PCM encoding.")
return f"PCM_{bits_per_sample}"
if encoding == "PCM_U":
if bits_per_sample in (None, 8):
return "PCM_U8"
raise ValueError("wav only supports 8-bit unsigned PCM encoding.")
if encoding == "PCM_F":
if bits_per_sample in (None, 32):
return "FLOAT"
if bits_per_sample == 64:
return "DOUBLE"
raise ValueError("wav only supports 32/64-bit float PCM encoding.")
if encoding == "ULAW":
if bits_per_sample in (None, 8):
return "ULAW"
raise ValueError("wav only supports 8-bit mu-law encoding.")
if encoding == "ALAW":
if bits_per_sample in (None, 8):
return "ALAW"
raise ValueError("wav only supports 8-bit a-law encoding.")
raise ValueError(f"wav does not support {encoding}.")
def _get_subtype_for_sphere(encoding: str, bits_per_sample: int):
if encoding in (None, "PCM_S"):
return f"PCM_{bits_per_sample}" if bits_per_sample else "PCM_32"
if encoding in ("PCM_U", "PCM_F"):
raise ValueError(f"sph does not support {encoding} encoding.")
if encoding == "ULAW":
if bits_per_sample in (None, 8):
return "ULAW"
raise ValueError("sph only supports 8-bit for mu-law encoding.")
if encoding == "ALAW":
return "ALAW"
raise ValueError(f"sph does not support {encoding}.")
def _get_subtype(dtype: paddle.dtype, format: str, encoding: str, bits_per_sample: int):
if format == "wav":
return _get_subtype_for_wav(dtype, encoding, bits_per_sample)
if format == "flac":
if encoding:
raise ValueError("flac does not support encoding.")
if not bits_per_sample:
return "PCM_16"
if bits_per_sample > 24:
raise ValueError("flac does not support bits_per_sample > 24.")
return "PCM_S8" if bits_per_sample == 8 else f"PCM_{bits_per_sample}"
if format in ("ogg", "vorbis"):
if encoding or bits_per_sample:
raise ValueError("ogg/vorbis does not support encoding/bits_per_sample.")
return "VORBIS"
if format == "sph":
return _get_subtype_for_sphere(encoding, bits_per_sample)
if format in ("nis", "nist"):
return "PCM_16"
raise ValueError(f"Unsupported format: {format}")
def save(
filepath: str,
src: paddle.Tensor,
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
format: Optional[str] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
):
"""Save audio data to file.
Note:
The formats this function can handle depend on the soundfile installation.
This function is tested on the following formats;
* WAV
* 32-bit floating-point
* 32-bit signed integer
* 16-bit signed integer
* 8-bit unsigned integer
* FLAC
* OGG/VORBIS
* SPHERE
Note:
``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts
``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend,
Args:
filepath (str or pathlib.Path): Path to audio file.
src (paddle.Tensor): Audio data to save. must be 2D tensor.
sample_rate (int): sampling rate
channels_first (bool, optional): If ``True``, the given tensor is interpreted as `[channel, time]`,
otherwise `[time, channel]`.
compression (float of None, optional): Not used.
It is here only for interface compatibility reson with "sox_io" backend.
format (str or None, optional): Override the audio format.
When ``filepath`` argument is path-like object, audio format is
inferred from file extension. If the file extension is missing or
different, you can specify the correct format with this argument.
When ``filepath`` argument is file-like object,
this argument is required.
Valid values are ``"wav"``, ``"ogg"``, ``"vorbis"``,
``"flac"`` and ``"sph"``.
encoding (str or None, optional): Changes the encoding for supported formats.
This argument is effective only for supported formats, sush as
``"wav"``, ``""flac"`` and ``"sph"``. Valid values are;
- ``"PCM_S"`` (signed integer Linear PCM)
- ``"PCM_U"`` (unsigned integer Linear PCM)
- ``"PCM_F"`` (floating point PCM)
- ``"ULAW"`` (mu-law)
- ``"ALAW"`` (a-law)
bits_per_sample (int or None, optional): Changes the bit depth for the
supported formats.
When ``format`` is one of ``"wav"``, ``"flac"`` or ``"sph"``,
you can change the bit depth.
Valid values are ``8``, ``16``, ``24``, ``32`` and ``64``.
Supported formats/encodings/bit depth/compression are:
``"wav"``
- 32-bit floating-point PCM
- 32-bit signed integer PCM
- 24-bit signed integer PCM
- 16-bit signed integer PCM
- 8-bit unsigned integer PCM
- 8-bit mu-law
- 8-bit a-law
Note:
Default encoding/bit depth is determined by the dtype of
the input Tensor.
``"flac"``
- 8-bit
- 16-bit (default)
- 24-bit
``"ogg"``, ``"vorbis"``
- Doesn't accept changing configuration.
``"sph"``
- 8-bit signed integer PCM
- 16-bit signed integer PCM
- 24-bit signed integer PCM
- 32-bit signed integer PCM (default)
- 8-bit mu-law
- 8-bit a-law
- 16-bit a-law
- 24-bit a-law
- 32-bit a-law
"""
if src.ndim != 2:
raise ValueError(f"Expected 2D Tensor, got {src.ndim}D.")
if compression is not None:
warnings.warn(
'`save` function of "soundfile" backend does not support "compression" parameter. '
"The argument is silently ignored."
)
if hasattr(filepath, "write"):
if format is None:
raise RuntimeError("`format` is required when saving to file object.")
ext = format.lower()
else:
ext = str(filepath).split(".")[-1].lower()
if bits_per_sample not in (None, 8, 16, 24, 32, 64):
raise ValueError("Invalid bits_per_sample.")
if bits_per_sample == 24:
warnings.warn(
"Saving audio with 24 bits per sample might warp samples near -1. "
"Using 16 bits per sample might be able to avoid this."
)
subtype = _get_subtype(src.dtype, ext, encoding, bits_per_sample)
# sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format,
# so we extend the extensions manually here
if ext in ["nis", "nist", "sph"] and format is None:
format = "NIST"
if channels_first:
src = src.t()
soundfile.write(file=filepath, data=src, samplerate=sample_rate, subtype=subtype, format=format)
def info(filepath: str) -> None:
raise RuntimeError("No audio I/O backend is available.")
_SUBTYPE2DTYPE = {
"PCM_S8": "int8",
"PCM_U8": "uint8",
"PCM_16": "int16",
"PCM_32": "int32",
"FLOAT": "float32",
"DOUBLE": "float64",
}
def load(
filepath: str,
frame_offset: int = 0,
num_frames: int = -1,
normalize: bool = True,
channels_first: bool = True,
format: Optional[str] = None,
) -> Tuple[paddle.Tensor, int]:
"""Load audio data from file.
Note:
The formats this function can handle depend on the soundfile installation.
This function is tested on the following formats;
* WAV
* 32-bit floating-point
* 32-bit signed integer
* 16-bit signed integer
* 8-bit unsigned integer
* FLAC
* OGG/VORBIS
* SPHERE
By default (``normalize=True``, ``channels_first=True``), this function returns Tensor with
``float32`` dtype and the shape of `[channel, time]`.
The samples are normalized to fit in the range of ``[-1.0, 1.0]``.
When the input format is WAV with integer type, such as 32-bit signed integer, 16-bit
signed integer and 8-bit unsigned integer (24-bit signed integer is not supported),
by providing ``normalize=False``, this function can return integer Tensor, where the samples
are expressed within the whole range of the corresponding dtype, that is, ``int32`` tensor
for 32-bit signed PCM, ``int16`` for 16-bit signed PCM and ``uint8`` for 8-bit unsigned PCM.
``normalize`` parameter has no effect on 32-bit floating-point WAV and other formats, such as
``flac`` and ``mp3``.
For these formats, this function always returns ``float32`` Tensor with values normalized to
``[-1.0, 1.0]``.
Note:
``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts
``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend.
Args:
filepath (path-like object or file-like object):
Source of audio data.
frame_offset (int, optional):
Number of frames to skip before start reading data.
num_frames (int, optional):
Maximum number of frames to read. ``-1`` reads all the remaining samples,
starting from ``frame_offset``.
This function may return the less number of frames if there is not enough
frames in the given file.
normalize (bool, optional):
When ``True``, this function always return ``float32``, and sample values are
normalized to ``[-1.0, 1.0]``.
If input file is integer WAV, giving ``False`` will change the resulting Tensor type to
integer type.
This argument has no effect for formats other than integer WAV type.
channels_first (bool, optional):
When True, the returned Tensor has dimension `[channel, time]`.
Otherwise, the returned Tensor's dimension is `[time, channel]`.
format (str or None, optional):
Not used. PySoundFile does not accept format hint.
Returns:
(paddle.Tensor, int): Resulting Tensor and sample rate.
If the input file has integer wav format and normalization is off, then it has
integer type, else ``float32`` type. If ``channels_first=True``, it has
`[channel, time]` else `[time, channel]`.
"""
with soundfile.SoundFile(filepath, "r") as file_:
if file_.format != "WAV" or normalize:
dtype = "float32"
elif file_.subtype not in _SUBTYPE2DTYPE:
raise ValueError(f"Unsupported subtype: {file_.subtype}")
else:
dtype = _SUBTYPE2DTYPE[file_.subtype]
frames = file_._prepare_read(frame_offset, None, num_frames)
waveform = file_.read(frames, dtype, always_2d=True)
sample_rate = file_.samplerate
waveform = paddle.to_tensor(waveform)
if channels_first:
waveform = paddle.transpose(waveform, perm=[1,0])
return waveform, sample_rate
# Mapping from soundfile subtype to number of bits per sample.
# This is mostly heuristical and the value is set to 0 when it is irrelevant
# (lossy formats) or when it can't be inferred.
# For ADPCM (and G72X) subtypes, it's hard to infer the bit depth because it's not part of the standard:
# According to https://en.wikipedia.org/wiki/Adaptive_differential_pulse-code_modulation#In_telephony,
# the default seems to be 8 bits but it can be compressed further to 4 bits.
# The dict is inspired from
# https://github.com/bastibe/python-soundfile/blob/744efb4b01abc72498a96b09115b42a4cabd85e4/soundfile.py#L66-L94
_SUBTYPE_TO_BITS_PER_SAMPLE = {
"PCM_S8": 8, # Signed 8 bit data
"PCM_16": 16, # Signed 16 bit data
"PCM_24": 24, # Signed 24 bit data
"PCM_32": 32, # Signed 32 bit data
"PCM_U8": 8, # Unsigned 8 bit data (WAV and RAW only)
"FLOAT": 32, # 32 bit float data
"DOUBLE": 64, # 64 bit float data
"ULAW": 8, # U-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
"ALAW": 8, # A-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
"IMA_ADPCM": 0, # IMA ADPCM.
"MS_ADPCM": 0, # Microsoft ADPCM.
"GSM610": 0, # GSM 6.10 encoding. (Wikipedia says 1.625 bit depth?? https://en.wikipedia.org/wiki/Full_Rate)
"VOX_ADPCM": 0, # OKI / Dialogix ADPCM
"G721_32": 0, # 32kbs G721 ADPCM encoding.
"G723_24": 0, # 24kbs G723 ADPCM encoding.
"G723_40": 0, # 40kbs G723 ADPCM encoding.
"DWVW_12": 12, # 12 bit Delta Width Variable Word encoding.
"DWVW_16": 16, # 16 bit Delta Width Variable Word encoding.
"DWVW_24": 24, # 24 bit Delta Width Variable Word encoding.
"DWVW_N": 0, # N bit Delta Width Variable Word encoding.
"DPCM_8": 8, # 8 bit differential PCM (XI only)
"DPCM_16": 16, # 16 bit differential PCM (XI only)
"VORBIS": 0, # Xiph Vorbis encoding. (lossy)
"ALAC_16": 16, # Apple Lossless Audio Codec (16 bit).
"ALAC_20": 20, # Apple Lossless Audio Codec (20 bit).
"ALAC_24": 24, # Apple Lossless Audio Codec (24 bit).
"ALAC_32": 32, # Apple Lossless Audio Codec (32 bit).
}
def _get_bit_depth(subtype):
if subtype not in _SUBTYPE_TO_BITS_PER_SAMPLE:
warnings.warn(
f"The {subtype} subtype is unknown to PaddleAudio. As a result, the bits_per_sample "
"attribute will be set to 0. If you are seeing this warning, please "
"report by opening an issue on github (after checking for existing/closed ones). "
"You may otherwise ignore this warning."
)
return _SUBTYPE_TO_BITS_PER_SAMPLE.get(subtype, 0)
_SUBTYPE_TO_ENCODING = {
"PCM_S8": "PCM_S",
"PCM_16": "PCM_S",
"PCM_24": "PCM_S",
"PCM_32": "PCM_S",
"PCM_U8": "PCM_U",
"FLOAT": "PCM_F",
"DOUBLE": "PCM_F",
"ULAW": "ULAW",
"ALAW": "ALAW",
"VORBIS": "VORBIS",
}
def _get_encoding(format: str, subtype: str):
if format == "FLAC":
return "FLAC"
return _SUBTYPE_TO_ENCODING.get(subtype, "UNKNOWN")
def info(filepath: str, format: Optional[str] = None) -> AudioMetaData:
"""Get signal information of an audio file.
Note:
``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts
``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend,
Args:
filepath (path-like object or file-like object):
Source of audio data.
format (str or None, optional):
Not used. PySoundFile does not accept format hint.
Returns:
AudioMetaData: meta data of the given audio.
"""
sinfo = soundfile.info(filepath)
return AudioMetaData(
sinfo.samplerate,
sinfo.frames,
sinfo.channels,
bits_per_sample=_get_bit_depth(sinfo.subtype),
encoding=_get_encoding(sinfo.format, sinfo.subtype),
)

@ -0,0 +1,57 @@
import itertools
from unittest import skipIf
from parameterized import parameterized
from paddlespeech.audio._internal.module_utils import is_module_available
def name_func(func, _, params):
return f'{func.__name__}_{"_".join(str(arg) for arg in params.args)}'
def dtype2subtype(dtype):
return {
"float64": "DOUBLE",
"float32": "FLOAT",
"int32": "PCM_32",
"int16": "PCM_16",
"uint8": "PCM_U8",
"int8": "PCM_S8",
}[dtype]
def skipIfFormatNotSupported(fmt):
fmts = []
if is_module_available("soundfile"):
import soundfile
fmts = soundfile.available_formats()
return skipIf(fmt not in fmts, f'"{fmt}" is not supported by soundfile')
return skipIf(True, '"soundfile" not available.')
def parameterize(*params):
return parameterized.expand(list(itertools.product(*params)), name_func=name_func)
def fetch_wav_subtype(dtype, encoding, bits_per_sample):
subtype = {
(None, None): dtype2subtype(dtype),
(None, 8): "PCM_U8",
("PCM_U", None): "PCM_U8",
("PCM_U", 8): "PCM_U8",
("PCM_S", None): "PCM_32",
("PCM_S", 16): "PCM_16",
("PCM_S", 32): "PCM_32",
("PCM_F", None): "FLOAT",
("PCM_F", 32): "FLOAT",
("PCM_F", 64): "DOUBLE",
("ULAW", None): "ULAW",
("ULAW", 8): "ULAW",
("ALAW", None): "ALAW",
("ALAW", 8): "ALAW",
}.get((encoding, bits_per_sample))
if subtype:
return subtype
raise ValueError(f"wav does not support ({encoding}, {bits_per_sample}).")

@ -0,0 +1,199 @@
#this code is from: https://github.com/pytorch/audio/blob/main/test/torchaudio_unittest/backend/soundfile/info_test.py
import tarfile
import warnings
import unittest
from unittest.mock import patch
import paddle
from paddlespeech.audio._internal import module_utils as _mod_utils
from paddlespeech.audio.backends import soundfile_backend
from tests.unit.audio.backends.common import get_bits_per_sample, get_encoding
from tests.unit.common_utils import (
get_wav_data,
nested_params,
save_wav,
TempDirMixin,
)
from common import parameterize, skipIfFormatNotSupported
import soundfile
class TestInfo(TempDirMixin, unittest.TestCase):
@parameterize(
["float32", "int32"],
[8000, 16000],
[1, 2],
)
def test_wav(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.info` can check wav file correctly"""
duration = 1
path = self.get_temp_path("data.wav")
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
info = soundfile_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == get_bits_per_sample("wav", dtype)
assert info.encoding == get_encoding("wav", dtype)
@parameterize([8000, 16000], [1, 2])
@skipIfFormatNotSupported("FLAC")
def test_flac(self, sample_rate, num_channels):
"""`soundfile_backend.info` can check flac file correctly"""
duration = 1
num_frames = sample_rate * duration
#data = torch.randn(num_frames, num_channels).numpy()
data = paddle.randn(shape=[num_frames, num_channels]).numpy()
path = self.get_temp_path("data.flac")
soundfile.write(path, data, sample_rate)
info = soundfile_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == num_frames
assert info.num_channels == num_channels
assert info.bits_per_sample == 16
assert info.encoding == "FLAC"
#@parameterize([8000, 16000], [1, 2])
#@skipIfFormatNotSupported("OGG")
#def test_ogg(self, sample_rate, num_channels):
#"""`soundfile_backend.info` can check ogg file correctly"""
#duration = 1
#num_frames = sample_rate * duration
##data = torch.randn(num_frames, num_channels).numpy()
#data = paddle.randn(shape=[num_frames, num_channels]).numpy()
#print(len(data))
#path = self.get_temp_path("data.ogg")
#soundfile.write(path, data, sample_rate)
#info = soundfile_backend.info(path)
#print(info)
#assert info.sample_rate == sample_rate
#print("info")
#print(info.num_frames)
#print("jiji")
#print(sample_rate*duration)
##assert info.num_frames == sample_rate * duration
#assert info.num_channels == num_channels
#assert info.bits_per_sample == 0
#assert info.encoding == "VORBIS"
@nested_params(
[8000, 16000],
[1, 2],
[("PCM_24", 24), ("PCM_32", 32)],
)
@skipIfFormatNotSupported("NIST")
def test_sphere(self, sample_rate, num_channels, subtype_and_bit_depth):
"""`soundfile_backend.info` can check sph file correctly"""
duration = 1
num_frames = sample_rate * duration
#data = torch.randn(num_frames, num_channels).numpy()
data = paddle.randn(shape=[num_frames, num_channels]).numpy()
path = self.get_temp_path("data.nist")
subtype, bits_per_sample = subtype_and_bit_depth
soundfile.write(path, data, sample_rate, subtype=subtype)
info = soundfile_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample
assert info.encoding == "PCM_S"
def test_unknown_subtype_warning(self):
"""soundfile_backend.info issues a warning when the subtype is unknown
This will happen if a new subtype is supported in SoundFile: the _SUBTYPE_TO_BITS_PER_SAMPLE
dict should be updated.
"""
def _mock_info_func(_):
class MockSoundFileInfo:
samplerate = 8000
frames = 356
channels = 2
subtype = "UNSEEN_SUBTYPE"
format = "UNKNOWN"
return MockSoundFileInfo()
with patch("soundfile.info", _mock_info_func):
with warnings.catch_warnings(record=True) as w:
info = soundfile_backend.info("foo")
assert len(w) == 1
assert "UNSEEN_SUBTYPE subtype is unknown to PaddleAudio" in str(w[-1].message)
assert info.bits_per_sample == 0
class TestFileObject(TempDirMixin, unittest.TestCase):
def _test_fileobj(self, ext, subtype, bits_per_sample):
"""Query audio via file-like object works"""
duration = 2
sample_rate = 16000
num_channels = 2
num_frames = sample_rate * duration
path = self.get_temp_path(f"test.{ext}")
#data = torch.randn(num_frames, num_channels).numpy()
data = paddle.randn(shape=[num_frames, num_channels]).numpy()
soundfile.write(path, data, sample_rate, subtype=subtype)
with open(path, "rb") as fileobj:
info = soundfile_backend.info(fileobj)
assert info.sample_rate == sample_rate
assert info.num_frames == num_frames
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample
assert info.encoding == "FLAC" if ext == "flac" else "PCM_S"
def test_fileobj_wav(self):
"""Loading audio via file-like object works"""
self._test_fileobj("wav", "PCM_16", 16)
@skipIfFormatNotSupported("FLAC")
def test_fileobj_flac(self):
"""Loading audio via file-like object works"""
self._test_fileobj("flac", "PCM_16", 16)
def _test_tarobj(self, ext, subtype, bits_per_sample):
"""Query compressed audio via file-like object works"""
duration = 2
sample_rate = 16000
num_channels = 2
num_frames = sample_rate * duration
audio_file = f"test.{ext}"
audio_path = self.get_temp_path(audio_file)
archive_path = self.get_temp_path("archive.tar.gz")
#data = torch.randn(num_frames, num_channels).numpy()
data = paddle.randn(shape=[num_frames, num_channels]).numpy()
soundfile.write(audio_path, data, sample_rate, subtype=subtype)
with tarfile.TarFile(archive_path, "w") as tarobj:
tarobj.add(audio_path, arcname=audio_file)
with tarfile.TarFile(archive_path, "r") as tarobj:
fileobj = tarobj.extractfile(audio_file)
info = soundfile_backend.info(fileobj)
assert info.sample_rate == sample_rate
assert info.num_frames == num_frames
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample
assert info.encoding == "FLAC" if ext == "flac" else "PCM_S"
def test_tarobj_wav(self):
"""Query compressed audio via file-like object works"""
self._test_tarobj("wav", "PCM_16", 16)
@skipIfFormatNotSupported("FLAC")
def test_tarobj_flac(self):
"""Query compressed audio via file-like object works"""
self._test_tarobj("flac", "PCM_16", 16)
if __name__ == '__main__':
unittest.main()

@ -0,0 +1,369 @@
#this code is from: https://github.com/pytorch/audio/blob/main/test/torchaudio_unittest/backend/soundfile/load_test.py
import os
import tarfile
import unittest
from unittest.mock import patch
import numpy as np
from parameterized import parameterized
import paddle
from paddlespeech.audio._internal import module_utils as _mod_utils
from paddlespeech.audio.backends import soundfile_backend
from tests.unit.audio.backends.common import get_bits_per_sample, get_encoding
from tests.unit.common_utils import (
get_wav_data,
load_wav,
nested_params,
normalize_wav,
save_wav,
TempDirMixin,
)
from common import dtype2subtype, parameterize, skipIfFormatNotSupported
import soundfile
def _get_mock_path(
ext: str,
dtype: str,
sample_rate: int,
num_channels: int,
num_frames: int,
):
return f"{dtype}_{sample_rate}_{num_channels}_{num_frames}.{ext}"
def _get_mock_params(path: str):
filename, ext = path.split(".")
parts = filename.split("_")
return {
"ext": ext,
"dtype": parts[0],
"sample_rate": int(parts[1]),
"num_channels": int(parts[2]),
"num_frames": int(parts[3]),
}
class SoundFileMock:
def __init__(self, path, mode):
assert mode == "r"
self.path = path
self._params = _get_mock_params(path)
self._start = None
@property
def samplerate(self):
return self._params["sample_rate"]
@property
def format(self):
if self._params["ext"] == "wav":
return "WAV"
if self._params["ext"] == "flac":
return "FLAC"
if self._params["ext"] == "ogg":
return "OGG"
if self._params["ext"] in ["sph", "nis", "nist"]:
return "NIST"
@property
def subtype(self):
if self._params["ext"] == "ogg":
return "VORBIS"
return dtype2subtype(self._params["dtype"])
def _prepare_read(self, start, stop, frames):
assert stop is None
self._start = start
return frames
def read(self, frames, dtype, always_2d):
assert always_2d
data = get_wav_data(
dtype,
self._params["num_channels"],
normalize=False,
num_frames=self._params["num_frames"],
channels_first=False,
).numpy()
return data[self._start : self._start + frames]
def __enter__(self):
return self
def __exit__(self, *args, **kwargs):
pass
class MockedLoadTest(unittest.TestCase):
def assert_dtype(self, ext, dtype, sample_rate, num_channels, normalize, channels_first):
"""When format is WAV or NIST, normalize=False will return the native dtype Tensor, otherwise float32"""
num_frames = 3 * sample_rate
path = _get_mock_path(ext, dtype, sample_rate, num_channels, num_frames)
expected_dtype = paddle.float32 if normalize or ext not in ["wav", "nist"] else getattr(paddle, dtype)
with patch("soundfile.SoundFile", SoundFileMock):
found, sr = soundfile_backend.load(path, normalize=normalize, channels_first=channels_first)
assert found.dtype == expected_dtype
assert sample_rate == sr
@parameterize(
["int32", "float32", "float64"],
[8000, 16000],
[1, 2],
[True, False],
[True, False],
)
def test_wav(self, dtype, sample_rate, num_channels, normalize, channels_first):
"""Returns native dtype when normalize=False else float32"""
self.assert_dtype("wav", dtype, sample_rate, num_channels, normalize, channels_first)
@parameterize(
["int32"],
[8000, 16000],
[1, 2],
[True, False],
[True, False],
)
def test_sphere(self, dtype, sample_rate, num_channels, normalize, channels_first):
"""Returns float32 always"""
self.assert_dtype("sph", dtype, sample_rate, num_channels, normalize, channels_first)
@parameterize([8000, 16000], [1, 2], [True, False], [True, False])
def test_ogg(self, sample_rate, num_channels, normalize, channels_first):
"""Returns float32 always"""
self.assert_dtype("ogg", "int16", sample_rate, num_channels, normalize, channels_first)
@parameterize([8000, 16000], [1, 2], [True, False], [True, False])
def test_flac(self, sample_rate, num_channels, normalize, channels_first):
"""`soundfile_backend.load` can load ogg format."""
self.assert_dtype("flac", "int16", sample_rate, num_channels, normalize, channels_first)
class LoadTestBase(TempDirMixin, unittest.TestCase):
def assert_wav(
self,
dtype,
sample_rate,
num_channels,
normalize,
channels_first=True,
duration=1,
):
"""`soundfile_backend.load` can load wav format correctly.
Wav data loaded with soundfile backend should match those with scipy
"""
path = self.get_temp_path("reference.wav")
num_frames = duration * sample_rate
data = get_wav_data(
dtype,
num_channels,
normalize=normalize,
num_frames=num_frames,
channels_first=channels_first,
)
save_wav(path, data, sample_rate, channels_first=channels_first)
expected = load_wav(path, normalize=normalize, channels_first=channels_first)[0]
data, sr = soundfile_backend.load(path, normalize=normalize, channels_first=channels_first)
assert sr == sample_rate
np.testing.assert_array_almost_equal(data.numpy(), expected.numpy())
def assert_sphere(
self,
dtype,
sample_rate,
num_channels,
channels_first=True,
duration=1,
):
"""`soundfile_backend.load` can load SPHERE format correctly."""
path = self.get_temp_path("reference.sph")
num_frames = duration * sample_rate
raw = get_wav_data(
dtype,
num_channels,
num_frames=num_frames,
normalize=False,
channels_first=False,
)
soundfile.write(path, raw, sample_rate, subtype=dtype2subtype(dtype), format="NIST")
expected = normalize_wav(raw.t() if channels_first else raw)
data, sr = soundfile_backend.load(path, channels_first=channels_first)
assert sr == sample_rate
#self.assertEqual(data, expected, atol=1e-4, rtol=1e-8)
np.testing.assert_array_almost_equal(data.numpy(), expected.numpy())
def assert_flac(
self,
dtype,
sample_rate,
num_channels,
channels_first=True,
duration=1,
):
"""`soundfile_backend.load` can load FLAC format correctly."""
path = self.get_temp_path("reference.flac")
num_frames = duration * sample_rate
raw = get_wav_data(
dtype,
num_channels,
num_frames=num_frames,
normalize=False,
channels_first=False,
)
soundfile.write(path, raw, sample_rate)
expected = normalize_wav(raw.t() if channels_first else raw)
data, sr = soundfile_backend.load(path, channels_first=channels_first)
assert sr == sample_rate
#self.assertEqual(data, expected, atol=1e-4, rtol=1e-8)
np.testing.assert_array_almost_equal(data.numpy(), expected.numpy())
class TestLoad(LoadTestBase):
"""Test the correctness of `soundfile_backend.load` for various formats"""
@parameterize(
["float32", "int32"],
[8000, 16000],
[1, 2],
[False, True],
[False, True],
)
def test_wav(self, dtype, sample_rate, num_channels, normalize, channels_first):
"""`soundfile_backend.load` can load wav format correctly."""
self.assert_wav(dtype, sample_rate, num_channels, normalize, channels_first)
@parameterize(
["int32"],
[16000],
[2],
[False],
)
def test_wav_large(self, dtype, sample_rate, num_channels, normalize):
"""`soundfile_backend.load` can load large wav file correctly."""
two_hours = 2 * 60 * 60
self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=two_hours)
@parameterize(["float32", "int32"], [4, 8, 16, 32], [False, True])
def test_multiple_channels(self, dtype, num_channels, channels_first):
"""`soundfile_backend.load` can load wav file with more than 2 channels."""
sample_rate = 8000
normalize = False
self.assert_wav(dtype, sample_rate, num_channels, normalize, channels_first)
#@parameterize(["int32"], [8000, 16000], [1, 2], [False, True])
#@skipIfFormatNotSupported("NIST")
#def test_sphere(self, dtype, sample_rate, num_channels, channels_first):
#"""`soundfile_backend.load` can load sphere format correctly."""
#self.assert_sphere(dtype, sample_rate, num_channels, channels_first)
#@parameterize(["int32"], [8000, 16000], [1, 2], [False, True])
#@skipIfFormatNotSupported("FLAC")
#def test_flac(self, dtype, sample_rate, num_channels, channels_first):
#"""`soundfile_backend.load` can load flac format correctly."""
#self.assert_flac(dtype, sample_rate, num_channels, channels_first)
class TestLoadFormat(TempDirMixin, unittest.TestCase):
"""Given `format` parameter, `so.load` can load files without extension"""
original = None
path = None
def _make_file(self, format_):
sample_rate = 8000
path_with_ext = self.get_temp_path(f"test.{format_}")
data = get_wav_data("float32", num_channels=2).numpy().T
soundfile.write(path_with_ext, data, sample_rate)
expected = soundfile.read(path_with_ext, dtype="float32")[0].T
path = os.path.splitext(path_with_ext)[0]
os.rename(path_with_ext, path)
return path, expected
def _test_format(self, format_):
"""Providing format allows to read file without extension"""
path, expected = self._make_file(format_)
found, _ = soundfile_backend.load(path)
#self.assertEqual(found, expected)
np.testing.assert_array_almost_equal(found, expected)
@parameterized.expand(
[
("WAV",),
("wav",),
]
)
def test_wav(self, format_):
self._test_format(format_)
@parameterized.expand(
[
("FLAC",),
("flac",),
]
)
@skipIfFormatNotSupported("FLAC")
def test_flac(self, format_):
self._test_format(format_)
class TestFileObject(TempDirMixin, unittest.TestCase):
def _test_fileobj(self, ext):
"""Loading audio via file-like object works"""
sample_rate = 16000
path = self.get_temp_path(f"test.{ext}")
data = get_wav_data("float32", num_channels=2).numpy().T
soundfile.write(path, data, sample_rate)
expected = soundfile.read(path, dtype="float32")[0].T
with open(path, "rb") as fileobj:
found, sr = soundfile_backend.load(fileobj)
assert sr == sample_rate
#self.assertEqual(expected, found)
np.testing.assert_array_almost_equal(found, expected)
def test_fileobj_wav(self):
"""Loading audio via file-like object works"""
self._test_fileobj("wav")
def test_fileobj_flac(self):
"""Loading audio via file-like object works"""
self._test_fileobj("flac")
def _test_tarfile(self, ext):
"""Loading audio via file-like object works"""
sample_rate = 16000
audio_file = f"test.{ext}"
audio_path = self.get_temp_path(audio_file)
archive_path = self.get_temp_path("archive.tar.gz")
data = get_wav_data("float32", num_channels=2).numpy().T
soundfile.write(audio_path, data, sample_rate)
expected = soundfile.read(audio_path, dtype="float32")[0].T
with tarfile.TarFile(archive_path, "w") as tarobj:
tarobj.add(audio_path, arcname=audio_file)
with tarfile.TarFile(archive_path, "r") as tarobj:
fileobj = tarobj.extractfile(audio_file)
found, sr = soundfile_backend.load(fileobj)
assert sr == sample_rate
#self.assertEqual(expected, found)
np.testing.assert_array_almost_equal(found.numpy(), expected)
def test_tarfile_wav(self):
"""Loading audio via file-like object works"""
self._test_tarfile("wav")
def test_tarfile_flac(self):
"""Loading audio via file-like object works"""
self._test_tarfile("flac")
if __name__ == '__main__':
unittest.main()

@ -0,0 +1,322 @@
import io
import unittest
from unittest.mock import patch
from paddlespeech.audio._internal import module_utils as _mod_utils
from paddlespeech.audio.backends import soundfile_backend
from tests.unit.common_utils import (
get_wav_data,
load_wav,
nested_params,
normalize_wav,
save_wav,
TempDirMixin,
)
from common import fetch_wav_subtype, parameterize, skipIfFormatNotSupported
import paddle
import numpy as np
import soundfile
class MockedSaveTest(unittest.TestCase):
@nested_params(
["float32", "int32"],
[8000, 16000],
[1, 2],
[False, True],
[
(None, None),
("PCM_U", None),
("PCM_U", 8),
("PCM_S", None),
("PCM_S", 16),
("PCM_S", 32),
("PCM_F", None),
("PCM_F", 32),
("PCM_F", 64),
("ULAW", None),
("ULAW", 8),
("ALAW", None),
("ALAW", 8),
],
)
@patch("soundfile.write")
def test_wav(self, dtype, sample_rate, num_channels, channels_first, enc_params, mocked_write):
"""soundfile_backend.save passes correct subtype to soundfile.write when WAV"""
filepath = "foo.wav"
input_tensor = get_wav_data(
dtype,
num_channels,
num_frames=3 * sample_rate,
normalize=dtype == "float32",
channels_first=channels_first,
)
input_tensor = paddle.transpose(input_tensor, [1, 0])
encoding, bits_per_sample = enc_params
soundfile_backend.save(
filepath,
input_tensor,
sample_rate,
channels_first=channels_first,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
# on +Py3.8 call_args.kwargs is more descreptive
args = mocked_write.call_args[1]
assert args["file"] == filepath
assert args["samplerate"] == sample_rate
assert args["subtype"] == fetch_wav_subtype(dtype, encoding, bits_per_sample)
assert args["format"] is None
tensor_result = paddle.transpose(input_tensor, [1, 0]) if channels_first else input_tensor
#self.assertEqual(args["data"], tensor_result.numpy())
np.testing.assert_array_almost_equal(args["data"].numpy(), tensor_result.numpy())
@patch("soundfile.write")
def assert_non_wav(
self,
fmt,
dtype,
sample_rate,
num_channels,
channels_first,
mocked_write,
encoding=None,
bits_per_sample=None,
):
"""soundfile_backend.save passes correct subtype and format to soundfile.write when SPHERE"""
filepath = f"foo.{fmt}"
input_tensor = get_wav_data(
dtype,
num_channels,
num_frames=3 * sample_rate,
normalize=False,
channels_first=channels_first,
)
input_tensor = paddle.transpose(input_tensor, [1, 0])
expected_data = paddle.transpose(input_tensor, [1, 0]) if channels_first else input_tensor
soundfile_backend.save(
filepath,
input_tensor,
sample_rate,
channels_first,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
# on +Py3.8 call_args.kwargs is more descreptive
args = mocked_write.call_args[1]
assert args["file"] == filepath
assert args["samplerate"] == sample_rate
if fmt in ["sph", "nist", "nis"]:
assert args["format"] == "NIST"
else:
assert args["format"] is None
np.testing.assert_array_almost_equal(args["data"].numpy(), expected_data.numpy())
#self.assertEqual(args["data"], expected_data)
@nested_params(
["sph", "nist", "nis"],
["int32"],
[8000, 16000],
[1, 2],
[False, True],
[
("PCM_S", 8),
("PCM_S", 16),
("PCM_S", 24),
("PCM_S", 32),
("ULAW", 8),
("ALAW", 8),
("ALAW", 16),
("ALAW", 24),
("ALAW", 32),
],
)
def test_sph(self, fmt, dtype, sample_rate, num_channels, channels_first, enc_params):
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
encoding, bits_per_sample = enc_params
self.assert_non_wav(
fmt, dtype, sample_rate, num_channels, channels_first, encoding=encoding, bits_per_sample=bits_per_sample
)
@parameterize(
["int32"],
[8000, 16000],
[1, 2],
[False, True],
[8, 16, 24],
)
def test_flac(self, dtype, sample_rate, num_channels, channels_first, bits_per_sample):
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
self.assert_non_wav("flac", dtype, sample_rate, num_channels, channels_first, bits_per_sample=bits_per_sample)
@parameterize(
["int32"],
[8000, 16000],
[1, 2],
[False, True],
)
def test_ogg(self, dtype, sample_rate, num_channels, channels_first):
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
self.assert_non_wav("ogg", dtype, sample_rate, num_channels, channels_first)
class SaveTestBase(TempDirMixin, unittest.TestCase):
def assert_wav(self, dtype, sample_rate, num_channels, num_frames):
"""`soundfile_backend.save` can save wav format."""
path = self.get_temp_path("data.wav")
expected = get_wav_data(dtype, num_channels, num_frames=num_frames, normalize=False)
soundfile_backend.save(path, expected, sample_rate)
found, sr = load_wav(path, normalize=False)
assert sample_rate == sr
#self.assertEqual(found, expected)
np.testing.assert_array_almost_equal(found.numpy(), expected.numpy())
def _assert_non_wav(self, fmt, dtype, sample_rate, num_channels):
"""`soundfile_backend.save` can save non-wav format.
Due to precision missmatch, and the lack of alternative way to decode the
resulting files without using soundfile, only meta data are validated.
"""
num_frames = sample_rate * 3
path = self.get_temp_path(f"data.{fmt}")
expected = get_wav_data(dtype, num_channels, num_frames=num_frames, normalize=False)
soundfile_backend.save(path, expected, sample_rate)
sinfo = soundfile.info(path)
assert sinfo.format == fmt.upper()
#assert sinfo.frames == num_frames this go wrong
assert sinfo.channels == num_channels
assert sinfo.samplerate == sample_rate
def assert_flac(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.save` can save flac format."""
self._assert_non_wav("flac", dtype, sample_rate, num_channels)
def assert_sphere(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.save` can save sph format."""
self._assert_non_wav("nist", dtype, sample_rate, num_channels)
def assert_ogg(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.save` can save ogg format.
As we cannot inspect the OGG format (it's lossy), we only check the metadata.
"""
self._assert_non_wav("ogg", dtype, sample_rate, num_channels)
class TestSave(SaveTestBase):
@parameterize(
["float32", "int32"],
[8000, 16000],
[1, 2],
)
def test_wav(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.save` can save wav format."""
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None)
@parameterize(
["float32", "int32"],
[4, 8, 16, 32],
)
def test_multiple_channels(self, dtype, num_channels):
"""`soundfile_backend.save` can save wav with more than 2 channels."""
sample_rate = 8000
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None)
@parameterize(
["int32"],
[8000, 16000],
[1, 2],
)
@skipIfFormatNotSupported("NIST")
def test_sphere(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.save` can save sph format."""
self.assert_sphere(dtype, sample_rate, num_channels)
@parameterize(
[8000, 16000],
[1, 2],
)
@skipIfFormatNotSupported("FLAC")
def test_flac(self, sample_rate, num_channels):
"""`soundfile_backend.save` can save flac format."""
self.assert_flac("float32", sample_rate, num_channels)
@parameterize(
[8000, 16000],
[1, 2],
)
@skipIfFormatNotSupported("OGG")
def test_ogg(self, sample_rate, num_channels):
"""`soundfile_backend.save` can save ogg/vorbis format."""
self.assert_ogg("float32", sample_rate, num_channels)
class TestSaveParams(TempDirMixin, unittest.TestCase):
"""Test the correctness of optional parameters of `soundfile_backend.save`"""
@parameterize([True, False])
def test_channels_first(self, channels_first):
"""channels_first swaps axes"""
path = self.get_temp_path("data.wav")
data = get_wav_data("int32", 2, channels_first=channels_first)
soundfile_backend.save(path, data, 8000, channels_first=channels_first)
found = load_wav(path)[0]
expected = data if channels_first else data.transpose([1, 0])
#self.assertEqual(found, expected, atol=1e-4, rtol=1e-8)
np.testing.assert_array_almost_equal(found.numpy(), expected.numpy())
class TestFileObject(TempDirMixin, unittest.TestCase):
def _test_fileobj(self, ext):
"""Saving audio to file-like object works"""
sample_rate = 16000
path = self.get_temp_path(f"test.{ext}")
subtype = "FLOAT" if ext == "wav" else None
data = get_wav_data("float32", num_channels=2)
soundfile.write(path, data.numpy().T, sample_rate, subtype=subtype)
expected = soundfile.read(path, dtype="float32")[0]
fileobj = io.BytesIO()
soundfile_backend.save(fileobj, data, sample_rate, format=ext)
fileobj.seek(0)
found, sr = soundfile.read(fileobj, dtype="float32")
assert sr == sample_rate
#self.assertEqual(expected, found, atol=1e-4, rtol=1e-8)
np.testing.assert_array_almost_equal(found, expected)
def test_fileobj_wav(self):
"""Saving audio via file-like object works"""
self._test_fileobj("wav")
@skipIfFormatNotSupported("FLAC")
def test_fileobj_flac(self):
"""Saving audio via file-like object works"""
self._test_fileobj("flac")
@skipIfFormatNotSupported("NIST")
def test_fileobj_nist(self):
"""Saving audio via file-like object works"""
self._test_fileobj("NIST")
@skipIfFormatNotSupported("OGG")
def test_fileobj_ogg(self):
"""Saving audio via file-like object works"""
self._test_fileobj("OGG")
if __name__ == '__main__':
unittest.main()

@ -9,6 +9,7 @@ import os
import io
from parameterized import parameterized
from tests.unit.audio.backends.common import get_bits_per_sample, get_encoding
from paddlespeech.audio.backends import sox_io_backend
from tests.unit.common_utils import (
@ -20,8 +21,6 @@ from tests.unit.common_utils import (
data_utils
)
from common import get_encoding, get_bits_per_sample
#code is from:https://github.com/pytorch/audio/blob/main/torchaudio/test/torchaudio_unittest/backend/sox_io/info_test.py
class TestInfo(TempDirMixin, unittest.TestCase):
@ -287,4 +286,4 @@ class TestFileObject(FileObjTestBase, unittest.TestCase):
if __name__ == '__main__':
unittest.main()
unittest.main()

Loading…
Cancel
Save