Merge pull request #2195 from SmileGoat/add_pitch2

[audio] add sox effects, load audio, save audio
pull/2382/head
Hui Zhang 3 years ago committed by GitHub
commit 4dc5f25738
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,8 +3,8 @@ include(ExternalProject)
FetchContent_Declare( FetchContent_Declare(
pybind pybind
URL https://github.com/pybind/pybind11/archive/refs/tags/v2.9.0.zip URL https://github.com/pybind/pybind11/archive/refs/tags/v2.10.0.zip
URL_HASH SHA256=1c6e0141f7092867c5bf388bc3acdb2689ed49f59c3977651394c6c87ae88232 URL_HASH SHA256=225df6e6dea7cea7c5754d4ed954e9ca7c43947b849b3795f87cb56437f1bd19
) )
FetchContent_MakeAvailable(pybind) FetchContent_MakeAvailable(pybind)
include_directories(${pybind_SOURCE_DIR}/include) include_directories(${pybind_SOURCE_DIR}/include)

@ -103,7 +103,6 @@ def _load_lib(lib: str) -> bool:
If a dependency is missing, then users have to install it. If a dependency is missing, then users have to install it.
""" """
path = _get_lib_path(lib) path = _get_lib_path(lib)
warnings.warn("lib path is :" + str(path))
if not path.exists(): if not path.exists():
warnings.warn("lib path is not exists:" + str(path)) warnings.warn("lib path is not exists:" + str(path))
return False return False

@ -145,4 +145,4 @@ def requires_sox():
return wrapped return wrapped
return return decorator

@ -1,15 +1,14 @@
from pathlib import Path from pathlib import Path
from typing import Callable from typing import Callable
from typing import Optional from typing import Optional, Tuple, Union
from typing import Tuple
from typing import Union
import paddle
from paddle import Tensor from paddle import Tensor
from .common import AudioMetaData from .common import AudioMetaData
import os
from paddlespeech.audio._internal import module_utils as _mod_utils from paddlespeech.audio._internal import module_utils as _mod_utils
from paddlespeech.audio._paddleaudio import get_info_file from paddlespeech.audio import _paddleaudio as paddleaudio
from paddlespeech.audio._paddleaudio import get_info_fileobj
#https://github.com/pytorch/audio/blob/main/torchaudio/backend/sox_io_backend.py #https://github.com/pytorch/audio/blob/main/torchaudio/backend/sox_io_backend.py
@ -29,7 +28,7 @@ def _fail_load(
normalize: bool = True, normalize: bool = True,
channels_first: bool = True, channels_first: bool = True,
format: Optional[str] = None, format: Optional[str] = None,
) -> Tuple[paddle.Tensor, int]: ) -> Tuple[Tensor, int]:
raise RuntimeError("Failed to load audio from {}".format(filepath)) raise RuntimeError("Failed to load audio from {}".format(filepath))
@ -41,26 +40,62 @@ _fallback_info_fileobj = _fail_info_fileobj
_fallback_load = _fail_load _fallback_load = _fail_load
_fallback_load_filebj = _fail_load_fileobj _fallback_load_filebj = _fail_load_fileobj
@_mod_utils.requires_sox()
def load( def load(
filepath: Union[str, Path], filepath: str,
out: Optional[Tensor]=None, frame_offset: int = 0,
normalization: Union[bool, float, Callable]=True, num_frames: int=-1,
channels_first: bool=True, normalize: bool = True,
num_frames: int=0, channels_first: bool = True,
offset: int=0, format: Optional[str]=None, ) -> Tuple[Tensor, int]:
filetype: Optional[str]=None, ) -> Tuple[Tensor, int]: if hasattr(filepath, "read"):
raise RuntimeError("No audio I/O backend is available.") ret = paddleaudio.load_audio_fileobj(
filepath, frame_offset, num_frames, normalize, channels_first, format
)
if ret is not None:
audio_tensor = paddle.to_tensor(ret[0])
return (audio_tensor, ret[1])
return _fallback_load_fileobj(filepath, frame_offset, num_frames, normalize, channels_first, format)
filepath = os.fspath(filepath)
ret = paddleaudio.sox_io_load_audio_file(
filepath, frame_offset, num_frames, normalize, channels_first, format
)
if ret is not None:
audio_tensor = paddle.to_tensor(ret[0])
return (audio_tensor, ret[1])
return _fallback_load(filepath, frame_offset, num_frames, normalize, channels_first, format)
@_mod_utils.requires_sox()
def save(filepath: str, def save(filepath: str,
src: Tensor, src: Tensor,
sample_rate: int, sample_rate: int,
precision: int = 16, channels_first: bool = True,
channels_first: bool = True) -> None: compression: Optional[float] = None,
raise RuntimeError("No audio I/O backend is available.") format: Optional[str] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
):
src_arr = src.numpy()
if hasattr(filepath, "write"):
paddleaudio.save_audio_fileobj(
filepath, src_arr, sample_rate, channels_first, compression, format, encoding, bits_per_sample
)
return
filepath = os.fspath(filepath)
paddleaudio.sox_io_save_audio_file(
filepath, src_arr, sample_rate, channels_first, compression, format, encoding, bits_per_sample
)
@_mod_utils.requires_sox() @_mod_utils.requires_sox()
def info(filepath: str, format: Optional[str]) -> None: def info(filepath: str, format: Optional[str] = None,) -> AudioMetaData:
sinfo = paddleaudio._paddleaudio.get_info_file(filepath, format) if hasattr(filepath, "read"):
sinfo = paddleaudio.get_info_fileobj(filepath, format)
if sinfo is not None:
return AudioMetaData(*sinfo)
return _fallback_info_fileobj(filepath, format)
filepath = os.fspath(filepath)
sinfo = paddleaudio.get_info_file(filepath, format)
if sinfo is not None: if sinfo is not None:
return AudioMetaData(*sinfo) return AudioMetaData(*sinfo)
return _fallback_info(filepath, format) return _fallback_info(filepath, format)

@ -0,0 +1,25 @@
from paddlespeech.audio._internal import module_utils as _mod_utils
from .sox_effects import (
apply_effects_file,
apply_effects_tensor,
effect_names,
init_sox_effects,
shutdown_sox_effects,
)
if _mod_utils.is_sox_available():
import atexit
init_sox_effects()
atexit.register(shutdown_sox_effects)
__all__ = [
"init_sox_effects",
"shutdown_sox_effects",
"effect_names",
"apply_effects_tensor",
"apply_effects_file",
]

@ -0,0 +1,238 @@
import os
from typing import List, Optional, Tuple
import paddle
import numpy
from paddlespeech.audio._internal import module_utils as _mod_utils
from paddlespeech.audio.utils.sox_utils import list_effects
from paddlespeech.audio import _paddleaudio as paddleaudio
#code is from: https://github.com/pytorch/audio/blob/main/torchaudio/sox_effects/sox_effects.py
@_mod_utils.requires_sox()
def init_sox_effects():
"""Initialize resources required to use sox effects.
Note:
You do not need to call this function manually. It is called automatically.
Once initialized, you do not need to call this function again across the multiple uses of
sox effects though it is safe to do so as long as :func:`shutdown_sox_effects` is not called yet.
Once :func:`shutdown_sox_effects` is called, you can no longer use SoX effects and initializing
again will result in error.
"""
paddleaudio.sox_effects_initialize_sox_effects()
@_mod_utils.requires_sox()
def shutdown_sox_effects():
"""Clean up resources required to use sox effects.
Note:
You do not need to call this function manually. It is called automatically.
It is safe to call this function multiple times.
Once :py:func:`shutdown_sox_effects` is called, you can no longer use SoX effects and
initializing again will result in error.
"""
paddleaudio.sox_effects_shutdown_sox_effects()
@_mod_utils.requires_sox()
def effect_names() -> List[str]:
"""Gets list of valid sox effect names
Returns:
List[str]: list of available effect names.
Example
>>> paddleaudio.sox_effects.effect_names()
['allpass', 'band', 'bandpass', ... ]
"""
return list(list_effects().keys())
@_mod_utils.requires_sox()
def apply_effects_tensor(
tensor: paddle.Tensor,
sample_rate: int,
effects: List[List[str]],
channels_first: bool = True,
) -> Tuple[paddle.Tensor, int]:
"""Apply sox effects to given Tensor
.. devices:: CPU
Note:
This function only works on CPU Tensors.
This function works in the way very similar to ``sox`` command, however there are slight
differences. For example, ``sox`` command adds certain effects automatically (such as
``rate`` effect after ``speed`` and ``pitch`` and other effects), but this function does
only applies the given effects. (Therefore, to actually apply ``speed`` effect, you also
need to give ``rate`` effect with desired sampling rate.).
Args:
tensor (paddle.Tensor): Input 2D CPU Tensor.
sample_rate (int): Sample rate
effects (List[List[str]]): List of effects.
channels_first (bool, optional): Indicates if the input Tensor's dimension is
`[channels, time]` or `[time, channels]`
Returns:
(Tensor, int): Resulting Tensor and sample rate.
The resulting Tensor has the same ``dtype`` as the input Tensor, and
the same channels order. The shape of the Tensor can be different based on the
effects applied. Sample rate can also be different based on the effects applied.
Example - Basic usage
>>>
>>> # Defines the effects to apply
>>> effects = [
... ['gain', '-n'], # normalises to 0dB
... ['pitch', '5'], # 5 cent pitch shift
... ['rate', '8000'], # resample to 8000 Hz
... ]
>>>
>>> # Generate pseudo wave:
>>> # normalized, channels first, 2ch, sampling rate 16000, 1 second
>>> sample_rate = 16000
>>> waveform = 2 * paddle.rand([2, sample_rate * 1]) - 1
>>> waveform.shape
paddle.Size([2, 16000])
>>> waveform
tensor([[ 0.3138, 0.7620, -0.9019, ..., -0.7495, -0.4935, 0.5442],
[-0.0832, 0.0061, 0.8233, ..., -0.5176, -0.9140, -0.2434]])
>>>
>>> # Apply effects
>>> waveform, sample_rate = apply_effects_tensor(
... wave_form, sample_rate, effects, channels_first=True)
>>>
>>> # Check the result
>>> # The new waveform is sampling rate 8000, 1 second.
>>> # normalization and channel order are preserved
>>> waveform.shape
paddle.Size([2, 8000])
>>> waveform
tensor([[ 0.5054, -0.5518, -0.4800, ..., -0.0076, 0.0096, -0.0110],
[ 0.1331, 0.0436, -0.3783, ..., -0.0035, 0.0012, 0.0008]])
>>> sample_rate
8000
"""
tensor_np = tensor.numpy()
ret = paddleaudio.sox_effects_apply_effects_tensor(tensor_np, sample_rate, effects, channels_first)
if ret is not None:
return (paddle.to_tensor(ret[0]), ret[1])
raise RuntimeError("Failed to apply sox effect")
@_mod_utils.requires_sox()
def apply_effects_file(
path: str,
effects: List[List[str]],
normalize: bool = True,
channels_first: bool = True,
format: Optional[str] = None,
) -> Tuple[paddle.Tensor, int]:
"""Apply sox effects to the audio file and load the resulting data as Tensor
Note:
This function works in the way very similar to ``sox`` command, however there are slight
differences. For example, ``sox`` commnad adds certain effects automatically (such as
``rate`` effect after ``speed``, ``pitch`` etc), but this function only applies the given
effects. Therefore, to actually apply ``speed`` effect, you also need to give ``rate``
effect with desired sampling rate, because internally, ``speed`` effects only alter sampling
rate and leave samples untouched.
Args:
path (path-like object or file-like object):
effects (List[List[str]]): List of effects.
normalize (bool, optional):
When ``True``, this function always return ``float32``, and sample values are
normalized to ``[-1.0, 1.0]``.
If input file is integer WAV, giving ``False`` will change the resulting Tensor type to
integer type. This argument has no effect for formats other
than integer WAV type.
channels_first (bool, optional): When True, the returned Tensor has dimension `[channel, time]`.
Otherwise, the returned Tensor's dimension is `[time, channel]`.
format (str or None, optional):
Override the format detection with the given format.
Providing the argument might help when libsox can not infer the format
from header or extension,
Returns:
(Tensor, int): Resulting Tensor and sample rate.
If ``normalize=True``, the resulting Tensor is always ``float32`` type.
If ``normalize=False`` and the input audio file is of integer WAV file, then the
resulting Tensor has corresponding integer type. (Note 24 bit integer type is not supported)
If ``channels_first=True``, the resulting Tensor has dimension `[channel, time]`,
otherwise `[time, channel]`.
Example - Basic usage
>>>
>>> # Defines the effects to apply
>>> effects = [
... ['gain', '-n'], # normalises to 0dB
... ['pitch', '5'], # 5 cent pitch shift
... ['rate', '8000'], # resample to 8000 Hz
... ]
>>>
>>> # Apply effects and load data with channels_first=True
>>> waveform, sample_rate = apply_effects_file("data.wav", effects, channels_first=True)
>>>
>>> # Check the result
>>> waveform.shape
paddle.Size([2, 8000])
>>> waveform
tensor([[ 5.1151e-03, 1.8073e-02, 2.2188e-02, ..., 1.0431e-07,
-1.4761e-07, 1.8114e-07],
[-2.6924e-03, 2.1860e-03, 1.0650e-02, ..., 6.4122e-07,
-5.6159e-07, 4.8103e-07]])
>>> sample_rate
8000
Example - Apply random speed perturbation to dataset
>>>
>>> # Load data from file, apply random speed perturbation
>>> class RandomPerturbationFile(paddle.utils.data.Dataset):
... \"\"\"Given flist, apply random speed perturbation
...
... Suppose all the input files are at least one second long.
... \"\"\"
... def __init__(self, flist: List[str], sample_rate: int):
... super().__init__()
... self.flist = flist
... self.sample_rate = sample_rate
...
... def __getitem__(self, index):
... speed = 0.5 + 1.5 * random.randn()
... effects = [
... ['gain', '-n', '-10'], # apply 10 db attenuation
... ['remix', '-'], # merge all the channels
... ['speed', f'{speed:.5f}'], # duration is now 0.5 ~ 2.0 seconds.
... ['rate', f'{self.sample_rate}'],
... ['pad', '0', '1.5'], # add 1.5 seconds silence at the end
... ['trim', '0', '2'], # get the first 2 seconds
... ]
... waveform, _ = paddleaudio.sox_effects.apply_effects_file(
... self.flist[index], effects)
... return waveform
...
... def __len__(self):
... return len(self.flist)
...
>>> dataset = RandomPerturbationFile(file_list, sample_rate=8000)
>>> loader = paddle.utils.data.DataLoader(dataset, batch_size=32)
>>> for batch in loader:
>>> pass
"""
if hasattr(path, "read"):
ret = paddleaudio.apply_effects_fileobj(path, effects, normalize, channels_first, format)
if ret is None:
raise RuntimeError("Failed to load audio from {}".format(path))
return (paddle.to_tensor(ret[0]), ret[1])
path = os.fspath(path)
ret = paddleaudio.sox_effects_apply_effects_file(path, effects, normalize, channels_first, format)
if ret is not None:
return (paddle.to_tensor(ret[0]), ret[1])
raise RuntimeError("Failed to load audio from {}".format(path))

@ -35,6 +35,11 @@ if(BUILD_SOX)
list( list(
APPEND APPEND
LIBPADDLEAUDIO_SOURCES LIBPADDLEAUDIO_SOURCES
#sox/io.cpp
#sox/utils.cpp
#sox/effects.cpp
#sox/effects_chain.cpp
#sox/types.cpp
) )
list( list(
APPEND APPEND
@ -139,9 +144,10 @@ if(BUILD_SOX)
list( list(
APPEND APPEND
EXTENSION_SOURCES EXTENSION_SOURCES
# pybind/sox/effects.cpp pybind/sox/effects.cpp
# pybind/sox/effects_chain.cpp pybind/sox/effects_chain.cpp
pybind/sox/io.cpp pybind/sox/io.cpp
pybind/sox/types.cpp
pybind/sox/utils.cpp pybind/sox/utils.cpp
) )
endif() endif()

@ -3,16 +3,82 @@
#include "paddlespeech/audio/src/pybind/kaldi/kaldi_feature.h" #include "paddlespeech/audio/src/pybind/kaldi/kaldi_feature.h"
#include "paddlespeech/audio/src/pybind/sox/io.h" #include "paddlespeech/audio/src/pybind/sox/io.h"
#include "paddlespeech/audio/src/pybind/sox/effects.h"
#include "paddlespeech/audio/third_party/kaldi/feat/feature-fbank.h" #include "paddlespeech/audio/third_party/kaldi/feat/feature-fbank.h"
#include <pybind11/stl.h>
#include <pybind11/pybind11.h>
// `tl::optional`
namespace pybind11 { namespace detail {
template <typename T>
struct type_caster<tl::optional<T>> : optional_caster<tl::optional<T>> {};
}}
PYBIND11_MODULE(_paddleaudio, m) { PYBIND11_MODULE(_paddleaudio, m) {
#ifdef INCLUDE_SOX #ifdef INCLUDE_SOX
m.def("get_info_file", m.def("get_info_file",
&paddleaudio::sox_io::get_info_file, &paddleaudio::sox_io::get_info_file,
"Get metadata of audio file."); "Get metadata of audio file.");
// support obj later
m.def("get_info_fileobj", m.def("get_info_fileobj",
&paddleaudio::sox_io::get_info_fileobj, &paddleaudio::sox_io::get_info_fileobj,
"Get metadata of audio in file object."); "Get metadata of audio in file object.");
m.def("load_audio_fileobj",
&paddleaudio::sox_io::load_audio_fileobj,
"Load audio from file object.");
m.def("save_audio_fileobj",
&paddleaudio::sox_io::save_audio_fileobj,
"Save audio to file obj.");
// sox io
m.def("sox_io_get_info", &paddleaudio::sox_io::get_info_file);
m.def(
"sox_io_load_audio_file",
&paddleaudio::sox_io::load_audio_file);
m.def(
"sox_io_save_audio_file",
&paddleaudio::sox_io::save_audio_file);
// sox utils
m.def("sox_utils_set_seed", &paddleaudio::sox_utils::set_seed);
m.def(
"sox_utils_set_verbosity",
&paddleaudio::sox_utils::set_verbosity);
m.def(
"sox_utils_set_use_threads",
&paddleaudio::sox_utils::set_use_threads);
m.def(
"sox_utils_set_buffer_size",
&paddleaudio::sox_utils::set_buffer_size);
m.def(
"sox_utils_list_effects",
&paddleaudio::sox_utils::list_effects);
m.def(
"sox_utils_list_read_formats",
&paddleaudio::sox_utils::list_read_formats);
m.def(
"sox_utils_list_write_formats",
&paddleaudio::sox_utils::list_write_formats);
m.def(
"sox_utils_get_buffer_size",
&paddleaudio::sox_utils::get_buffer_size);
// effect
m.def("apply_effects_fileobj",
&paddleaudio::sox_effects::apply_effects_fileobj,
"Decode audio data from file-like obj and apply effects.");
m.def("sox_effects_initialize_sox_effects",
&paddleaudio::sox_effects::initialize_sox_effects);
m.def(
"sox_effects_shutdown_sox_effects",
&paddleaudio::sox_effects::shutdown_sox_effects);
m.def(
"sox_effects_apply_effects_tensor",
&paddleaudio::sox_effects::apply_effects_tensor);
m.def(
"sox_effects_apply_effects_file",
&paddleaudio::sox_effects::apply_effects_file);
#endif #endif
#ifdef INCLUDE_KALDI #ifdef INCLUDE_KALDI

@ -0,0 +1,257 @@
#include <mutex>
#include <sox.h>
#include "paddlespeech/audio/src/pybind/sox/effects.h"
#include "paddlespeech/audio/src/pybind/sox/effects_chain.h"
#include "paddlespeech/audio/src/pybind/sox/utils.h"
using namespace paddleaudio::sox_utils;
namespace paddleaudio::sox_effects {
// Streaming decoding over file-like object is tricky because libsox operates on
// FILE pointer. The folloing is what `sox` and `play` commands do
// - file input -> FILE pointer
// - URL input -> call wget in suprocess and pipe the data -> FILE pointer
// - stdin -> FILE pointer
//
// We want to, instead, fetch byte strings chunk by chunk, consume them, and
// discard.
//
// Here is the approach
// 1. Initialize sox_format_t using sox_open_mem_read, providing the initial
// chunk of byte string
// This will perform header-based format detection, if necessary, then fill
// the metadata of sox_format_t. Internally, sox_open_mem_read uses fmemopen,
// which returns FILE* which points the buffer of the provided byte string.
// 2. Each time sox reads a chunk from the FILE*, we update the underlying
// buffer in a way that it
// starts with unseen data, and append the new data read from the given
// fileobj. This will trick libsox as if it keeps reading from the FILE*
// continuously.
// For Step 2. see `fileobj_input_drain` function in effects_chain.cpp
auto apply_effects_fileobj(
py::object fileobj,
const std::vector<std::vector<std::string>>& effects,
tl::optional<bool> normalize,
tl::optional<bool> channels_first,
tl::optional<std::string> format)
-> tl::optional<std::tuple<py::array, int64_t>> {
// Prepare the buffer used throughout the lifecycle of SoxEffectChain.
//
// For certain format (such as FLAC), libsox keeps reading the content at
// the initialization unless it reaches EOF even when the header is properly
// parsed. (Making buffer size 8192, which is way bigger than the header,
// resulted in libsox consuming all the buffer content at the time it opens
// the file.) Therefore buffer has to always contain valid data, except after
// EOF. We default to `sox_get_globals()->bufsiz`* for buffer size and we
// first check if there is enough data to fill the buffer. `read_fileobj`
// repeatedly calls `read` method until it receives the requested length of
// bytes or it reaches EOF. If we get bytes shorter than requested, that means
// the whole audio data are fetched.
//
// * This can be changed with `paddleaudio.utils.sox_utils.set_buffer_size`.
const auto capacity = [&]() {
// NOTE:
// Use the abstraction provided by `libpaddleaudio` to access the global
// config defined by libsox. Directly using `sox_get_globals` function will
// end up retrieving the static variable defined in `_paddleaudio`, which is
// not correct.
const auto bufsiz = get_buffer_size();
const int64_t kDefaultCapacityInBytes = 256;
return (bufsiz > kDefaultCapacityInBytes) ? bufsiz
: kDefaultCapacityInBytes;
}();
std::string buffer(capacity, '\0');
auto* in_buf = const_cast<char*>(buffer.data());
auto num_read = read_fileobj(&fileobj, capacity, in_buf);
// If the file is shorter than 256, then libsox cannot read the header.
auto in_buffer_size = (num_read > 256) ? num_read : 256;
// Open file (this starts reading the header)
// When opening a file there are two functions that can touches FILE*.
// * `auto_detect_format`
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L43
// * `startread` handler of detected format.
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L574
// To see the handler of a particular format, go to
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/<FORMAT>.c
// For example, voribs can be found
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/vorbis.c#L97-L158
SoxFormat sf(sox_open_mem_read(
in_buf,
in_buffer_size,
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
// In case of streamed data, length can be 0
if (static_cast<sox_format_t*>(sf) == nullptr ||
sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
return {};
}
// Prepare output buffer
std::vector<sox_sample_t> out_buffer;
out_buffer.reserve(sf->signal.length);
// Create and run SoxEffectsChain
const auto dtype = get_dtype(sf->encoding.encoding, sf->signal.precision);
paddleaudio::sox_effects_chain::SoxEffectsChainPyBind chain(
/*input_encoding=*/sf->encoding,
/*output_encoding=*/get_tensor_encodinginfo(dtype));
chain.addInputFileObj(sf, in_buf, in_buffer_size, &fileobj);
for (const auto& effect : effects) {
chain.addEffect(effect);
}
chain.addOutputBuffer(&out_buffer);
chain.run();
// Create tensor from buffer
bool channels_first_ = channels_first.value_or(true);
auto tensor = convert_to_tensor(
/*buffer=*/out_buffer.data(),
/*num_samples=*/out_buffer.size(),
/*num_channels=*/chain.getOutputNumChannels(),
dtype,
normalize.value_or(true),
channels_first_);
return std::forward_as_tuple(
tensor, static_cast<int64_t>(chain.getOutputSampleRate()));
}
namespace {
enum SoxEffectsResourceState { NotInitialized, Initialized, ShutDown };
SoxEffectsResourceState SOX_RESOURCE_STATE = NotInitialized;
std::mutex SOX_RESOUCE_STATE_MUTEX;
} // namespace
void initialize_sox_effects() {
const std::lock_guard<std::mutex> lock(SOX_RESOUCE_STATE_MUTEX);
switch (SOX_RESOURCE_STATE) {
case NotInitialized:
if (sox_init() != SOX_SUCCESS) {
throw std::runtime_error("Failed to initialize sox effects.");
};
SOX_RESOURCE_STATE = Initialized;
break;
case Initialized:
break;
case ShutDown:
throw std::runtime_error(
"SoX Effects has been shut down. Cannot initialize again.");
}
};
void shutdown_sox_effects() {
const std::lock_guard<std::mutex> lock(SOX_RESOUCE_STATE_MUTEX);
switch (SOX_RESOURCE_STATE) {
case NotInitialized:
throw std::runtime_error(
"SoX Effects is not initialized. Cannot shutdown.");
case Initialized:
if (sox_quit() != SOX_SUCCESS) {
throw std::runtime_error("Failed to initialize sox effects.");
};
SOX_RESOURCE_STATE = ShutDown;
break;
case ShutDown:
break;
}
}
auto apply_effects_tensor(
py::array waveform,
int64_t sample_rate,
const std::vector<std::vector<std::string>>& effects,
bool channels_first) -> std::tuple<py::array, int64_t> {
validate_input_tensor(waveform);
// Create SoxEffectsChain
const auto dtype = waveform.dtype();
paddleaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/get_tensor_encodinginfo(dtype),
/*output_encoding=*/get_tensor_encodinginfo(dtype));
// Prepare output buffer
std::vector<sox_sample_t> out_buffer;
out_buffer.reserve(waveform.size());
// Build and run effects chain
chain.addInputTensor(&waveform, sample_rate, channels_first);
for (const auto& effect : effects) {
chain.addEffect(effect);
}
chain.addOutputBuffer(&out_buffer);
chain.run();
// Create tensor from buffer
auto out_tensor = convert_to_tensor(
/*buffer=*/out_buffer.data(),
/*num_samples=*/out_buffer.size(),
/*num_channels=*/chain.getOutputNumChannels(),
dtype,
/*normalize=*/false,
channels_first);
return std::tuple<py::array, int64_t>(
out_tensor, chain.getOutputSampleRate());
}
auto apply_effects_file(
const std::string& path,
const std::vector<std::vector<std::string>>& effects,
tl::optional<bool> normalize,
tl::optional<bool> channels_first,
const tl::optional<std::string>& format)
-> tl::optional<std::tuple<py::array, int64_t>> {
// Open input file
SoxFormat sf(sox_open_read(
path.c_str(),
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
if (static_cast<sox_format_t*>(sf) == nullptr ||
sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
return {};
}
const auto dtype = get_dtype(sf->encoding.encoding, sf->signal.precision);
// Prepare output
std::vector<sox_sample_t> out_buffer;
out_buffer.reserve(sf->signal.length);
// Create and run SoxEffectsChain
paddleaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/sf->encoding,
/*output_encoding=*/get_tensor_encodinginfo(dtype));
chain.addInputFile(sf);
for (const auto& effect : effects) {
chain.addEffect(effect);
}
chain.addOutputBuffer(&out_buffer);
chain.run();
// Create tensor from buffer
bool channels_first_ = channels_first.value_or(true);
auto tensor = convert_to_tensor(
/*buffer=*/out_buffer.data(),
/*num_samples=*/out_buffer.size(),
/*num_channels=*/chain.getOutputNumChannels(),
dtype,
normalize.value_or(true),
channels_first_);
return std::tuple<py::array, int64_t>(
tensor, chain.getOutputSampleRate());
}
} // namespace paddleaudio::sox_effects

@ -0,0 +1,36 @@
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include "paddlespeech/audio/src/optional/optional.hpp"
namespace py = pybind11;
namespace paddleaudio::sox_effects {
auto apply_effects_fileobj(
py::object fileobj,
const std::vector<std::vector<std::string>>& effects,
tl::optional<bool> normalize,
tl::optional<bool> channels_first,
tl::optional<std::string> format)
-> tl::optional<std::tuple<py::array, int64_t>>;
void initialize_sox_effects();
void shutdown_sox_effects();
auto apply_effects_tensor(
py::array waveform,
int64_t sample_rate,
const std::vector<std::vector<std::string>>& effects,
bool channels_first) -> std::tuple<py::array, int64_t>;
auto apply_effects_file(
const std::string& path,
const std::vector<std::vector<std::string>>& effects,
tl::optional<bool> normalize,
tl::optional<bool> channels_first,
const tl::optional<std::string>& format)
-> tl::optional<std::tuple<py::array, int64_t>>;
} // namespace paddleaudio::sox_effects

@ -0,0 +1,595 @@
#include <sox.h>
#include <iostream>
#include <vector>
#include "paddlespeech/audio/src/pybind/sox/effects_chain.h"
#include "paddlespeech/audio/src/pybind/sox/utils.h"
using namespace paddleaudio::sox_utils;
namespace paddleaudio::sox_effects_chain {
namespace {
/// helper classes for passing the location of input tensor and output buffer
///
/// drain/flow callback functions require plaing C style function signature and
/// the way to pass extra data is to attach data to sox_effect_t::priv pointer.
/// The following structs will be assigned to sox_effect_t::priv pointer which
/// gives sox_effect_t an access to input Tensor and output buffer object.
struct TensorInputPriv {
size_t index;
py::array* waveform;
int64_t sample_rate;
bool channels_first;
};
struct TensorOutputPriv {
std::vector<sox_sample_t>* buffer;
};
struct FileOutputPriv {
sox_format_t* sf;
};
/// Callback function to feed Tensor data to SoxEffectChain.
int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
// Retrieve the input Tensor and current index
auto priv = static_cast<TensorInputPriv*>(effp->priv);
auto index = priv->index;
auto tensor = *(priv->waveform);
auto num_channels = effp->out_signal.channels;
// Adjust the number of samples to read
const size_t num_samples = tensor.size();
if (index + *osamp > num_samples) {
*osamp = num_samples - index;
}
// Ensure that it's a multiple of the number of channels
*osamp -= *osamp % num_channels;
// Slice the input Tensor
// refacor this module, chunk
auto i_frame = index / num_channels;
auto num_frames = *osamp / num_channels;
std::vector<int> chunk(num_frames*num_channels);
py::buffer_info ori_info = tensor.request();
void* ptr = ori_info.ptr;
// Convert to sox_sample_t (int32_t)
switch (tensor.dtype().num()) {
//case c10::ScalarType::Float: {
case 11: {
// Need to convert to 64-bit precision so that
// values around INT32_MIN/MAX are handled correctly.
for (int idx = 0; idx < chunk.size(); ++idx) {
int frame_idx = (idx + index) / num_channels;
int channels_idx = (idx + index) % num_channels;
double elem = 0;
if (priv->channels_first) {
elem = *(float*)tensor.data(channels_idx, frame_idx);
} else {
elem = *(float*)tensor.data(frame_idx, channels_idx);
}
elem = elem * 2147483648.;
// *new_ptr = std::clamp(elem, INT32_MIN, INT32_MAX);
if (elem > INT32_MAX) {
chunk[idx] = INT32_MAX;
} else if (elem < INT32_MIN) {
chunk[idx] = INT32_MIN;
} else {
chunk[idx] = elem;
}
}
break;
}
//case c10::ScalarType::Int: {
case 5: {
for (int idx = 0; idx < chunk.size(); ++idx) {
int frame_idx = (idx + index) / num_channels;
int channels_idx = (idx + index) % num_channels;
int elem = 0;
if (priv->channels_first) {
elem = *(int*)tensor.data(channels_idx, frame_idx);
} else {
elem = *(int*)tensor.data(frame_idx, channels_idx);
}
chunk[idx] = elem;
}
break;
}
// case short
case 3: {
for (int idx = 0; idx < chunk.size(); ++idx) {
int frame_idx = (idx + index) / num_channels;
int channels_idx = (idx + index) % num_channels;
int16_t elem = 0;
if (priv->channels_first) {
elem = *(int16_t*)tensor.data(channels_idx, frame_idx);
} else {
elem = *(int16_t*)tensor.data(frame_idx, channels_idx);
}
chunk[idx] = elem * 65536;
}
break;
}
// case byte
case 1: {
for (int idx = 0; idx < chunk.size(); ++idx) {
int frame_idx = (idx + index) / num_channels;
int channels_idx = (idx + index) % num_channels;
int8_t elem = 0;
if (priv->channels_first) {
elem = *(int8_t*)tensor.data(channels_idx, frame_idx);
} else {
elem = *(int8_t*)tensor.data(frame_idx, channels_idx);
}
chunk[idx] = (elem - 128) * 16777216;
}
break;
}
default:
throw std::runtime_error("Unexpected dtype.");
}
// Write to buffer
memcpy(obuf, chunk.data(), *osamp * 4);
priv->index += *osamp;
return (priv->index == num_samples) ? SOX_EOF : SOX_SUCCESS;
}
/// Callback function to fetch data from SoxEffectChain.
int tensor_output_flow(
sox_effect_t* effp,
sox_sample_t const* ibuf,
sox_sample_t* obuf LSX_UNUSED,
size_t* isamp,
size_t* osamp) {
*osamp = 0;
// Get output buffer
auto out_buffer = static_cast<TensorOutputPriv*>(effp->priv)->buffer;
// Append at the end
out_buffer->insert(out_buffer->end(), ibuf, ibuf + *isamp);
return SOX_SUCCESS;
}
int file_output_flow(
sox_effect_t* effp,
sox_sample_t const* ibuf,
sox_sample_t* obuf LSX_UNUSED,
size_t* isamp,
size_t* osamp) {
*osamp = 0;
if (*isamp) {
auto sf = static_cast<FileOutputPriv*>(effp->priv)->sf;
if (sox_write(sf, ibuf, *isamp) != *isamp) {
if (sf->sox_errno) {
std::ostringstream stream;
stream << sf->sox_errstr << " " << sox_strerror(sf->sox_errno) << " "
<< sf->filename;
throw std::runtime_error(stream.str());
}
return SOX_EOF;
}
}
return SOX_SUCCESS;
}
sox_effect_handler_t* get_tensor_input_handler() {
static sox_effect_handler_t handler{
/*name=*/"input_tensor",
/*usage=*/NULL,
/*flags=*/SOX_EFF_MCHAN,
/*getopts=*/NULL,
/*start=*/NULL,
/*flow=*/NULL,
/*drain=*/tensor_input_drain,
/*stop=*/NULL,
/*kill=*/NULL,
/*priv_size=*/sizeof(TensorInputPriv)};
return &handler;
}
sox_effect_handler_t* get_tensor_output_handler() {
static sox_effect_handler_t handler{
/*name=*/"output_tensor",
/*usage=*/NULL,
/*flags=*/SOX_EFF_MCHAN,
/*getopts=*/NULL,
/*start=*/NULL,
/*flow=*/tensor_output_flow,
/*drain=*/NULL,
/*stop=*/NULL,
/*kill=*/NULL,
/*priv_size=*/sizeof(TensorOutputPriv)};
return &handler;
}
sox_effect_handler_t* get_file_output_handler() {
static sox_effect_handler_t handler{
/*name=*/"output_file",
/*usage=*/NULL,
/*flags=*/SOX_EFF_MCHAN,
/*getopts=*/NULL,
/*start=*/NULL,
/*flow=*/file_output_flow,
/*drain=*/NULL,
/*stop=*/NULL,
/*kill=*/NULL,
/*priv_size=*/sizeof(FileOutputPriv)};
return &handler;
}
} // namespace
SoxEffect::SoxEffect(sox_effect_t* se) noexcept : se_(se) {}
SoxEffect::~SoxEffect() {
if (se_ != nullptr) {
free(se_);
}
}
SoxEffect::operator sox_effect_t*() const {
return se_;
}
auto SoxEffect::operator->() noexcept -> sox_effect_t* {
return se_;
}
SoxEffectsChain::SoxEffectsChain(
sox_encodinginfo_t input_encoding,
sox_encodinginfo_t output_encoding)
: in_enc_(input_encoding),
out_enc_(output_encoding),
in_sig_(),
interm_sig_(),
out_sig_(),
sec_(sox_create_effects_chain(&in_enc_, &out_enc_)) {
if (!sec_) {
throw std::runtime_error("Failed to create effect chain.");
}
}
SoxEffectsChain::~SoxEffectsChain() {
if (sec_ != nullptr) {
sox_delete_effects_chain(sec_);
}
}
void SoxEffectsChain::run() {
sox_flow_effects(sec_, NULL, NULL);
}
void SoxEffectsChain::addInputTensor(
py::array* waveform,
int64_t sample_rate,
bool channels_first) {
in_sig_ = get_signalinfo(waveform, sample_rate, "wav", channels_first);
interm_sig_ = in_sig_;
SoxEffect e(sox_create_effect(get_tensor_input_handler()));
auto priv = static_cast<TensorInputPriv*>(e->priv);
priv->index = 0;
priv->waveform = waveform;
priv->sample_rate = sample_rate;
priv->channels_first = channels_first;
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
throw std::runtime_error(
"Internal Error: Failed to add effect: input_tensor");
}
}
void SoxEffectsChain::addOutputBuffer(
std::vector<sox_sample_t>* output_buffer) {
SoxEffect e(sox_create_effect(get_tensor_output_handler()));
static_cast<TensorOutputPriv*>(e->priv)->buffer = output_buffer;
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
throw std::runtime_error(
"Internal Error: Failed to add effect: output_tensor");
}
}
void SoxEffectsChain::addInputFile(sox_format_t* sf) {
in_sig_ = sf->signal;
interm_sig_ = in_sig_;
SoxEffect e(sox_create_effect(sox_find_effect("input")));
char* opts[] = {(char*)sf};
sox_effect_options(e, 1, opts);
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
std::ostringstream stream;
stream << "Internal Error: Failed to add effect: input " << sf->filename;
throw std::runtime_error(stream.str());
}
}
void SoxEffectsChain::addOutputFile(sox_format_t* sf) {
out_sig_ = sf->signal;
SoxEffect e(sox_create_effect(get_file_output_handler()));
static_cast<FileOutputPriv*>(e->priv)->sf = sf;
if (sox_add_effect(sec_, e, &interm_sig_, &out_sig_) != SOX_SUCCESS) {
std::ostringstream stream;
stream << "Internal Error: Failed to add effect: output " << sf->filename;
throw std::runtime_error(stream.str());
}
}
void SoxEffectsChain::addEffect(const std::vector<std::string> effect) {
const auto num_args = effect.size();
if (num_args == 0) {
throw std::runtime_error("Invalid argument: empty effect.");
}
const auto name = effect[0];
if (UNSUPPORTED_EFFECTS.find(name) != UNSUPPORTED_EFFECTS.end()) {
std::ostringstream stream;
stream << "Unsupported effect: " << name;
throw std::runtime_error(stream.str());
}
auto returned_effect = sox_find_effect(name.c_str());
if (!returned_effect) {
std::ostringstream stream;
stream << "Unsupported effect: " << name;
throw std::runtime_error(stream.str());
}
SoxEffect e(sox_create_effect(returned_effect));
const auto num_options = num_args - 1;
std::vector<char*> opts;
for (size_t i = 1; i < num_args; ++i) {
opts.push_back((char*)effect[i].c_str());
}
if (sox_effect_options(e, num_options, num_options ? opts.data() : nullptr) !=
SOX_SUCCESS) {
std::ostringstream stream;
stream << "Invalid effect option:";
for (const auto& v : effect) {
stream << " " << v;
}
throw std::runtime_error(stream.str());
}
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
std::ostringstream stream;
stream << "Internal Error: Failed to add effect: \"" << name;
for (size_t i = 1; i < num_args; ++i) {
stream << " " << effect[i];
}
stream << "\"";
throw std::runtime_error(stream.str());
}
}
int64_t SoxEffectsChain::getOutputNumChannels() {
return interm_sig_.channels;
}
int64_t SoxEffectsChain::getOutputSampleRate() {
return interm_sig_.rate;
}
namespace {
/// helper classes for passing file-like object to SoxEffectChain
struct FileObjInputPriv {
sox_format_t* sf;
py::object* fileobj;
bool eof_reached;
char* buffer;
uint64_t buffer_size;
};
struct FileObjOutputPriv {
sox_format_t* sf;
py::object* fileobj;
char** buffer;
size_t* buffer_size;
};
/// Callback function to feed byte string
/// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/sox.h#L1268-L1278
auto fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp)
-> int {
auto priv = static_cast<FileObjInputPriv*>(effp->priv);
auto sf = priv->sf;
auto buffer = priv->buffer;
// 1. Refresh the buffer
//
// NOTE:
// Since the underlying FILE* was opened with `fmemopen`, the only way
// libsox detect EOF is reaching the end of the buffer. (null byte won't
// help) Therefore we need to align the content at the end of buffer,
// otherwise, libsox will keep reading the content beyond intended length.
//
// Before:
//
// |<-------consumed------>|<---remaining--->|
// |***********************|-----------------|
// ^ ftell
//
// After:
//
// |<-offset->|<---remaining--->|<-new data->|
// |**********|-----------------|++++++++++++|
// ^ ftell
// NOTE:
// Do not use `sf->tell_off` here. Presumably, `tell_off` and `fseek` are
// supposed to be in sync, but there are cases (Vorbis) they are not
// in sync and `tell_off` has seemingly uninitialized value, which
// leads num_remain to be negative and cause segmentation fault
// in `memmove`.
const auto tell = ftell((FILE*)sf->fp);
if (tell < 0) {
throw std::runtime_error("Internal Error: ftell failed.");
}
const auto num_consumed = static_cast<size_t>(tell);
if (num_consumed > priv->buffer_size) {
throw std::runtime_error("Internal Error: buffer overrun.");
}
const auto num_remain = priv->buffer_size - num_consumed;
// 1.1. Fetch the data to see if there is data to fill the buffer
size_t num_refill = 0;
std::string chunk(num_consumed, '\0');
if (num_consumed && !priv->eof_reached) {
num_refill = read_fileobj(
priv->fileobj, num_consumed, const_cast<char*>(chunk.data()));
if (num_refill < num_consumed) {
priv->eof_reached = true;
}
}
const auto offset = num_consumed - num_refill;
// 1.2. Move the unconsumed data towards the beginning of buffer.
if (num_remain) {
auto src = static_cast<void*>(buffer + num_consumed);
auto dst = static_cast<void*>(buffer + offset);
memmove(dst, src, num_remain);
}
// 1.3. Refill the remaining buffer.
if (num_refill) {
auto src = static_cast<void*>(const_cast<char*>(chunk.c_str()));
auto dst = buffer + offset + num_remain;
memcpy(dst, src, num_refill);
}
// 1.4. Set the file pointer to the new offset
sf->tell_off = offset;
fseek((FILE*)sf->fp, offset, SEEK_SET);
// 2. Perform decoding operation
// The following part is practically same as "input" effect
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/input.c#L30-L48
// At this point, osamp represents the buffer size in bytes,
// but sox_read expects the maximum number of samples ready to read.
// Normally, this is fine, but in case when the samples are not 4-byte
// aligned, (e.g. sample is 24bits), the resulting signal is not correct.
// https://github.com/pytorch/audio/issues/2083
if (sf->encoding.bits_per_sample > 0)
*osamp /= (sf->encoding.bits_per_sample / 8);
// Ensure that it's a multiple of the number of channels
*osamp -= *osamp % effp->out_signal.channels;
// Read up to *osamp samples into obuf;
// store the actual number read back to *osamp
*osamp = sox_read(sf, obuf, *osamp);
// Decoding is finished when fileobject is exhausted and sox can no longer
// decode a sample.
return (priv->eof_reached && !*osamp) ? SOX_EOF : SOX_SUCCESS;
}
auto fileobj_output_flow(
sox_effect_t* effp,
sox_sample_t const* ibuf,
sox_sample_t* obuf LSX_UNUSED,
size_t* isamp,
size_t* osamp) -> int {
*osamp = 0;
if (*isamp) {
auto priv = static_cast<FileObjOutputPriv*>(effp->priv);
auto sf = priv->sf;
auto fp = static_cast<FILE*>(sf->fp);
auto fileobj = priv->fileobj;
auto buffer = priv->buffer;
// Encode chunk
auto num_samples_written = sox_write(sf, ibuf, *isamp);
fflush(fp);
// Copy the encoded chunk to python object.
fileobj->attr("write")(py::bytes(*buffer, ftell(fp)));
// Reset FILE*
sf->tell_off = 0;
fseek(fp, 0, SEEK_SET);
if (num_samples_written != *isamp) {
if (sf->sox_errno) {
std::ostringstream stream;
stream << sf->sox_errstr << " " << sox_strerror(sf->sox_errno) << " "
<< sf->filename;
throw std::runtime_error(stream.str());
}
return SOX_EOF;
}
}
return SOX_SUCCESS;
}
auto get_fileobj_input_handler() -> sox_effect_handler_t* {
static sox_effect_handler_t handler{
/*name=*/"input_fileobj_object",
/*usage=*/nullptr,
/*flags=*/SOX_EFF_MCHAN,
/*getopts=*/nullptr,
/*start=*/nullptr,
/*flow=*/nullptr,
/*drain=*/fileobj_input_drain,
/*stop=*/nullptr,
/*kill=*/nullptr,
/*priv_size=*/sizeof(FileObjInputPriv)};
return &handler;
}
auto get_fileobj_output_handler() -> sox_effect_handler_t* {
static sox_effect_handler_t handler{
/*name=*/"output_fileobj_object",
/*usage=*/nullptr,
/*flags=*/SOX_EFF_MCHAN,
/*getopts=*/nullptr,
/*start=*/nullptr,
/*flow=*/fileobj_output_flow,
/*drain=*/nullptr,
/*stop=*/nullptr,
/*kill=*/nullptr,
/*priv_size=*/sizeof(FileObjOutputPriv)};
return &handler;
}
} // namespace
void SoxEffectsChainPyBind::addInputFileObj(
sox_format_t* sf,
char* buffer,
uint64_t buffer_size,
py::object* fileobj) {
in_sig_ = sf->signal;
interm_sig_ = in_sig_;
SoxEffect e(sox_create_effect(get_fileobj_input_handler()));
auto priv = static_cast<FileObjInputPriv*>(e->priv);
priv->sf = sf;
priv->fileobj = fileobj;
priv->eof_reached = false;
priv->buffer = buffer;
priv->buffer_size = buffer_size;
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
throw std::runtime_error(
"Internal Error: Failed to add effect: input fileobj");
}
}
void SoxEffectsChainPyBind::addOutputFileObj(
sox_format_t* sf,
char** buffer,
size_t* buffer_size,
py::object* fileobj) {
out_sig_ = sf->signal;
SoxEffect e(sox_create_effect(get_fileobj_output_handler()));
auto priv = static_cast<FileObjOutputPriv*>(e->priv);
priv->sf = sf;
priv->fileobj = fileobj;
priv->buffer = buffer;
priv->buffer_size = buffer_size;
if (sox_add_effect(sec_, e, &interm_sig_, &out_sig_) != SOX_SUCCESS) {
throw std::runtime_error(
"Internal Error: Failed to add effect: output fileobj");
}
}
} // namespace paddleaudio::sox_effects_chain

@ -0,0 +1,76 @@
#pragma once
#include <sox.h>
#include "paddlespeech/audio/src/pybind/sox/utils.h"
namespace paddleaudio::sox_effects_chain {
// Helper struct to safely close sox_effect_t* pointer returned by
// sox_create_effect
struct SoxEffect {
explicit SoxEffect(sox_effect_t* se) noexcept;
SoxEffect(const SoxEffect& other) = delete;
SoxEffect(const SoxEffect&& other) = delete;
auto operator=(const SoxEffect& other) -> SoxEffect& = delete;
auto operator=(SoxEffect&& other) -> SoxEffect& = delete;
~SoxEffect();
operator sox_effect_t*() const;
auto operator->() noexcept -> sox_effect_t*;
private:
sox_effect_t* se_;
};
// Helper struct to safely close sox_effects_chain_t with handy methods
class SoxEffectsChain {
const sox_encodinginfo_t in_enc_;
const sox_encodinginfo_t out_enc_;
protected:
sox_signalinfo_t in_sig_;
sox_signalinfo_t interm_sig_;
sox_signalinfo_t out_sig_;
sox_effects_chain_t* sec_;
public:
explicit SoxEffectsChain(
sox_encodinginfo_t input_encoding,
sox_encodinginfo_t output_encoding);
SoxEffectsChain(const SoxEffectsChain& other) = delete;
SoxEffectsChain(const SoxEffectsChain&& other) = delete;
SoxEffectsChain& operator=(const SoxEffectsChain& other) = delete;
SoxEffectsChain& operator=(SoxEffectsChain&& other) = delete;
~SoxEffectsChain();
void run();
void addInputTensor(
py::array* waveform,
int64_t sample_rate,
bool channels_first);
void addInputFile(sox_format_t* sf);
void addOutputBuffer(std::vector<sox_sample_t>* output_buffer);
void addOutputFile(sox_format_t* sf);
void addEffect(const std::vector<std::string> effect);
int64_t getOutputNumChannels();
int64_t getOutputSampleRate();
};
class SoxEffectsChainPyBind : public SoxEffectsChain {
using SoxEffectsChain::SoxEffectsChain;
public:
void addInputFileObj(
sox_format_t* sf,
char* buffer,
uint64_t buffer_size,
py::object* fileobj);
void addOutputFileObj(
sox_format_t* sf,
char** buffer,
size_t* buffer_size,
py::object* fileobj);
};
} // namespace paddleaudio::sox_effects_chain

@ -2,20 +2,25 @@
// All rights reserved. // All rights reserved.
#include "paddlespeech/audio/src/pybind/sox/io.h" #include "paddlespeech/audio/src/pybind/sox/io.h"
#include "paddlespeech/audio/src/pybind/sox/effects.h"
#include "paddlespeech/audio/src/pybind/sox/types.h"
#include "paddlespeech/audio/src/pybind/sox/effects_chain.h"
#include "paddlespeech/audio/src/pybind/sox/utils.h" #include "paddlespeech/audio/src/pybind/sox/utils.h"
#include "paddlespeech/audio/src/optional/optional.hpp"
using namespace paddleaudio::sox_utils; using namespace paddleaudio::sox_utils;
namespace paddleaudio { namespace paddleaudio {
namespace sox_io { namespace sox_io {
auto get_info_file(const std::string &path, const std::string &format) auto get_info_file(const std::string &path,
const tl::optional<std::string> &format)
-> std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> { -> std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> {
SoxFormat sf( SoxFormat sf(
sox_open_read(path.data(), sox_open_read(path.data(),
/*signal=*/nullptr, /*signal=*/nullptr,
/*encoding=*/nullptr, /*encoding=*/nullptr,
/*filetype=*/format.empty() ? nullptr : format.data())); /*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
validate_input_file(sf, path); validate_input_file(sf, path);
@ -28,7 +33,37 @@ auto get_info_file(const std::string &path, const std::string &format)
get_encoding(sf->encoding.encoding)); get_encoding(sf->encoding.encoding));
} }
auto get_info_fileobj(py::object fileobj, const std::string &format) std::vector<std::vector<std::string>> get_effects(
const tl::optional<int64_t>& frame_offset,
const tl::optional<int64_t>& num_frames) {
const auto offset = frame_offset.value_or(0);
if (offset < 0) {
throw std::runtime_error(
"Invalid argument: frame_offset must be non-negative.");
}
const auto frames = num_frames.value_or(-1);
if (frames == 0 || frames < -1) {
throw std::runtime_error(
"Invalid argument: num_frames must be -1 or greater than 0.");
}
std::vector<std::vector<std::string>> effects;
if (frames != -1) {
std::ostringstream os_offset, os_frames;
os_offset << offset << "s";
os_frames << "+" << frames << "s";
effects.emplace_back(
std::vector<std::string>{"trim", os_offset.str(), os_frames.str()});
} else if (offset != 0) {
std::ostringstream os_offset;
os_offset << offset << "s";
effects.emplace_back(std::vector<std::string>{"trim", os_offset.str()});
}
return effects;
}
auto get_info_fileobj(py::object fileobj,
const tl::optional<std::string> &format)
-> std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> { -> std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> {
const auto capacity = [&]() { const auto capacity = [&]() {
const auto bufsiz = get_buffer_size(); const auto bufsiz = get_buffer_size();
@ -47,7 +82,7 @@ auto get_info_fileobj(py::object fileobj, const std::string &format)
buf_size, buf_size,
/*signal=*/nullptr, /*signal=*/nullptr,
/*encoding=*/nullptr, /*encoding=*/nullptr,
/*filetype=*/format.empty() ? nullptr : format.data())); /*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
// In case of streamed data, length can be 0 // In case of streamed data, length can be 0
validate_input_memfile(sf); validate_input_memfile(sf);
@ -60,5 +95,186 @@ auto get_info_fileobj(py::object fileobj, const std::string &format)
get_encoding(sf->encoding.encoding)); get_encoding(sf->encoding.encoding));
} }
tl::optional<std::tuple<py::array, int64_t>> load_audio_fileobj(
py::object fileobj,
const tl::optional<int64_t>& frame_offset,
const tl::optional<int64_t>& num_frames,
tl::optional<bool> normalize,
tl::optional<bool> channels_first,
const tl::optional<std::string>& format) {
auto effects = get_effects(frame_offset, num_frames);
return paddleaudio::sox_effects::apply_effects_fileobj(
std::move(fileobj), effects, normalize, channels_first, std::move(format));
}
tl::optional<std::tuple<py::array, int64_t>> load_audio_file(
const std::string& path,
const tl::optional<int64_t>& frame_offset,
const tl::optional<int64_t>& num_frames,
tl::optional<bool> normalize,
tl::optional<bool> channels_first,
const tl::optional<std::string>& format) {
auto effects = get_effects(frame_offset, num_frames);
return paddleaudio::sox_effects::apply_effects_file(
path, effects, normalize, channels_first, format);
}
void save_audio_file(const std::string& path,
py::array tensor,
int64_t sample_rate,
bool channels_first,
tl::optional<double> compression,
tl::optional<std::string> format,
tl::optional<std::string> encoding,
tl::optional<int64_t> bits_per_sample) {
validate_input_tensor(tensor);
const auto filetype = [&]() {
if (format.has_value()) return format.value();
return get_filetype(path);
}();
if (filetype == "amr-nb") {
const auto num_channels = tensor.shape(channels_first ? 0 : 1);
//TORCH_CHECK(num_channels == 1,
// "amr-nb format only supports single channel audio.");
assert(num_channels == 1);
} else if (filetype == "htk") {
const auto num_channels = tensor.shape(channels_first ? 0 : 1);
// TORCH_CHECK(num_channels == 1,
// "htk format only supports single channel audio.");
assert(num_channels == 1);
} else if (filetype == "gsm") {
const auto num_channels = tensor.shape(channels_first ? 0 : 1);
assert(num_channels == 1);
assert(sample_rate == 8000);
//TORCH_CHECK(num_channels == 1,
// "gsm format only supports single channel audio.");
//TORCH_CHECK(sample_rate == 8000,
// "gsm format only supports a sampling rate of 8kHz.");
}
const auto signal_info =
get_signalinfo(&tensor, sample_rate, filetype, channels_first);
const auto encoding_info = get_encodinginfo_for_save(
filetype, tensor.dtype(), compression, encoding, bits_per_sample);
SoxFormat sf(sox_open_write(path.c_str(),
&signal_info,
&encoding_info,
/*filetype=*/filetype.c_str(),
/*oob=*/nullptr,
/*overwrite_permitted=*/nullptr));
if (static_cast<sox_format_t*>(sf) == nullptr) {
throw std::runtime_error(
"Error saving audio file: failed to open file " + path);
}
paddleaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/get_tensor_encodinginfo(tensor.dtype()),
/*output_encoding=*/sf->encoding);
chain.addInputTensor(&tensor, sample_rate, channels_first);
chain.addOutputFile(sf);
chain.run();
}
namespace {
// helper class to automatically release buffer, to be used by
// save_audio_fileobj
struct AutoReleaseBuffer {
char* ptr;
size_t size;
AutoReleaseBuffer() : ptr(nullptr), size(0) {}
AutoReleaseBuffer(const AutoReleaseBuffer& other) = delete;
AutoReleaseBuffer(AutoReleaseBuffer&& other) = delete;
auto operator=(const AutoReleaseBuffer& other) -> AutoReleaseBuffer& = delete;
auto operator=(AutoReleaseBuffer&& other) -> AutoReleaseBuffer& = delete;
~AutoReleaseBuffer() {
if (ptr) {
free(ptr);
}
}
};
} // namespace
void save_audio_fileobj(
py::object fileobj,
py::array tensor,
int64_t sample_rate,
bool channels_first,
tl::optional<double> compression,
tl::optional<std::string> format,
tl::optional<std::string> encoding,
tl::optional<int64_t> bits_per_sample) {
if (!format.has_value()) {
throw std::runtime_error(
"`format` is required when saving to file object.");
}
const auto filetype = format.value();
if (filetype == "amr-nb") {
const auto num_channels = tensor.shape(channels_first ? 0 : 1);
if (num_channels != 1) {
throw std::runtime_error(
"amr-nb format only supports single channel audio.");
}
} else if (filetype == "htk") {
const auto num_channels = tensor.shape(channels_first ? 0 : 1);
if (num_channels != 1) {
throw std::runtime_error(
"htk format only supports single channel audio.");
}
} else if (filetype == "gsm") {
const auto num_channels = tensor.shape(channels_first ? 0 : 1);
if (num_channels != 1) {
throw std::runtime_error(
"gsm format only supports single channel audio.");
}
if (sample_rate != 8000) {
throw std::runtime_error(
"gsm format only supports a sampling rate of 8kHz.");
}
}
const auto signal_info =
get_signalinfo(&tensor, sample_rate, filetype, channels_first);
const auto encoding_info = get_encodinginfo_for_save(
filetype,
tensor.dtype(),
compression,
std::move(encoding),
bits_per_sample);
AutoReleaseBuffer buffer;
SoxFormat sf(sox_open_memstream_write(
&buffer.ptr,
&buffer.size,
&signal_info,
&encoding_info,
filetype.c_str(),
/*oob=*/nullptr));
if (static_cast<sox_format_t*>(sf) == nullptr) {
throw std::runtime_error(
"Error saving audio file: failed to open memory stream.");
}
paddleaudio::sox_effects_chain::SoxEffectsChainPyBind chain(
/*input_encoding=*/get_tensor_encodinginfo(tensor.dtype()),
/*output_encoding=*/sf->encoding);
chain.addInputTensor(&tensor, sample_rate, channels_first);
chain.addOutputFileObj(sf, &buffer.ptr, &buffer.size, &fileobj);
chain.run();
// Closing the sox_format_t is necessary for flushing the last chunk to the
// buffer
sf.close();
fileobj.attr("write")(py::bytes(buffer.ptr, buffer.size));
}
} // namespace paddleaudio } // namespace paddleaudio
} // namespace sox_io } // namespace sox_io

@ -1,21 +1,63 @@
// Copyright (c) 2017 Facebook Inc. (Soumith Chintala), // Copyright (c) 2017 Facebook Inc. (Soumith Chintala),
// All rights reserved. // All rights reserved.
#ifndef PADDLEAUDIO_PYBIND_SOX_IO_H #pragma once
#define PADDLEAUDIO_PYBIND_SOX_IO_H
#include "paddlespeech/audio/src/pybind/sox/utils.h" #include "paddlespeech/audio/src/pybind/sox/utils.h"
namespace py = pybind11;
namespace paddleaudio { namespace paddleaudio {
namespace sox_io { namespace sox_io {
auto get_info_file(const std::string &path, const std::string &format) auto get_info_file(const std::string &path,
const tl::optional<std::string> &format)
-> std::tuple<int64_t, int64_t, int64_t, int64_t, std::string>; -> std::tuple<int64_t, int64_t, int64_t, int64_t, std::string>;
auto get_info_fileobj(py::object fileobj, const std::string &format) auto get_info_fileobj(py::object fileobj,
const tl::optional<std::string> &format)
-> std::tuple<int64_t, int64_t, int64_t, int64_t, std::string>; -> std::tuple<int64_t, int64_t, int64_t, int64_t, std::string>;
tl::optional<std::tuple<py::array, int64_t>> load_audio_fileobj(
py::object fileobj,
const tl::optional<int64_t>& frame_offset,
const tl::optional<int64_t>& num_frames,
tl::optional<bool> normalize,
tl::optional<bool> channels_first,
const tl::optional<std::string>& format);
void save_audio_fileobj(
py::object fileobj,
py::array tensor,
int64_t sample_rate,
bool channels_first,
tl::optional<double> compression,
tl::optional<std::string> format,
tl::optional<std::string> encoding,
tl::optional<int64_t> bits_per_sample);
auto get_effects(const tl::optional<int64_t>& frame_offset,
const tl::optional<int64_t>& num_frames)
-> std::vector<std::vector<std::string>>;
tl::optional<std::tuple<py::array, int64_t>> load_audio_file(
const std::string& path,
const tl::optional<int64_t>& frame_offset,
const tl::optional<int64_t>& num_frames,
tl::optional<bool> normalize,
tl::optional<bool> channels_first,
const tl::optional<std::string>& format);
void save_audio_file(const std::string& path,
py::array tensor,
int64_t sample_rate,
bool channels_first,
tl::optional<double> compression,
tl::optional<std::string> format,
tl::optional<std::string> encoding,
tl::optional<int64_t> bits_per_sample);
} // namespace paddleaudio } // namespace paddleaudio
} // namespace sox_io } // namespace sox_io
#endif

@ -0,0 +1,143 @@
//code is from: https://github.com/pytorch/audio/blob/main/torchaudio/csrc/sox/types.cpp
#include "paddlespeech/audio/src/pybind/sox/types.h"
#include <ostream>
#include <sstream>
namespace paddleaudio {
namespace sox_utils {
Format get_format_from_string(const std::string& format) {
if (format == "wav")
return Format::WAV;
if (format == "mp3")
return Format::MP3;
if (format == "flac")
return Format::FLAC;
if (format == "ogg" || format == "vorbis")
return Format::VORBIS;
if (format == "amr-nb")
return Format::AMR_NB;
if (format == "amr-wb")
return Format::AMR_WB;
if (format == "amb")
return Format::AMB;
if (format == "sph")
return Format::SPHERE;
if (format == "htk")
return Format::HTK;
if (format == "gsm")
return Format::GSM;
std::ostringstream stream;
stream << "Internal Error: unexpected format value: " << format;
throw std::runtime_error(stream.str());
}
std::string to_string(Encoding v) {
switch (v) {
case Encoding::UNKNOWN:
return "UNKNOWN";
case Encoding::PCM_SIGNED:
return "PCM_S";
case Encoding::PCM_UNSIGNED:
return "PCM_U";
case Encoding::PCM_FLOAT:
return "PCM_F";
case Encoding::FLAC:
return "FLAC";
case Encoding::ULAW:
return "ULAW";
case Encoding::ALAW:
return "ALAW";
case Encoding::MP3:
return "MP3";
case Encoding::VORBIS:
return "VORBIS";
case Encoding::AMR_WB:
return "AMR_WB";
case Encoding::AMR_NB:
return "AMR_NB";
case Encoding::OPUS:
return "OPUS";
default:
throw std::runtime_error("Internal Error: unexpected encoding.");
}
}
Encoding get_encoding_from_option(const tl::optional<std::string> encoding) {
if (!encoding.has_value())
return Encoding::NOT_PROVIDED;
std::string v = encoding.value();
if (v == "PCM_S")
return Encoding::PCM_SIGNED;
if (v == "PCM_U")
return Encoding::PCM_UNSIGNED;
if (v == "PCM_F")
return Encoding::PCM_FLOAT;
if (v == "ULAW")
return Encoding::ULAW;
if (v == "ALAW")
return Encoding::ALAW;
std::ostringstream stream;
stream << "Internal Error: unexpected encoding value: " << v;
throw std::runtime_error(stream.str());
}
BitDepth get_bit_depth_from_option(const tl::optional<int64_t> bit_depth) {
if (!bit_depth.has_value())
return BitDepth::NOT_PROVIDED;
int64_t v = bit_depth.value();
switch (v) {
case 8:
return BitDepth::B8;
case 16:
return BitDepth::B16;
case 24:
return BitDepth::B24;
case 32:
return BitDepth::B32;
case 64:
return BitDepth::B64;
default: {
std::ostringstream s;
s << "Internal Error: unexpected bit depth value: " << v;
throw std::runtime_error(s.str());
}
}
}
std::string get_encoding(sox_encoding_t encoding) {
switch (encoding) {
case SOX_ENCODING_UNKNOWN:
return "UNKNOWN";
case SOX_ENCODING_SIGN2:
return "PCM_S";
case SOX_ENCODING_UNSIGNED:
return "PCM_U";
case SOX_ENCODING_FLOAT:
return "PCM_F";
case SOX_ENCODING_FLAC:
return "FLAC";
case SOX_ENCODING_ULAW:
return "ULAW";
case SOX_ENCODING_ALAW:
return "ALAW";
case SOX_ENCODING_MP3:
return "MP3";
case SOX_ENCODING_VORBIS:
return "VORBIS";
case SOX_ENCODING_AMR_WB:
return "AMR_WB";
case SOX_ENCODING_AMR_NB:
return "AMR_NB";
case SOX_ENCODING_OPUS:
return "OPUS";
case SOX_ENCODING_GSM:
return "GSM";
default:
return "UNKNOWN";
}
}
} // namespace sox_utils
} // namespace paddleaudio

@ -0,0 +1,58 @@
//code is from: https://github.com/pytorch/audio/blob/main/torchaudio/csrc/sox/types.h
#pragma once
#include <sox.h>
#include "paddlespeech/audio/src/optional/optional.hpp"
namespace paddleaudio {
namespace sox_utils {
enum class Format {
WAV,
MP3,
FLAC,
VORBIS,
AMR_NB,
AMR_WB,
AMB,
SPHERE,
GSM,
HTK,
};
Format get_format_from_string(const std::string& format);
enum class Encoding {
NOT_PROVIDED,
UNKNOWN,
PCM_SIGNED,
PCM_UNSIGNED,
PCM_FLOAT,
FLAC,
ULAW,
ALAW,
MP3,
VORBIS,
AMR_WB,
AMR_NB,
OPUS,
};
std::string to_string(Encoding v);
Encoding get_encoding_from_option(const tl::optional<std::string> encoding);
enum class BitDepth : unsigned {
NOT_PROVIDED = 0,
B8 = 8,
B16 = 16,
B24 = 24,
B32 = 32,
B64 = 64,
};
BitDepth get_bit_depth_from_option(const tl::optional<int64_t> bit_depth);
std::string get_encoding(sox_encoding_t encoding);
} // namespace sox_utils
} // namespace paddleaudio

@ -1,13 +1,554 @@
// Copyright (c) 2017 Facebook Inc. (Soumith Chintala), // Copyright (c) 2017 Facebook Inc. (Soumith Chintala),
// All rights reserved. // All rights reserved.
#include <sox.h>
#include "paddlespeech/audio/src/pybind/sox/utils.h" #include "paddlespeech/audio/src/pybind/sox/utils.h"
#include "paddlespeech/audio/src/pybind/sox/types.h"
#include <sstream> #include <sstream>
namespace paddleaudio { namespace paddleaudio {
namespace sox_utils { namespace sox_utils {
auto read_fileobj(py::object *fileobj, const uint64_t size, char *buffer)
-> uint64_t {
uint64_t num_read = 0;
while (num_read < size) {
auto request = size - num_read;
auto chunk = static_cast<std::string>(
static_cast<py::bytes>(fileobj->attr("read")(request)));
auto chunk_len = chunk.length();
if (chunk_len == 0) {
break;
}
if (chunk_len > request) {
std::ostringstream message;
message
<< "Requested up to " << request << " bytes but, "
<< "received " << chunk_len << " bytes. "
<< "The given object does not confirm to read protocol of file "
"object.";
throw std::runtime_error(message.str());
}
memcpy(buffer, chunk.data(), chunk_len);
buffer += chunk_len;
num_read += chunk_len;
}
return num_read;
}
void set_seed(const int64_t seed) {
sox_get_globals()->ranqd1 = static_cast<sox_int32_t>(seed);
}
void set_verbosity(const int64_t verbosity) {
sox_get_globals()->verbosity = static_cast<unsigned>(verbosity);
}
void set_use_threads(const bool use_threads) {
sox_get_globals()->use_threads = static_cast<sox_bool>(use_threads);
}
void set_buffer_size(const int64_t buffer_size) {
sox_get_globals()->bufsiz = static_cast<size_t>(buffer_size);
}
int64_t get_buffer_size() {
return sox_get_globals()->bufsiz;
}
std::vector<std::vector<std::string>> list_effects() {
std::vector<std::vector<std::string>> effects;
for (const sox_effect_fn_t* fns = sox_get_effect_fns(); *fns; ++fns) {
const sox_effect_handler_t* handler = (*fns)();
if (handler && handler->name) {
if (UNSUPPORTED_EFFECTS.find(handler->name) ==
UNSUPPORTED_EFFECTS.end()) {
effects.emplace_back(std::vector<std::string>{
handler->name,
handler->usage ? std::string(handler->usage) : std::string("")});
}
}
}
return effects;
}
std::vector<std::string> list_write_formats() {
std::vector<std::string> formats;
for (const sox_format_tab_t* fns = sox_get_format_fns(); fns->fn; ++fns) {
const sox_format_handler_t* handler = fns->fn();
for (const char* const* names = handler->names; *names; ++names) {
if (!strchr(*names, '/') && handler->write)
formats.emplace_back(*names);
}
}
return formats;
}
std::vector<std::string> list_read_formats() {
std::vector<std::string> formats;
for (const sox_format_tab_t* fns = sox_get_format_fns(); fns->fn; ++fns) {
const sox_format_handler_t* handler = fns->fn();
for (const char* const* names = handler->names; *names; ++names) {
if (!strchr(*names, '/') && handler->read)
formats.emplace_back(*names);
}
}
return formats;
}
SoxFormat::SoxFormat(sox_format_t* fd) noexcept : fd_(fd) {}
SoxFormat::~SoxFormat() {
close();
}
sox_format_t* SoxFormat::operator->() const noexcept {
return fd_;
}
SoxFormat::operator sox_format_t*() const noexcept {
return fd_;
}
void SoxFormat::close() {
if (fd_ != nullptr) {
sox_close(fd_);
fd_ = nullptr;
}
}
void validate_input_file(const SoxFormat& sf, const std::string& path) {
if (static_cast<sox_format_t*>(sf) == nullptr) {
throw std::runtime_error(
"Error loading audio file: failed to open file " + path);
}
if (sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
throw std::runtime_error("Error loading audio file: unknown encoding.");
}
}
void validate_input_memfile(const SoxFormat &sf) {
return validate_input_file(sf, "<in memory buffer>");
}
void validate_input_tensor(const py::array tensor) {
if (tensor.ndim() != 2) {
throw std::runtime_error("Input tensor has to be 2D.");
}
char dtype = tensor.dtype().char_();
bool flag = (dtype == 'f') || (dtype == 'd') || (dtype == 'l') || (dtype == 'i');
if (flag == false) {
throw std::runtime_error(
"Input tensor has to be one of float32, int32, int16 or uint8 type.");
}
}
py::dtype get_dtype(
const sox_encoding_t encoding,
const unsigned precision) {
switch (encoding) {
case SOX_ENCODING_UNSIGNED: // 8-bit PCM WAV
return py::dtype('u1');
case SOX_ENCODING_SIGN2: // 16-bit, 24-bit, or 32-bit PCM WAV
switch (precision) {
case 16:
return py::dtype("i2");
case 24: // Cast 24-bit to 32-bit.
case 32:
return py::dtype('i');
default:
throw std::runtime_error(
"Only 16, 24, and 32 bits are supported for signed PCM.");
}
default:
// default to float32 for the other formats, including
// 32-bit flaoting-point WAV,
// MP3,
// FLAC,
// VORBIS etc...
return py::dtype("f");
}
}
py::array convert_to_tensor(
sox_sample_t* buffer,
const int32_t num_samples,
const int32_t num_channels,
const py::dtype dtype,
const bool normalize,
const bool channels_first) {
// todo refector later(SGoat)
py::array t;
uint64_t dummy = 0;
SOX_SAMPLE_LOCALS;
int32_t num_rows = num_samples / num_channels;
if (normalize || dtype.char_() == 'f') {
t = py::array(dtype, {num_rows, num_channels});
auto ptr = (float*)t.mutable_data(0, 0);
for (int32_t i = 0; i < num_samples; ++i) {
ptr[i] = SOX_SAMPLE_TO_FLOAT_32BIT(buffer[i], dummy);
}
if (channels_first) {
py::array t2 = py::array(dtype, {num_channels, num_rows});
for (int32_t row_idx = 0; row_idx < num_channels; ++row_idx) {
for (int32_t col_idx = 0; col_idx < num_rows; ++col_idx)
*(float*)t2.mutable_data(row_idx, col_idx) = *(float*)t.data(col_idx, row_idx);
}
return t2;
}
} else if (dtype.char_() == 'i') {
t = py::array(dtype, {num_rows, num_channels});
auto ptr = (int*)t.mutable_data(0, 0);
for (int32_t i = 0; i < num_samples; ++i) {
ptr[i] = buffer[i];
}
if (channels_first) {
py::array t2 = py::array(dtype, {num_channels, num_rows});
for (int32_t row_idx = 0; row_idx < num_channels; ++row_idx) {
for (int32_t col_idx = 0; col_idx < num_rows; ++col_idx)
*(int*)t2.mutable_data(row_idx, col_idx) = *(int*)t.data(col_idx, row_idx);
}
return t2;
}
} else if (dtype.char_() == 'h') { // int16
t = py::array(dtype, {num_rows, num_channels});
auto ptr = (int16_t*)t.mutable_data(0, 0);
for (int32_t i = 0; i < num_samples; ++i) {
ptr[i] = SOX_SAMPLE_TO_SIGNED_16BIT(buffer[i], dummy);
}
if (channels_first) {
py::array t2 = py::array(dtype, {num_channels, num_rows});
for (int32_t row_idx = 0; row_idx < num_channels; ++row_idx) {
for (int32_t col_idx = 0; col_idx < num_rows; ++col_idx)
*(int16_t*)t2.mutable_data(row_idx, col_idx) = *(int16_t*)t.data(col_idx, row_idx);
}
return t2;
}
} else if (dtype.char_() == 'b') {
//t = torch::empty({num_samples / num_channels, num_channels}, torch::kUInt8);
t = py::array(dtype, {num_rows, num_channels});
auto ptr = (uint8_t*)t.mutable_data(0,0);
for (int32_t i = 0; i < num_samples; ++i) {
ptr[i] = SOX_SAMPLE_TO_UNSIGNED_8BIT(buffer[i], dummy);
}
if (channels_first) {
py::array t2 = py::array(dtype, {num_channels, num_rows});
for (int32_t row_idx = 0; row_idx < num_channels; ++row_idx) {
for (int32_t col_idx = 0; col_idx < num_rows; ++col_idx)
*(uint8_t*)t2.mutable_data(row_idx, col_idx) = *(uint8_t*)t.data(col_idx, row_idx);
}
return t2;
}
} else {
throw std::runtime_error("Unsupported dtype.");
}
return t;
}
const std::string get_filetype(const std::string path) {
std::string ext = path.substr(path.find_last_of(".") + 1);
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
return ext;
}
namespace {
std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav(
const std::string format,
py::dtype dtype,
const Encoding& encoding,
const BitDepth& bits_per_sample) {
switch (encoding) {
case Encoding::NOT_PROVIDED:
switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED:
switch (dtype.num()) {
case 11: // float32 numpy dtype num
return std::make_tuple<>(SOX_ENCODING_FLOAT, 32);
case 5: // int numpy dtype num
return std::make_tuple<>(SOX_ENCODING_SIGN2, 32);
case 3: // int16 numpy
return std::make_tuple<>(SOX_ENCODING_SIGN2, 16);
case 1: // byte numpy
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
default:
throw std::runtime_error("Internal Error: Unexpected dtype.");
}
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
default:
return std::make_tuple<>(
SOX_ENCODING_SIGN2, static_cast<unsigned>(bits_per_sample));
}
case Encoding::PCM_SIGNED:
switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED:
return std::make_tuple<>(SOX_ENCODING_SIGN2, 32);
case BitDepth::B8:
throw std::runtime_error(
format + " does not support 8-bit signed PCM encoding.");
default:
return std::make_tuple<>(
SOX_ENCODING_SIGN2, static_cast<unsigned>(bits_per_sample));
}
case Encoding::PCM_UNSIGNED:
switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
default:
throw std::runtime_error(
format + " only supports 8-bit for unsigned PCM encoding.");
}
case Encoding::PCM_FLOAT:
switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B32:
return std::make_tuple<>(SOX_ENCODING_FLOAT, 32);
case BitDepth::B64:
return std::make_tuple<>(SOX_ENCODING_FLOAT, 64);
default:
throw std::runtime_error(
format +
" only supports 32-bit or 64-bit for floating-point PCM encoding.");
}
case Encoding::ULAW:
switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_ULAW, 8);
default:
throw std::runtime_error(
format + " only supports 8-bit for mu-law encoding.");
}
case Encoding::ALAW:
switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_ALAW, 8);
default:
throw std::runtime_error(
format + " only supports 8-bit for a-law encoding.");
}
default:
throw std::runtime_error(
format + " does not support encoding: " + to_string(encoding));
}
}
std::tuple<sox_encoding_t, unsigned> get_save_encoding(
const std::string& format,
const py::dtype dtype,
const tl::optional<std::string> encoding,
const tl::optional<int64_t> bits_per_sample) {
const Format fmt = get_format_from_string(format);
const Encoding enc = get_encoding_from_option(encoding);
const BitDepth bps = get_bit_depth_from_option(bits_per_sample);
switch (fmt) {
case Format::WAV:
case Format::AMB:
return get_save_encoding_for_wav(format, dtype, enc, bps);
case Format::MP3:
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("mp3 does not support `encoding` option.");
if (bps != BitDepth::NOT_PROVIDED)
throw std::runtime_error(
"mp3 does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_MP3, 16);
case Format::HTK:
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("htk does not support `encoding` option.");
if (bps != BitDepth::NOT_PROVIDED)
throw std::runtime_error(
"htk does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_SIGN2, 16);
case Format::VORBIS:
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("vorbis does not support `encoding` option.");
if (bps != BitDepth::NOT_PROVIDED)
throw std::runtime_error(
"vorbis does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_VORBIS, 16);
case Format::AMR_NB:
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("amr-nb does not support `encoding` option.");
if (bps != BitDepth::NOT_PROVIDED)
throw std::runtime_error(
"amr-nb does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_AMR_NB, 16);
case Format::FLAC:
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("flac does not support `encoding` option.");
switch (bps) {
case BitDepth::B32:
case BitDepth::B64:
throw std::runtime_error(
"flac does not support `bits_per_sample` larger than 24.");
default:
return std::make_tuple<>(
SOX_ENCODING_FLAC, static_cast<unsigned>(bps));
}
case Format::SPHERE:
switch (enc) {
case Encoding::NOT_PROVIDED:
case Encoding::PCM_SIGNED:
switch (bps) {
case BitDepth::NOT_PROVIDED:
return std::make_tuple<>(SOX_ENCODING_SIGN2, 32);
default:
return std::make_tuple<>(
SOX_ENCODING_SIGN2, static_cast<unsigned>(bps));
}
case Encoding::PCM_UNSIGNED:
throw std::runtime_error(
"sph does not support unsigned integer PCM.");
case Encoding::PCM_FLOAT:
throw std::runtime_error("sph does not support floating point PCM.");
case Encoding::ULAW:
switch (bps) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_ULAW, 8);
default:
throw std::runtime_error(
"sph only supports 8-bit for mu-law encoding.");
}
case Encoding::ALAW:
switch (bps) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_ALAW, 8);
default:
return std::make_tuple<>(
SOX_ENCODING_ALAW, static_cast<unsigned>(bps));
}
default:
throw std::runtime_error(
"sph does not support encoding: " + encoding.value());
}
case Format::GSM:
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("gsm does not support `encoding` option.");
if (bps != BitDepth::NOT_PROVIDED)
throw std::runtime_error(
"gsm does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_GSM, 16);
default:
throw std::runtime_error("Unsupported format: " + format);
}
}
unsigned get_precision(const std::string filetype, py::dtype dtype) {
if (filetype == "mp3")
return SOX_UNSPEC;
if (filetype == "flac")
return 24;
if (filetype == "ogg" || filetype == "vorbis")
return SOX_UNSPEC;
if (filetype == "wav" || filetype == "amb") {
switch (dtype.num()) {
case 1: // byte in numpy dype num
return 8;
case 3: // short, in numpy dtype num
return 16;
case 5: // int, numpy dtype
return 32;
case 11: // float, numpy dtype
return 32;
default:
throw std::runtime_error("Unsupported dtype.");
}
}
if (filetype == "sph")
return 32;
if (filetype == "amr-nb") {
return 16;
}
if (filetype == "gsm") {
return 16;
}
if (filetype == "htk") {
return 16;
}
throw std::runtime_error("Unsupported file type: " + filetype);
}
} // namespace
sox_signalinfo_t get_signalinfo(
const py::array* waveform,
const int64_t sample_rate,
const std::string filetype,
const bool channels_first) {
return sox_signalinfo_t{
/*rate=*/static_cast<sox_rate_t>(sample_rate),
/*channels=*/
static_cast<unsigned>(waveform->shape(channels_first ? 0 : 1)),
/*precision=*/get_precision(filetype, waveform->dtype()),
/*length=*/static_cast<uint64_t>(waveform->size())};
}
sox_encodinginfo_t get_tensor_encodinginfo(py::dtype dtype) {
sox_encoding_t encoding = [&]() {
switch (dtype.num()) {
case 1: // byte
return SOX_ENCODING_UNSIGNED;
case 3: // short
return SOX_ENCODING_SIGN2;
case 5: // int32
return SOX_ENCODING_SIGN2;
case 11: // float
return SOX_ENCODING_FLOAT;
default:
throw std::runtime_error("Unsupported dtype.");
}
}();
unsigned bits_per_sample = [&]() {
switch (dtype.num()) {
case 1: // byte
return 8;
case 3: //short
return 16;
case 5: // int32
return 32;
case 11: // float
return 32;
default:
throw std::runtime_error("Unsupported dtype.");
}
}();
return sox_encodinginfo_t{
/*encoding=*/encoding,
/*bits_per_sample=*/bits_per_sample,
/*compression=*/HUGE_VAL,
/*reverse_bytes=*/sox_option_default,
/*reverse_nibbles=*/sox_option_default,
/*reverse_bits=*/sox_option_default,
/*opposite_endian=*/sox_false};
}
sox_encodinginfo_t get_encodinginfo_for_save(
const std::string& format,
const py::dtype dtype,
const tl::optional<double> compression,
const tl::optional<std::string> encoding,
const tl::optional<int64_t> bits_per_sample) {
auto enc = get_save_encoding(format, dtype, encoding, bits_per_sample);
return sox_encodinginfo_t{
/*encoding=*/std::get<0>(enc),
/*bits_per_sample=*/std::get<1>(enc),
/*compression=*/compression.value_or(HUGE_VAL),
/*reverse_bytes=*/sox_option_default,
/*reverse_nibbles=*/sox_option_default,
/*reverse_bits=*/sox_option_default,
/*opposite_endian=*/sox_false};
}
/*
SoxFormat::SoxFormat(sox_format_t *fd) noexcept : fd_(fd) {} SoxFormat::SoxFormat(sox_format_t *fd) noexcept : fd_(fd) {}
SoxFormat::~SoxFormat() { close(); } SoxFormat::~SoxFormat() { close(); }
@ -96,6 +637,6 @@ std::string get_encoding(sox_encoding_t encoding) {
return "UNKNOWN"; return "UNKNOWN";
} }
} }
*/
} // namespace paddleaudio } // namespace paddleaudio
} // namespace sox_utils } // namespace sox_utils

@ -4,39 +4,113 @@
#pragma once #pragma once
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <sox.h> #include <sox.h>
#include "paddlespeech/audio/src/optional/optional.hpp"
namespace py = pybind11; namespace py = pybind11;
namespace paddleaudio { namespace paddleaudio {
namespace sox_utils { namespace sox_utils {
auto read_fileobj(py::object *fileobj, uint64_t size, char *buffer) -> uint64_t;
void set_seed(const int64_t seed);
void set_verbosity(const int64_t verbosity);
void set_use_threads(const bool use_threads);
void set_buffer_size(const int64_t buffer_size);
int64_t get_buffer_size();
std::vector<std::vector<std::string>> list_effects();
std::vector<std::string> list_read_formats();
std::vector<std::string> list_write_formats();
////////////////////////////////////////////////////////////////////////////////
// Utilities for sox_io / sox_effects implementations
////////////////////////////////////////////////////////////////////////////////
const std::unordered_set<std::string> UNSUPPORTED_EFFECTS =
{"input", "output", "spectrogram", "noiseprof", "noisered", "splice"};
/// helper class to automatically close sox_format_t* /// helper class to automatically close sox_format_t*
struct SoxFormat { struct SoxFormat {
explicit SoxFormat(sox_format_t *fd) noexcept; explicit SoxFormat(sox_format_t* fd) noexcept;
SoxFormat(const SoxFormat &other) = delete; SoxFormat(const SoxFormat& other) = delete;
SoxFormat(SoxFormat &&other) = delete; SoxFormat(SoxFormat&& other) = delete;
SoxFormat &operator=(const SoxFormat &other) = delete; SoxFormat& operator=(const SoxFormat& other) = delete;
SoxFormat &operator=(SoxFormat &&other) = delete; SoxFormat& operator=(SoxFormat&& other) = delete;
~SoxFormat(); ~SoxFormat();
sox_format_t *operator->() const noexcept; sox_format_t* operator->() const noexcept;
operator sox_format_t *() const noexcept; operator sox_format_t*() const noexcept;
void close();
private:
sox_format_t *fd_;
};
auto read_fileobj(py::object *fileobj, uint64_t size, char *buffer) -> uint64_t; void close();
int64_t get_buffer_size(); private:
sox_format_t* fd_;
};
void validate_input_file(const SoxFormat &sf, const std::string &path); ///
/// Verify that input Tensor is 2D, CPU and either uin8, int16, int32 or float32
void validate_input_tensor(const py::array);
void validate_input_file(const SoxFormat& sf, const std::string& path);
void validate_input_memfile(const SoxFormat &sf); void validate_input_memfile(const SoxFormat &sf);
///
/// Get target dtype for the given encoding and precision.
py::dtype get_dtype(
const sox_encoding_t encoding,
const unsigned precision);
///
/// Convert sox_sample_t buffer to uint8/int16/int32/float32 Tensor
/// NOTE: This function might modify the values in the input buffer to
/// reduce the number of memory copy.
/// @param buffer Pointer to buffer that contains audio data.
/// @param num_samples The number of samples to read.
/// @param num_channels The number of channels. Used to reshape the resulting
/// Tensor.
/// @param dtype Target dtype. Determines the output dtype and value range in
/// conjunction with normalization.
/// @param noramlize Perform normalization. Only effective when dtype is not
/// kFloat32. When effective, the output tensor is kFloat32 type and value range
/// is [-1.0, 1.0]
/// @param channels_first When True, output Tensor has shape of [num_channels,
/// num_frames].
py::array convert_to_tensor(
sox_sample_t* buffer,
const int32_t num_samples,
const int32_t num_channels,
const py::dtype dtype,
const bool normalize,
const bool channels_first);
/// Extract extension from file path
const std::string get_filetype(const std::string path);
/// Get sox_signalinfo_t for passing a py::array object.
sox_signalinfo_t get_signalinfo(
const py::array* waveform,
const int64_t sample_rate,
const std::string filetype,
const bool channels_first);
/// Get sox_encodinginfo_t for Tensor I/O
sox_encodinginfo_t get_tensor_encodinginfo(const py::dtype dtype);
std::string get_encoding(sox_encoding_t encoding); /// Get sox_encodinginfo_t for saving to file/file object
sox_encodinginfo_t get_encodinginfo_for_save(
const std::string& format,
const py::dtype dtype,
const tl::optional<double> compression,
const tl::optional<std::string> encoding,
const tl::optional<int64_t> bits_per_sample);
} // namespace paddleaudio } // namespace paddleaudio
} // namespace sox_utils } // namespace sox_utils

@ -1,139 +0,0 @@
// #include "sox/effects.h"
// #include "sox/effects_chain.h"
#include "sox/io.h"
#include "sox/types.h"
#include "sox/utils.h"
using namespace torch::indexing;
using namespace paddleaudio::sox_utils;
namespace paddleaudio {
namespace sox_io {
tl::optional<MetaDataTuple> get_info_file(
const std::string& path, const tl::optional<std::string>& format) {
SoxFormat sf(sox_open_read(
path.c_str(),
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
if (static_cast<sox_format_t*>(sf) == nullptr ||
sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
return {};
}
return std::forward_as_tuple(
static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.length / sf->signal.channels),
static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->encoding.bits_per_sample),
get_encoding(sf->encoding.encoding));
}
std::vector<std::vector<std::string>> get_effects(
const tl::optional<int64_t>& frame_offset,
const tl::optional<int64_t>& num_frames) {
const auto offset = frame_offset.value_or(0);
if (offset < 0) {
throw std::runtime_error(
"Invalid argument: frame_offset must be non-negative.");
}
const auto frames = num_frames.value_or(-1);
if (frames == 0 || frames < -1) {
throw std::runtime_error(
"Invalid argument: num_frames must be -1 or greater than 0.");
}
std::vector<std::vector<std::string>> effects;
if (frames != -1) {
std::ostringstream os_offset, os_frames;
os_offset << offset << "s";
os_frames << "+" << frames << "s";
effects.emplace_back(
std::vector<std::string>{"trim", os_offset.str(), os_frames.str()});
} else if (offset != 0) {
std::ostringstream os_offset;
os_offset << offset << "s";
effects.emplace_back(std::vector<std::string>{"trim", os_offset.str()});
}
return effects;
}
tl::optional<std::tuple<torch::Tensor, int64_t>> load_audio_file(
const std::string& path,
const tl::optional<int64_t>& frame_offset,
const tl::optional<int64_t>& num_frames,
tl::optional<bool> normalize,
tl::optional<bool> channels_first,
const tl::optional<std::string>& format) {
auto effects = get_effects(frame_offset, num_frames);
return paddleaudio::sox_effects::apply_effects_file(
path, effects, normalize, channels_first, format);
}
void save_audio_file(const std::string& path,
torch::Tensor tensor,
int64_t sample_rate,
bool channels_first,
tl::optional<double> compression,
tl::optional<std::string> format,
tl::optional<std::string> encoding,
tl::optional<int64_t> bits_per_sample) {
validate_input_tensor(tensor);
const auto filetype = [&]() {
if (format.has_value()) return format.value();
return get_filetype(path);
}();
if (filetype == "amr-nb") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
TORCH_CHECK(num_channels == 1,
"amr-nb format only supports single channel audio.");
} else if (filetype == "htk") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
TORCH_CHECK(num_channels == 1,
"htk format only supports single channel audio.");
} else if (filetype == "gsm") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
TORCH_CHECK(num_channels == 1,
"gsm format only supports single channel audio.");
TORCH_CHECK(sample_rate == 8000,
"gsm format only supports a sampling rate of 8kHz.");
}
const auto signal_info =
get_signalinfo(&tensor, sample_rate, filetype, channels_first);
const auto encoding_info = get_encodinginfo_for_save(
filetype, tensor.dtype(), compression, encoding, bits_per_sample);
SoxFormat sf(sox_open_write(path.c_str(),
&signal_info,
&encoding_info,
/*filetype=*/filetype.c_str(),
/*oob=*/nullptr,
/*overwrite_permitted=*/nullptr));
if (static_cast<sox_format_t*>(sf) == nullptr) {
throw std::runtime_error(
"Error saving audio file: failed to open file " + path);
}
paddleaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/get_tensor_encodinginfo(tensor.dtype()),
/*output_encoding=*/sf->encoding);
chain.addInputTensor(&tensor, sample_rate, channels_first);
chain.addOutputFile(sf);
chain.run();
}
TORCH_LIBRARY_FRAGMENT(paddleaudio, m) {
m.def("paddleaudio::sox_io_get_info", &paddleaudio::sox_io::get_info_file);
m.def("paddleaudio::sox_io_load_audio_file",
&paddleaudio::sox_io::load_audio_file);
m.def("paddleaudio::sox_io_save_audio_file",
&paddleaudio::sox_io::save_audio_file);
}
} // namespace sox_io
} // namespace paddleaudio

@ -1,44 +0,0 @@
// Copyright (c) 2017 Facebook Inc. (Soumith Chintala),
// All rights reserved.
#ifndef PADDLEAUDIO_SOX_IO_H
#define PADDLEAUDIO_SOX_IO_H
// #include "sox/utils.h"
#include "optional/optional.hpp"
namespace paddleaudio {
namespace sox_io {
auto get_effects(const tl::optional<int64_t>& frame_offset,
const tl::optional<int64_t>& num_frames)
-> std::vector<std::vector<std::string>>;
using MetaDataTuple =
std::tuple<int64_t, int64_t, int64_t, int64_t, std::string>;
tl::optional<MetaDataTuple> get_info_file(
const std::string& path, const tl::optional<std::string>& format);
tl::optional<std::tuple<torch::Tensor, int64_t>> load_audio_file(
const std::string& path,
const tl::optional<int64_t>& frame_offset,
const tl::optional<int64_t>& num_frames,
tl::optional<bool> normalize,
tl::optional<bool> channels_first,
const tl::optional<std::string>& format);
void save_audio_file(const std::string& path,
torch::Tensor tensor,
int64_t sample_rate,
bool channels_first,
tl::optional<double> compression,
tl::optional<std::string> format,
tl::optional<std::string> encoding,
tl::optional<int64_t> bits_per_sample);
} // namespace sox_io
} // namespace paddleaudio
#endif

@ -0,0 +1,101 @@
from typing import Dict, List
from paddlespeech.audio._internal import module_utils as _mod_utils
from paddlespeech.audio import _paddleaudio
@_mod_utils.requires_sox()
def set_seed(seed: int):
"""Set libsox's PRNG
Args:
seed (int): seed value. valid range is int32.
See Also:
http://sox.sourceforge.net/sox.html
"""
_paddleaudio.sox_utils_set_seed(seed)
@_mod_utils.requires_sox()
def set_verbosity(verbosity: int):
"""Set libsox's verbosity
Args:
verbosity (int): Set verbosity level of libsox.
* ``1`` failure messages
* ``2`` warnings
* ``3`` details of processing
* ``4``-``6`` increasing levels of debug messages
See Also:
http://sox.sourceforge.net/sox.html
"""
_paddleaudio.sox_utils_set_verbosity(verbosity)
@_mod_utils.requires_sox()
def set_buffer_size(buffer_size: int):
"""Set buffer size for sox effect chain
Args:
buffer_size (int): Set the size in bytes of the buffers used for processing audio.
See Also:
http://sox.sourceforge.net/sox.html
"""
_paddleaudio.sox_utils_set_buffer_size(buffer_size)
@_mod_utils.requires_sox()
def set_use_threads(use_threads: bool):
"""Set multithread option for sox effect chain
Args:
use_threads (bool): When ``True``, enables ``libsox``'s parallel effects channels processing.
To use mutlithread, the underlying ``libsox`` has to be compiled with OpenMP support.
See Also:
http://sox.sourceforge.net/sox.html
"""
_paddleaudio.sox_utils_set_use_threads(use_threads)
@_mod_utils.requires_sox()
def list_effects() -> Dict[str, str]:
"""List the available sox effect names
Returns:
Dict[str, str]: Mapping from ``effect name`` to ``usage``
"""
return dict(_paddleaudio.sox_utils_list_effects())
@_mod_utils.requires_sox()
def list_read_formats() -> List[str]:
"""List the supported audio formats for read
Returns:
List[str]: List of supported audio formats
"""
return _paddleaudio.sox_utils_list_read_formats()
@_mod_utils.requires_sox()
def list_write_formats() -> List[str]:
"""List the supported audio formats for write
Returns:
List[str]: List of supported audio formats
"""
return _paddleaudio.sox_utils_list_write_formats()
@_mod_utils.requires_sox()
def get_buffer_size() -> int:
"""Get buffer size for sox effect chain
Returns:
int: size in bytes of buffers used for processing audio.
"""
return _paddleaudio.sox_utils_get_buffer_size()

@ -43,7 +43,7 @@ base = [
"pypinyin", "pypinyin-dict", "python-dateutil", "pyworld", "resampy==0.2.2", "pypinyin", "pypinyin-dict", "python-dateutil", "pyworld", "resampy==0.2.2",
"sacrebleu", "scipy", "sentencepiece~=0.1.96", "soundfile~=0.10", "sacrebleu", "scipy", "sentencepiece~=0.1.96", "soundfile~=0.10",
"textgrid", "timer", "tqdm", "typeguard", "visualdl", "webrtcvad", "textgrid", "timer", "tqdm", "typeguard", "visualdl", "webrtcvad",
"yacs~=0.1.8", "prettytable", "zhon", "colorlog", "pathos == 0.2.8" "yacs~=0.1.8", "prettytable", "zhon", "colorlog", "pathos == 0.2.8", "Ninja"
] ]
server = [ server = [

@ -0,0 +1,78 @@
{"effects": [["allpass", "300", "10"]]}
{"effects": [["band", "300", "10"]]}
{"effects": [["bandpass", "300", "10"]]}
{"effects": [["bandreject", "300", "10"]]}
{"effects": [["bass", "-10"]]}
{"effects": [["biquad", "0.4", "0.2", "0.9", "0.7", "0.2", "0.6"]]}
{"effects": [["chorus", "0.7", "0.9", "55", "0.4", "0.25", "2", "-t"]]}
{"effects": [["chorus", "0.6", "0.9", "50", "0.4", "0.25", "2", "-t", "60", "0.32", "0.4", "1.3", "-s"]]}
{"effects": [["chorus", "0.5", "0.9", "50", "0.4", "0.25", "2", "-t", "60", "0.32", "0.4", "2.3", "-t", "40", "0.3", "0.3", "1.3", "-s"]]}
{"effects": [["channels", "1"]]}
{"effects": [["channels", "2"]]}
{"effects": [["channels", "3"]]}
{"effects": [["compand", "0.3,1", "6:-70,-60,-20", "-5", "-90", "0.2"]]}
{"effects": [["compand", ".1,.2", "-inf,-50.1,-inf,-50,-50", "0", "-90", ".1"]]}
{"effects": [["compand", ".1,.1", "-45.1,-45,-inf,0,-inf", "45", "-90", ".1"]]}
{"effects": [["contrast", "0"]]}
{"effects": [["contrast", "25"]]}
{"effects": [["contrast", "50"]]}
{"effects": [["contrast", "75"]]}
{"effects": [["contrast", "100"]]}
{"effects": [["dcshift", "1.0"]]}
{"effects": [["dcshift", "-1.0"]]}
{"effects": [["deemph"]], "input_sample_rate": 44100}
{"effects": [["dither", "-s"]]}
{"effects": [["dither", "-S"]]}
{"effects": [["divide"]]}
{"effects": [["downsample", "2"]], "input_sample_rate": 8000, "output_sample_rate": 4000}
{"effects": [["earwax"]], "input_sample_rate": 44100}
{"effects": [["echo", "0.8", "0.88", "60", "0.4"]]}
{"effects": [["echo", "0.8", "0.88", "6", "0.4"]]}
{"effects": [["echo", "0.8", "0.9", "1000", "0.3"]]}
{"effects": [["echo", "0.8", "0.9", "1000", "0.3", "1800", "0.25"]]}
{"effects": [["echos", "0.8", "0.7", "700", "0.25", "700", "0.3"]]}
{"effects": [["echos", "0.8", "0.7", "700", "0.25", "900", "0.3"]]}
{"effects": [["echos", "0.8", "0.7", "40", "0.25", "63", "0.3"]]}
{"effects": [["equalizer", "300", "10", "5"]]}
{"effects": [["fade", "q", "3"]]}
{"effects": [["fade", "h", "3"]]}
{"effects": [["fade", "t", "3"]]}
{"effects": [["fade", "l", "3"]]}
{"effects": [["fade", "p", "3"]]}
{"effects": [["fir", "0.0195", "-0.082", "0.234", "0.891", "-0.145", "0.043"]]}
{"effects": [["fir", "<ASSET_DIR>/sox_effect_test_fir_coeffs.txt"]]}
{"effects": [["flanger"]]}
{"effects": [["gain", "-l", "-6"]]}
{"effects": [["highpass", "-1", "300"]]}
{"effects": [["highpass", "-2", "300"]]}
{"effects": [["hilbert"]]}
{"effects": [["loudness"]]}
{"effects": [["lowpass", "-1", "300"]]}
{"effects": [["lowpass", "-2", "300"]]}
{"effects": [["mcompand", "0.005,0.1 -47,-40,-34,-34,-17,-33", "100", "0.003,0.05 -47,-40,-34,-34,-17,-33", "400", "0.000625,0.0125 -47,-40,-34,-34,-15,-33", "1600", "0.0001,0.025 -47,-40,-34,-34,-31,-31,-0,-30", "6400", "0,0.025 -38,-31,-28,-28,-0,-25"]], "input_sample_rate": 44100}
{"effects": [["oops"]]}
{"effects": [["overdrive"]]}
{"effects": [["pad"]]}
{"effects": [["phaser"]]}
{"effects": [["remix", "6", "7", "8", "0"]], "num_channels": 8}
{"effects": [["remix", "1-3,7", "3"]], "num_channels": 8}
{"effects": [["repeat"]]}
{"effects": [["reverb"]]}
{"effects": [["reverse"]]}
{"effects": [["riaa"]], "input_sample_rate": 44100}
{"effects": [["silence", "0"]]}
{"effects": [["speed", "1.3"]], "input_sample_rate": 4000, "output_sample_rate": 5200}
{"effects": [["speed", "0.7"]], "input_sample_rate": 4000, "output_sample_rate": 2800}
{"effects": [["stat"]]}
{"effects": [["stats"]]}
{"effects": [["stretch"]]}
{"effects": [["swap"]]}
{"effects": [["synth"]]}
{"effects": [["tempo", "0.9"]]}
{"effects": [["tempo", "1.1"]]}
{"effects": [["treble", "3"]]}
{"effects": [["tremolo", "300", "40"]]}
{"effects": [["tremolo", "300", "50"]]}
{"effects": [["trim", "0", "0.1"]]}
{"effects": [["upsample", "2"]], "input_sample_rate": 8000, "output_sample_rate": 16000}
{"effects": [["vol", "3"]]}

@ -0,0 +1,32 @@
def get_encoding(ext, dtype):
exts = {
"mp3",
"flac",
"vorbis",
}
encodings = {
"float32": "PCM_F",
"int32": "PCM_S",
"int16": "PCM_S",
"uint8": "PCM_U",
}
return ext.upper() if ext in exts else encodings[dtype]
def get_bit_depth(dtype):
bit_depths = {
"float32": 32,
"int32": 32,
"int16": 16,
"uint8": 8,
}
return bit_depths[dtype]
def get_bits_per_sample(ext, dtype):
bits_per_samples = {
"flac": 24,
"mp3": 0,
"vorbis": 0,
}
return bits_per_samples.get(ext, get_bit_depth(dtype))

@ -0,0 +1,290 @@
import unittest
import itertools
import tarfile
from contextlib import contextmanager
import numpy as np
import paddle
import os
import io
from parameterized import parameterized
from paddlespeech.audio.backends import sox_io_backend
from tests.unit.common_utils import (
get_wav_data,
load_wav,
save_wav,
TempDirMixin,
sox_utils,
data_utils
)
from common import get_encoding, get_bits_per_sample
#code is from:https://github.com/pytorch/audio/blob/main/torchaudio/test/torchaudio_unittest/backend/sox_io/info_test.py
class TestInfo(TempDirMixin, unittest.TestCase):
@parameterized.expand(
list(
itertools.product(
["float32", "int32",],
[8000, 16000],
[1, 2],
)
),
)
def test_wav(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check wav file correctly"""
duration = 1
path = self.get_temp_path("data.wav")
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == sox_utils.get_bit_depth(dtype)
assert info.encoding == get_encoding("wav", dtype)
@parameterized.expand(
list(
itertools.product(
["float32", "int32"],
[8000, 16000],
[4, 8, 16, 32],
)
),
)
def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check wav file with channels more than 2 correctly"""
duration = 1
path = self.get_temp_path("data.wav")
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == sox_utils.get_bit_depth(dtype)
def test_ulaw(self):
"""`sox_io_backend.info` can check ulaw file correctly"""
duration = 1
num_channels = 1
sample_rate = 8000
path = self.get_temp_path("data.wav")
sox_utils.gen_audio_file(
path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=8, encoding="u-law", duration=duration
)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 8
assert info.encoding == "ULAW"
def test_alaw(self):
"""`sox_io_backend.info` can check alaw file correctly"""
duration = 1
num_channels = 1
sample_rate = 8000
path = self.get_temp_path("data.wav")
sox_utils.gen_audio_file(
path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=8, encoding="a-law", duration=duration
)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 8
assert info.encoding == "ALAW"
#class TestInfoOpus(unittest.TestCase):
#@parameterized.expand(
#list(
#itertools.product(
#["96k"],
#[1, 2],
#[0, 5, 10],
#)
#),
#)
#def test_opus(self, bitrate, num_channels, compression_level):
#"""`sox_io_backend.info` can check opus file correcty"""
#path = data_utils.get_asset_path("io", f"{bitrate}_{compression_level}_{num_channels}ch.opus")
#info = sox_io_backend.info(path)
#assert info.sample_rate == 48000
#assert info.num_frames == 32768
#assert info.num_channels == num_channels
#assert info.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats
#assert info.encoding == "OPUS"
class FileObjTestBase(TempDirMixin):
def _gen_file(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None):
path = self.get_temp_path(f"test.{ext}")
bit_depth = sox_utils.get_bit_depth(dtype)
duration = num_frames / sample_rate
comment_file = self._gen_comment_file(comments) if comments else None
sox_utils.gen_audio_file(
path,
sample_rate,
num_channels=num_channels,
encoding=sox_utils.get_encoding(dtype),
bit_depth=bit_depth,
duration=duration,
comment_file=comment_file,
)
return path
def _gen_comment_file(self, comments):
comment_path = self.get_temp_path("comment.txt")
with open(comment_path, "w") as file_:
file_.writelines(comments)
return comment_path
class Unseekable:
def __init__(self, fileobj):
self.fileobj = fileobj
def read(self, n):
return self.fileobj.read(n)
class TestFileObject(FileObjTestBase, unittest.TestCase):
def _query_fileobj(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None):
path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)
format_ = ext if ext in ["mp3"] else None
with open(path, "rb") as fileobj:
return sox_io_backend.info(fileobj, format_)
def _query_bytesio(self, ext, dtype, sample_rate, num_channels, num_frames):
path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames)
format_ = ext if ext in ["mp3"] else None
with open(path, "rb") as file_:
fileobj = io.BytesIO(file_.read())
return sox_io_backend.info(fileobj, format_)
def _query_tarfile(self, ext, dtype, sample_rate, num_channels, num_frames):
audio_path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames)
audio_file = os.path.basename(audio_path)
archive_path = self.get_temp_path("archive.tar.gz")
with tarfile.TarFile(archive_path, "w") as tarobj:
tarobj.add(audio_path, arcname=audio_file)
format_ = ext if ext in ["mp3"] else None
with tarfile.TarFile(archive_path, "r") as tarobj:
fileobj = tarobj.extractfile(audio_file)
return sox_io_backend.info(fileobj, format_)
@contextmanager
def _set_buffer_size(self, buffer_size):
try:
original_buffer_size = get_buffer_size()
set_buffer_size(buffer_size)
yield
finally:
set_buffer_size(original_buffer_size)
@parameterized.expand(
[
("wav", "float32"),
("wav", "int32"),
("wav", "int16"),
("wav", "uint8"),
]
)
def test_fileobj(self, ext, dtype):
"""Querying audio via file object works"""
sample_rate = 16000
num_frames = 3 * sample_rate
num_channels = 2
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
@parameterized.expand(
[
("wav", "float32"),
("wav", "int32"),
("wav", "int16"),
("wav", "uint8"),
]
)
def test_bytesio(self, ext, dtype):
"""Querying audio via ByteIO object works for small data"""
sample_rate = 16000
num_frames = 3 * sample_rate
num_channels = 2
sinfo = self._query_bytesio(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
@parameterized.expand(
[
("wav", "float32"),
("wav", "int32"),
("wav", "int16"),
("wav", "uint8"),
]
)
def test_bytesio_tiny(self, ext, dtype):
"""Querying audio via ByteIO object works for small data"""
sample_rate = 8000
num_frames = 4
num_channels = 2
sinfo = self._query_bytesio(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
@parameterized.expand(
[
("wav", "float32"),
("wav", "int32"),
("wav", "int16"),
("wav", "uint8"),
("flac", "float32"),
("vorbis", "float32"),
("amb", "int16"),
]
)
def test_tarfile(self, ext, dtype):
"""Querying compressed audio via file-like object works"""
sample_rate = 16000
num_frames = 3.0 * sample_rate
num_channels = 2
sinfo = self._query_tarfile(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ["vorbis"] else num_frames
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
if __name__ == '__main__':
unittest.main()

@ -0,0 +1,47 @@
import unittest
import itertools
from parameterized import parameterized
import numpy as np
from paddlespeech.audio._internal import module_utils as _mod_utils
from paddlespeech.audio.backends import sox_io_backend
from tests.unit.common_utils import (
get_wav_data,
load_wav,
save_wav,
)
#code is from:https://github.com/pytorch/audio/blob/main/torchaudio/test/torchaudio_unittest/backend/sox_io/load_test.py
class TestLoad(unittest.TestCase):
def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration):
"""`sox_io_backend.load` can load wav format correctly.
Wav data loaded with sox_io backend should match those with scipy
"""
path = 'testdata/reference.wav'
data = get_wav_data(dtype, num_channels, normalize=normalize, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
expected = load_wav(path, normalize=normalize)[0]
data, sr = sox_io_backend.load(path, normalize=normalize)
assert sr == sample_rate
np.testing.assert_array_almost_equal(data, expected, decimal=4)
@parameterized.expand(
list(
itertools.product(
["float64", "float32", "int32",],
[8000, 16000],
[1, 2],
[False, True],
)
),
)
def test_wav(self, dtype, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load wav format correctly."""
self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1)
if __name__ == '__main__':
unittest.main()

@ -0,0 +1,175 @@
import io
import os
import unittest
import numpy as np
import paddle
from parameterized import parameterized
from paddlespeech.audio.backends import sox_io_backend
from tests.unit.common_utils import (
get_wav_data,
load_wav,
save_wav,
nested_params,
TempDirMixin,
sox_utils
)
#code is from:https://github.com/pytorch/audio/blob/main/torchaudio/test/torchaudio_unittest/backend/sox_io/save_test.py
def _get_sox_encoding(encoding):
encodings = {
"PCM_F": "floating-point",
"PCM_S": "signed-integer",
"PCM_U": "unsigned-integer",
"ULAW": "u-law",
"ALAW": "a-law",
}
return encodings.get(encoding)
class TestSaveBase(TempDirMixin):
def assert_save_consistency(
self,
format: str,
*,
compression: float = None,
encoding: str = None,
bits_per_sample: int = None,
sample_rate: float = 8000,
num_channels: int = 2,
num_frames: float = 3 * 8000,
src_dtype: str = "int32",
test_mode: str = "path",
):
"""`save` function produces file that is comparable with `sox` command
To compare that the file produced by `save` function agains the file produced by
the equivalent `sox` command, we need to load both files.
But there are many formats that cannot be opened with common Python modules (like
SciPy).
So we use `sox` command to prepare the original data and convert the saved files
into a format that SciPy can read (PCM wav).
The following diagram illustrates this process. The difference is 2.1. and 3.1.
This assumes that
- loading data with SciPy preserves the data well.
- converting the resulting files into WAV format with `sox` preserve the data well.
x
| 1. Generate source wav file with SciPy
|
v
-------------- wav ----------------
| |
| 2.1. load with scipy | 3.1. Convert to the target
| then save it into the target | format depth with sox
| format with paddleaudio |
v v
target format target format
| |
| 2.2. Convert to wav with sox | 3.2. Convert to wav with sox
| |
v v
wav wav
| |
| 2.3. load with scipy | 3.3. load with scipy
| |
v v
tensor -------> compare <--------- tensor
"""
cmp_encoding = "floating-point"
cmp_bit_depth = 32
src_path = self.get_temp_path("1.source.wav")
tgt_path = self.get_temp_path(f"2.1.paddleaudio.{format}")
tst_path = self.get_temp_path("2.2.result.wav")
sox_path = self.get_temp_path(f"3.1.sox.{format}")
ref_path = self.get_temp_path("3.2.ref.wav")
# 1. Generate original wav
data = get_wav_data(src_dtype, num_channels, normalize=False, num_frames=num_frames)
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to target format with paddleaudio
data = load_wav(src_path, normalize=False)[0]
if test_mode == "path":
sox_io_backend.save(
tgt_path, data, sample_rate, compression=compression, encoding=encoding, bits_per_sample=bits_per_sample
)
elif test_mode == "fileobj":
with open(tgt_path, "bw") as file_:
sox_io_backend.save(
file_,
data,
sample_rate,
format=format,
compression=compression,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
elif test_mode == "bytesio":
file_ = io.BytesIO()
sox_io_backend.save(
file_,
data,
sample_rate,
format=format,
compression=compression,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
file_.seek(0)
with open(tgt_path, "bw") as f:
f.write(file_.read())
else:
raise ValueError(f"Unexpected test mode: {test_mode}")
# 2.2. Convert the target format to wav with sox
sox_utils.convert_audio_file(tgt_path, tst_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
# 2.3. Load with SciPy
found = load_wav(tst_path, normalize=False)[0]
# 3.1. Convert the original wav to target format with sox
sox_encoding = _get_sox_encoding(encoding)
sox_utils.convert_audio_file(
src_path, sox_path, compression=compression, encoding=sox_encoding, bit_depth=bits_per_sample
)
# 3.2. Convert the target format to wav with sox
sox_utils.convert_audio_file(sox_path, ref_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
# 3.3. Load with SciPy
expected = load_wav(ref_path, normalize=False)[0]
np.testing.assert_array_almost_equal(found, expected)
class TestSave(TestSaveBase, unittest.TestCase):
@nested_params(
["path",],
[
("PCM_U", 8),
("PCM_S", 16),
("PCM_S", 32),
("PCM_F", 32),
("PCM_F", 64),
("ULAW", 8),
("ALAW", 8),
],
)
def test_save_wav(self, test_mode, enc_params):
encoding, bits_per_sample = enc_params
self.assert_save_consistency("wav", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode)
@nested_params(
["path", ],
[
("float32",),
("int32",),
],
)
def test_save_wav_dtype(self, test_mode, params):
(dtype,) = params
self.assert_save_consistency("wav", src_dtype=dtype, test_mode=test_mode)
if __name__ == '__main__':
unittest.main()

@ -0,0 +1,183 @@
import io
import itertools
import unittest
from parameterized import parameterized
from paddlespeech.audio.backends import sox_io_backend
from tests.unit.common_utils import (
get_wav_data,
TempDirMixin,
name_func
)
class SmokeTest(TempDirMixin, unittest.TestCase):
"""Run smoke test on various audio format
The purpose of this test suite is to verify that sox_io_backend functionalities do not exhibit
abnormal behaviors.
This test suite should be able to run without any additional tools (such as sox command),
however without such tools, the correctness of each function cannot be verified.
"""
def run_smoke_test(self, ext, sample_rate, num_channels, *, compression=None, dtype="float32"):
duration = 1
num_frames = sample_rate * duration
#path = self.get_temp_path(f"test.{ext}")
path = self.get_temp_path(f"test.{ext}")
original = get_wav_data(dtype, num_channels, normalize=False, num_frames=num_frames)
# 1. run save
sox_io_backend.save(path, original, sample_rate, compression=compression)
# 2. run info
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_channels == num_channels
# 3. run load
loaded, sr = sox_io_backend.load(path, normalize=False)
assert sr == sample_rate
assert loaded.shape[0] == num_channels
@parameterized.expand(
list(
itertools.product(
["float32", "int32" ],
#["float32", "int32", "int16", "uint8"],
[8000, 16000],
[1, 2],
)
),
name_func=name_func,
)
def test_wav(self, dtype, sample_rate, num_channels):
"""Run smoke test on wav format"""
self.run_smoke_test("wav", sample_rate, num_channels, dtype=dtype)
#@parameterized.expand(
#list(
#itertools.product(
#[8000, 16000],
#[1, 2],
#[-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320],
#)
#)
#)
#def test_mp3(self, sample_rate, num_channels, bit_rate):
#"""Run smoke test on mp3 format"""
#self.run_smoke_test("mp3", sample_rate, num_channels, compression=bit_rate)
#@parameterized.expand(
#list(
#itertools.product(
#[8000, 16000],
#[1, 2],
#[-1, 0, 1, 2, 3, 3.6, 5, 10],
#)
#)
#)
#def test_vorbis(self, sample_rate, num_channels, quality_level):
#"""Run smoke test on vorbis format"""
#self.run_smoke_test("vorbis", sample_rate, num_channels, compression=quality_level)
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)
),
name_func=name_func,
)
def test_flac(self, sample_rate, num_channels, compression_level):
"""Run smoke test on flac format"""
self.run_smoke_test("flac", sample_rate, num_channels, compression=compression_level)
class SmokeTestFileObj(unittest.TestCase):
"""Run smoke test on various audio format
The purpose of this test suite is to verify that sox_io_backend functionalities do not exhibit
abnormal behaviors.
This test suite should be able to run without any additional tools (such as sox command),
however without such tools, the correctness of each function cannot be verified.
"""
def run_smoke_test(self, ext, sample_rate, num_channels, *, compression=None, dtype="float32"):
duration = 1
num_frames = sample_rate * duration
original = get_wav_data(dtype, num_channels, normalize=False, num_frames=num_frames)
fileobj = io.BytesIO()
# 1. run save
sox_io_backend.save(fileobj, original, sample_rate, compression=compression, format=ext)
# 2. run info
fileobj.seek(0)
info = sox_io_backend.info(fileobj, format=ext)
assert info.sample_rate == sample_rate
assert info.num_channels == num_channels
# 3. run load
fileobj.seek(0)
loaded, sr = sox_io_backend.load(fileobj, normalize=False, format=ext)
assert sr == sample_rate
assert loaded.shape[0] == num_channels
@parameterized.expand(
list(
itertools.product(
["float32", "int32"],
[8000, 16000],
[1, 2],
)
),
name_func=name_func,
)
def test_wav(self, dtype, sample_rate, num_channels):
"""Run smoke test on wav format"""
self.run_smoke_test("wav", sample_rate, num_channels, dtype=dtype)
# not support yet
#@parameterized.expand(
#list(
#itertools.product(
#[8000, 16000],
#[1, 2],
#[-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320],
#)
#)
#)
#def test_mp3(self, sample_rate, num_channels, bit_rate):
#"""Run smoke test on mp3 format"""
#self.run_smoke_test("mp3", sample_rate, num_channels, compression=bit_rate)
#@parameterized.expand(
#list(
#itertools.product(
#[8000, 16000],
#[1, 2],
#[-1, 0, 1, 2, 3, 3.6, 5, 10],
#)
#)
#)
#def test_vorbis(self, sample_rate, num_channels, quality_level):
#"""Run smoke test on vorbis format"""
#self.run_smoke_test("vorbis", sample_rate, num_channels, compression=quality_level)
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)
),
name_func=name_func,
)
def test_flac(self, sample_rate, num_channels, compression_level):
#"""Run smoke test on flac format"""
self.run_smoke_test("flac", sample_rate, num_channels, compression=compression_level)
if __name__ == '__main__':
#test_func()
unittest.main()

@ -0,0 +1,347 @@
#code is from: https://github.com/pytorch/audio/blob/main/test/torchaudio_unittest/sox_effect/sox_effect_test.py
import io
import itertools
import tarfile
import unittest
from pathlib import Path
import numpy as np
from parameterized import parameterized
from paddlespeech.audio import sox_effects
from paddlespeech.audio._internal import module_utils as _mod_utils
from tests.unit.common_utils import (
get_sinusoid,
get_wav_data,
load_wav,
save_wav,
sox_utils,
TempDirMixin,
name_func,
load_effects_params
)
if _mod_utils.is_module_available("requests"):
import requests
class TestSoxEffects(unittest.TestCase):
def test_init(self):
"""Calling init_sox_effects multiple times does not crush"""
for _ in range(3):
sox_effects.init_sox_effects()
class TestSoxEffectsTensor(TempDirMixin, unittest.TestCase):
"""Test suite for `apply_effects_tensor` function"""
@parameterized.expand(
list(itertools.product(["float32", "int32"], [8000, 16000], [1, 2, 4, 8], [True, False])),
)
def test_apply_no_effect(self, dtype, sample_rate, num_channels, channels_first):
"""`apply_effects_tensor` without effects should return identical data as input"""
original = get_wav_data(dtype, num_channels, channels_first=channels_first)
expected = original.clone()
found, output_sample_rate = sox_effects.apply_effects_tensor(expected, sample_rate, [], channels_first)
assert (output_sample_rate == sample_rate)
# SoxEffect should not alter the input Tensor object
#self.assertEqual(original, expected)
np.testing.assert_array_almost_equal(original.numpy(), expected.numpy())
# SoxEffect should not return the same Tensor object
assert expected is not found
# Returned Tensor should equal to the input Tensor
#self.assertEqual(expected, found)
np.testing.assert_array_almost_equal(expected.numpy(), found.numpy())
@parameterized.expand(
load_effects_params("sox_effect_test_args.jsonl"),
name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}',
)
def test_apply_effects(self, args):
"""`apply_effects_tensor` should return identical data as sox command"""
effects = args["effects"]
num_channels = args.get("num_channels", 2)
input_sr = args.get("input_sample_rate", 8000)
output_sr = args.get("output_sample_rate")
input_path = self.get_temp_path("input.wav")
reference_path = self.get_temp_path("reference.wav")
original = get_sinusoid(frequency=800, sample_rate=input_sr, n_channels=num_channels, dtype="float32")
save_wav(input_path, original, input_sr)
sox_utils.run_sox_effect(input_path, reference_path, effects, output_sample_rate=output_sr)
expected, expected_sr = load_wav(reference_path)
found, sr = sox_effects.apply_effects_tensor(original, input_sr, effects)
assert sr == expected_sr
#self.assertEqual(expected, found)
np.testing.assert_array_almost_equal(expected.numpy(), found.numpy())
class TestSoxEffectsFile(TempDirMixin, unittest.TestCase):
"""Test suite for `apply_effects_file` function"""
@parameterized.expand(
list(
itertools.product(
["float32", "int32"],
[8000, 16000],
[1, 2, 4, 8],
[False, True],
)
),
#name_func=name_func,
)
def test_apply_no_effect(self, dtype, sample_rate, num_channels, channels_first):
"""`apply_effects_file` without effects should return identical data as input"""
path = self.get_temp_path("input.wav")
expected = get_wav_data(dtype, num_channels, channels_first=channels_first)
save_wav(path, expected, sample_rate, channels_first=channels_first)
found, output_sample_rate = sox_effects.apply_effects_file(
path, [], normalize=False, channels_first=channels_first
)
assert output_sample_rate == sample_rate
#self.assertEqual(expected, found)
np.testing.assert_array_almost_equal(expected.numpy(), found.numpy())
@parameterized.expand(
load_effects_params("sox_effect_test_args.jsonl"),
#name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}',
)
def test_apply_effects_str(self, args):
"""`apply_effects_file` should return identical data as sox command"""
dtype = "int32"
channels_first = True
effects = args["effects"]
num_channels = args.get("num_channels", 2)
input_sr = args.get("input_sample_rate", 8000)
output_sr = args.get("output_sample_rate")
input_path = self.get_temp_path("input.wav")
reference_path = self.get_temp_path("reference.wav")
data = get_wav_data(dtype, num_channels, channels_first=channels_first)
save_wav(input_path, data, input_sr, channels_first=channels_first)
sox_utils.run_sox_effect(input_path, reference_path, effects, output_sample_rate=output_sr)
expected, expected_sr = load_wav(reference_path)
found, sr = sox_effects.apply_effects_file(input_path, effects, normalize=False, channels_first=channels_first)
assert sr == expected_sr
#self.assertEqual(found, expected)
np.testing.assert_array_almost_equal(expected.numpy(), found.numpy())
def test_apply_effects_path(self):
"""`apply_effects_file` should return identical data as sox command when file path is given as a Path Object"""
dtype = "int32"
channels_first = True
effects = [["hilbert"]]
num_channels = 2
input_sr = 8000
output_sr = 8000
input_path = self.get_temp_path("input.wav")
reference_path = self.get_temp_path("reference.wav")
data = get_wav_data(dtype, num_channels, channels_first=channels_first)
save_wav(input_path, data, input_sr, channels_first=channels_first)
sox_utils.run_sox_effect(input_path, reference_path, effects, output_sample_rate=output_sr)
expected, expected_sr = load_wav(reference_path)
found, sr = sox_effects.apply_effects_file(
Path(input_path), effects, normalize=False, channels_first=channels_first
)
assert sr == expected_sr
#self.assertEqual(found, expected)
np.testing.assert_array_almost_equal(expected.numpy(), found.numpy())
class TestFileFormats(TempDirMixin, unittest.TestCase):
"""`apply_effects_file` gives the same result as sox on various file formats"""
@parameterized.expand(
list(
itertools.product(
["float32", "int32"],
[8000, 16000],
[1, 2],
)
),
#name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}',
)
def test_wav(self, dtype, sample_rate, num_channels):
"""`apply_effects_file` works on various wav format"""
channels_first = True
effects = [["band", "300", "10"]]
input_path = self.get_temp_path("input.wav")
reference_path = self.get_temp_path("reference.wav")
data = get_wav_data(dtype, num_channels, channels_first=channels_first)
save_wav(input_path, data, sample_rate, channels_first=channels_first)
sox_utils.run_sox_effect(input_path, reference_path, effects)
expected, expected_sr = load_wav(reference_path)
found, sr = sox_effects.apply_effects_file(input_path, effects, normalize=False, channels_first=channels_first)
assert sr == expected_sr
#self.assertEqual(found, expected)
np.testing.assert_array_almost_equal(found.numpy(), expected.numpy())
#not support now
#@parameterized.expand(
#list(
#itertools.product(
#[8000, 16000],
#[1, 2],
#)
#),
##name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}',
#)
#def test_flac(self, sample_rate, num_channels):
#"""`apply_effects_file` works on various flac format"""
#channels_first = True
#effects = [["band", "300", "10"]]
#input_path = self.get_temp_path("input.flac")
#reference_path = self.get_temp_path("reference.wav")
#sox_utils.gen_audio_file(input_path, sample_rate, num_channels)
#sox_utils.run_sox_effect(input_path, reference_path, effects, output_bitdepth=32)
#expected, expected_sr = load_wav(reference_path)
#found, sr = sox_effects.apply_effects_file(input_path, effects, channels_first=channels_first)
#save_wav(self.get_temp_path("result.wav"), found, sr, channels_first=channels_first)
#assert sr == expected_sr
##self.assertEqual(found, expected)
#np.testing.assert_array_almost_equal(found.numpy(), expected.numpy())
#@parameterized.expand(
#list(
#itertools.product(
#[8000, 16000],
#[1, 2],
#)
#),
##name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}',
#)
#def test_vorbis(self, sample_rate, num_channels):
#"""`apply_effects_file` works on various vorbis format"""
#channels_first = True
#effects = [["band", "300", "10"]]
#input_path = self.get_temp_path("input.vorbis")
#reference_path = self.get_temp_path("reference.wav")
#sox_utils.gen_audio_file(input_path, sample_rate, num_channels)
#sox_utils.run_sox_effect(input_path, reference_path, effects, output_bitdepth=32)
#expected, expected_sr = load_wav(reference_path)
#found, sr = sox_effects.apply_effects_file(input_path, effects, channels_first=channels_first)
#save_wav(self.get_temp_path("result.wav"), found, sr, channels_first=channels_first)
#assert sr == expected_sr
##self.assertEqual(found, expected)
#np.testing.assert_array_almost_equal(found.numpy(), expected.numpy())
#@skipIfNoExec("sox")
#@skipIfNoSox
class TestFileObject(TempDirMixin, unittest.TestCase):
@parameterized.expand(
[
("wav", None),
]
)
def test_fileobj(self, ext, compression):
"""Applying effects via file object works"""
sample_rate = 16000
channels_first = True
effects = [["band", "300", "10"]]
input_path = self.get_temp_path(f"input.{ext}")
reference_path = self.get_temp_path("reference.wav")
#sox_utils.gen_audio_file(input_path, sample_rate, num_channels=2, compression=compression)
data = get_wav_data("int32", 2, channels_first=channels_first)
save_wav(input_path, data, sample_rate, channels_first=channels_first)
sox_utils.run_sox_effect(input_path, reference_path, effects, output_bitdepth=32)
expected, expected_sr = load_wav(reference_path)
with open(input_path, "rb") as fileobj:
found, sr = sox_effects.apply_effects_file(fileobj, effects, channels_first=channels_first)
save_wav(self.get_temp_path("result.wav"), found, sr, channels_first=channels_first)
assert sr == expected_sr
#self.assertEqual(found, expected)
np.testing.assert_array_almost_equal(found.numpy(), expected.numpy())
@parameterized.expand(
[
("wav", None),
]
)
def test_bytesio(self, ext, compression):
"""Applying effects via BytesIO object works"""
sample_rate = 16000
channels_first = True
effects = [["band", "300", "10"]]
input_path = self.get_temp_path(f"input.{ext}")
reference_path = self.get_temp_path("reference.wav")
#sox_utils.gen_audio_file(input_path, sample_rate, num_channels=2, compression=compression)
data = get_wav_data("int32", 2, channels_first=channels_first)
save_wav(input_path, data, sample_rate, channels_first=channels_first)
sox_utils.run_sox_effect(input_path, reference_path, effects, output_bitdepth=32)
expected, expected_sr = load_wav(reference_path)
with open(input_path, "rb") as file_:
fileobj = io.BytesIO(file_.read())
found, sr = sox_effects.apply_effects_file(fileobj, effects, channels_first=channels_first)
save_wav(self.get_temp_path("result.wav"), found, sr, channels_first=channels_first)
assert sr == expected_sr
#self.assertEqual(found, expected)
print("found")
print(found)
print("expected")
print(expected)
np.testing.assert_array_almost_equal(found.numpy(), expected.numpy())
@parameterized.expand(
[
("wav", None),
]
)
def test_tarfile(self, ext, compression):
"""Applying effects to compressed audio via file-like file works"""
sample_rate = 16000
channels_first = True
effects = [["band", "300", "10"]]
audio_file = f"input.{ext}"
input_path = self.get_temp_path(audio_file)
reference_path = self.get_temp_path("reference.wav")
archive_path = self.get_temp_path("archive.tar.gz")
data = get_wav_data("int32", 2, channels_first=channels_first)
save_wav(input_path, data, sample_rate, channels_first=channels_first)
# sox_utils.gen_audio_file(input_path, sample_rate, num_channels=2, compression=compression)
sox_utils.run_sox_effect(input_path, reference_path, effects, output_bitdepth=32)
expected, expected_sr = load_wav(reference_path)
with tarfile.TarFile(archive_path, "w") as tarobj:
tarobj.add(input_path, arcname=audio_file)
with tarfile.TarFile(archive_path, "r") as tarobj:
fileobj = tarobj.extractfile(audio_file)
found, sr = sox_effects.apply_effects_file(fileobj, effects, channels_first=channels_first)
save_wav(self.get_temp_path("result.wav"), found, sr, channels_first=channels_first)
assert sr == expected_sr
#self.assertEqual(found, expected)
np.testing.assert_array_almost_equal(found.numpy(), expected.numpy())
if __name__ == '__main__':
unittest.main()

@ -0,0 +1,19 @@
from .wav_utils import get_wav_data, load_wav, save_wav, normalize_wav
from .parameterized_utils import nested_params
from .data_utils import get_sinusoid, load_params, load_effects_params
from .case_utils import (
TempDirMixin,
name_func
)
__all__ = [
"get_wav_data",
"load_wav",
"save_wav",
"normalize_wav",
"load_params",
"nested_params",
"get_sinusoid",
"name_func",
"load_effects_params"
]

@ -0,0 +1,61 @@
import functools
import os.path
import shutil
import subprocess
import sys
import tempfile
import time
import unittest
#code is from:https://github.com/pytorch/audio/blob/main/test/torchaudio_unittest/common_utils/case_utils.py
import paddle
from paddlespeech.audio._internal.module_utils import (
is_kaldi_available,
is_module_available,
is_sox_available,
)
def name_func(func, _, params):
return f'{func.__name__}_{"_".join(str(arg) for arg in params.args)}'
class TempDirMixin:
"""Mixin to provide easy access to temp dir"""
temp_dir_ = None
@classmethod
def get_base_temp_dir(cls):
# If PADDLEAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory.
# this is handy for debugging.
key = "PADDLEAUDIO_TEST_TEMP_DIR"
if key in os.environ:
return os.environ[key]
if cls.temp_dir_ is None:
cls.temp_dir_ = tempfile.TemporaryDirectory()
return cls.temp_dir_.name
@classmethod
def tearDownClass(cls):
if cls.temp_dir_ is not None:
try:
cls.temp_dir_.cleanup()
cls.temp_dir_ = None
except PermissionError:
# On Windows there is a know issue with `shutil.rmtree`,
# which fails intermittenly.
#
# https://github.com/python/cpython/issues/74168
#
# We observed this on CircleCI, where Windows job raises
# PermissionError.
#
# Following the above thread, we ignore it.
pass
super().tearDownClass()
def get_temp_path(self, *paths):
temp_dir = os.path.join(self.get_base_temp_dir(), self.id())
path = os.path.join(temp_dir, *paths)
os.makedirs(os.path.dirname(path), exist_ok=True)
return path

@ -0,0 +1,63 @@
import json
from itertools import product
import os
from parameterized import param, parameterized
#def get_asset_path(*paths):
#"""Return full path of a test asset"""
#return os.path.join(_TEST_DIR_PATH, "assets", *paths)
#def load_params(*paths):
#with open(get_asset_path(*paths), "r") as file:
#return [param(json.loads(line)) for line in file]
#def load_effects_params(*paths):
#params = []
#with open(get_asset_path(*paths), "r") as file:
#for line in file:
#data = json.loads(line)
#for effect in data["effects"]:
#for i, arg in enumerate(effect):
#if arg.startswith("<ASSET_DIR>"):
#effect[i] = arg.replace("<ASSET_DIR>", get_asset_path())
#params.append(param(data))
#return params
def _name_func(func, _, params):
strs = []
for arg in params.args:
if isinstance(arg, tuple):
strs.append("_".join(str(a) for a in arg))
else:
strs.append(str(arg))
# sanitize the test name
name = "_".join(strs)
return parameterized.to_safe_name(f"{func.__name__}_{name}")
def nested_params(*params_set, name_func=_name_func):
"""Generate the cartesian product of the given list of parameters.
Args:
params_set (list of parameters): Parameters. When using ``parameterized.param`` class,
all the parameters have to be specified with the class, only using kwargs.
"""
flatten = [p for params in params_set for p in params]
# Parameters to be nested are given as list of plain objects
if all(not isinstance(p, param) for p in flatten):
args = list(product(*params_set))
return parameterized.expand(args, name_func=_name_func)
# Parameters to be nested are given as list of `parameterized.param`
if not all(isinstance(p, param) for p in flatten):
raise TypeError("When using ``parameterized.param``, " "all the parameters have to be of the ``param`` type.")
if any(p.args for p in flatten):
raise ValueError(
"When using ``parameterized.param``, " "all the parameters have to be provided as keyword argument."
)
args = [param()]
for params in params_set:
args = [param(**x.kwargs, **y.kwargs) for x in args for y in params]
return parameterized.expand(args)

@ -0,0 +1,116 @@
import subprocess
import sys
import warnings
def get_encoding(dtype):
encodings = {
"float32": "floating-point",
"int32": "signed-integer",
"int16": "signed-integer",
"uint8": "unsigned-integer",
}
return encodings[dtype]
def get_bit_depth(dtype):
bit_depths = {
"float32": 32,
"int32": 32,
"int16": 16,
"uint8": 8,
}
return bit_depths[dtype]
def gen_audio_file(
path,
sample_rate,
num_channels,
*,
encoding=None,
bit_depth=None,
compression=None,
attenuation=None,
duration=1,
comment_file=None,
):
"""Generate synthetic audio file with `sox` command."""
if path.endswith(".wav"):
warnings.warn("Use get_wav_data and save_wav to generate wav file for accurate result.")
command = [
"sox",
"-V3", # verbose
"--no-dither", # disable automatic dithering
"-R",
# -R is supposed to be repeatable, though the implementation looks suspicious
# and not setting the seed to a fixed value.
# https://fossies.org/dox/sox-14.4.2/sox_8c_source.html
# search "sox_globals.repeatable"
]
if bit_depth is not None:
command += ["--bits", str(bit_depth)]
command += [
"--rate",
str(sample_rate),
"--null", # no input
"--channels",
str(num_channels),
]
if compression is not None:
command += ["--compression", str(compression)]
if bit_depth is not None:
command += ["--bits", str(bit_depth)]
if encoding is not None:
command += ["--encoding", str(encoding)]
if comment_file is not None:
command += ["--comment-file", str(comment_file)]
command += [
str(path),
"synth",
str(duration), # synthesizes for the given duration [sec]
"sawtooth",
"1",
# saw tooth covers the both ends of value range, which is a good property for test.
# similar to linspace(-1., 1.)
# this introduces bigger boundary effect than sine when converted to mp3
]
if attenuation is not None:
command += ["vol", f"-{attenuation}dB"]
print(" ".join(command), file=sys.stderr)
subprocess.run(command, check=True)
def convert_audio_file(src_path, dst_path, *, encoding=None, bit_depth=None, compression=None):
"""Convert audio file with `sox` command."""
command = ["sox", "-V3", "--no-dither", "-R", str(src_path)]
if encoding is not None:
command += ["--encoding", str(encoding)]
if bit_depth is not None:
command += ["--bits", str(bit_depth)]
if compression is not None:
command += ["--compression", str(compression)]
command += [dst_path]
print(" ".join(command), file=sys.stderr)
subprocess.run(command, check=True)
def _flattern(effects):
if not effects:
return effects
if isinstance(effects[0], str):
return effects
return [item for sublist in effects for item in sublist]
def run_sox_effect(input_file, output_file, effect, *, output_sample_rate=None, output_bitdepth=None):
"""Run sox effects"""
effect = _flattern(effect)
command = ["sox", "-V", "--no-dither", input_file]
if output_bitdepth:
command += ["--bits", str(output_bitdepth)]
command += [output_file] + effect
if output_sample_rate:
command += ["rate", str(output_sample_rate)]
print(" ".join(command))
subprocess.run(command, check=True)

@ -0,0 +1,102 @@
from typing import Optional
import scipy.io.wavfile
import paddle
import numpy as np
def normalize_wav(tensor: paddle.Tensor) -> paddle.Tensor:
if tensor.dtype == paddle.float32:
pass
elif tensor.dtype == paddle.int32:
tensor = paddle.cast(tensor, paddle.float32)
tensor[tensor > 0] /= 2147483647.0
tensor[tensor < 0] /= 2147483648.0
elif tensor.dtype == paddle.int16:
tensor = paddle.cast(tensor, paddle.float32)
tensor[tensor > 0] /= 32767.0
tensor[tensor < 0] /= 32768.0
elif tensor.dtype == paddle.uint8:
tensor = paddle.cast(tensor, paddle.float32) - 128
tensor[tensor > 0] /= 127.0
tensor[tensor < 0] /= 128.0
return tensor
def get_wav_data(
dtype: str,
num_channels: int,
*,
num_frames: Optional[int] = None,
normalize: bool = True,
channels_first: bool = True,
):
"""Generate linear signal of the given dtype and num_channels
Data range is
[-1.0, 1.0] for float32,
[-2147483648, 2147483647] for int32
[-32768, 32767] for int16
[0, 255] for uint8
num_frames allow to change the linear interpolation parameter.
Default values are 256 for uint8, else 1 << 16.
1 << 16 as default is so that int16 value range is completely covered.
"""
dtype_ = getattr(paddle, dtype)
if num_frames is None:
if dtype == "uint8":
num_frames = 256
else:
num_frames = 1 << 16
# paddle linspace not support uint8, int8, int16
#if dtype == "uint8":
# base = paddle.linspace(0, 255, num_frames, dtype=dtype_)
#dtype_np = getattr(np, dtype)
#base_np = np.linspace(0, 255, num_frames, dtype_np)
#base = paddle.to_tensor(base_np, dtype=dtype_)
#elif dtype == "int8":
# base = paddle.linspace(-128, 127, num_frames, dtype=dtype_)
#dtype_np = getattr(np, dtype)
#base_np = np.linspace(-128, 127, num_frames, dtype_np)
#base = paddle.to_tensor(base_np, dtype=dtype_)
if dtype == "float32":
base = paddle.linspace(-1.0, 1.0, num_frames, dtype=dtype_)
elif dtype == "float64":
base = paddle.linspace(-1.0, 1.0, num_frames, dtype=dtype_)
elif dtype == "int32":
base = paddle.linspace(-2147483648, 2147483647, num_frames, dtype=dtype_)
#elif dtype == "int16":
# base = paddle.linspace(-32768, 32767, num_frames, dtype=dtype_)
#dtype_np = getattr(np, dtype)
#base_np = np.linspace(-32768, 32767, num_frames, dtype_np)
#base = paddle.to_tensor(base_np, dtype=dtype_)
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
data = base.tile([num_channels, 1])
if not channels_first:
data = data.transpose([1, 0])
if normalize:
data = normalize_wav(data)
return data
def load_wav(path: str, normalize=True, channels_first=True) -> paddle.Tensor:
"""Load wav file without paddleaudio"""
sample_rate, data = scipy.io.wavfile.read(path)
data = paddle.to_tensor(data.copy())
if data.ndim == 1:
data = data.unsqueeze(1)
if normalize:
data = normalize_wav(data)
if channels_first:
data = data.transpose([1, 0])
return data, sample_rate
def save_wav(path, data, sample_rate, channels_first=True):
"""Save wav file without paddleaudio"""
if channels_first:
data = data.transpose([1, 0])
scipy.io.wavfile.write(path, sample_rate, data.numpy())
Loading…
Cancel
Save