diff --git a/paddlespeech/audio/src/CMakeLists.txt b/paddlespeech/audio/src/CMakeLists.txt index 7448225ef..4c46fbe24 100644 --- a/paddlespeech/audio/src/CMakeLists.txt +++ b/paddlespeech/audio/src/CMakeLists.txt @@ -35,11 +35,11 @@ if(BUILD_SOX) list( APPEND LIBPADDLEAUDIO_SOURCES - sox/io.cpp - sox/utils.cpp - sox/effects.cpp - sox/effects_chain.cpp - sox/types.cpp + #sox/io.cpp + #sox/utils.cpp + #sox/effects.cpp + #sox/effects_chain.cpp + #sox/types.cpp ) list( APPEND @@ -147,6 +147,7 @@ if(BUILD_SOX) pybind/sox/effects.cpp pybind/sox/effects_chain.cpp pybind/sox/io.cpp + pybind/sox/types.cpp pybind/sox/utils.cpp ) endif() diff --git a/paddlespeech/audio/src/pybind/sox/effects.cpp b/paddlespeech/audio/src/pybind/sox/effects.cpp index 96907a670..b69c5358a 100644 --- a/paddlespeech/audio/src/pybind/sox/effects.cpp +++ b/paddlespeech/audio/src/pybind/sox/effects.cpp @@ -1,3 +1,6 @@ +#include +#include + #include "paddlespeech/audio/src/pybind/sox/effects.h" #include "paddlespeech/audio/src/pybind/sox/effects_chain.h" #include "paddlespeech/audio/src/pybind/sox/utils.h" @@ -118,4 +121,137 @@ auto apply_effects_fileobj( tensor, static_cast(chain.getOutputSampleRate())); } +namespace { + +enum SoxEffectsResourceState { NotInitialized, Initialized, ShutDown }; +SoxEffectsResourceState SOX_RESOURCE_STATE = NotInitialized; +std::mutex SOX_RESOUCE_STATE_MUTEX; + +} // namespace + +void initialize_sox_effects() { + const std::lock_guard lock(SOX_RESOUCE_STATE_MUTEX); + + switch (SOX_RESOURCE_STATE) { + case NotInitialized: + if (sox_init() != SOX_SUCCESS) { + throw std::runtime_error("Failed to initialize sox effects."); + }; + SOX_RESOURCE_STATE = Initialized; + break; + case Initialized: + break; + case ShutDown: + throw std::runtime_error( + "SoX Effects has been shut down. Cannot initialize again."); + } +}; + +void shutdown_sox_effects() { + const std::lock_guard lock(SOX_RESOUCE_STATE_MUTEX); + + switch (SOX_RESOURCE_STATE) { + case NotInitialized: + throw std::runtime_error( + "SoX Effects is not initialized. Cannot shutdown."); + case Initialized: + if (sox_quit() != SOX_SUCCESS) { + throw std::runtime_error("Failed to initialize sox effects."); + }; + SOX_RESOURCE_STATE = ShutDown; + break; + case ShutDown: + break; + } +} + +auto apply_effects_tensor( + py::array waveform, + int64_t sample_rate, + const std::vector>& effects, + bool channels_first) -> std::tuple { + validate_input_tensor(waveform); + + // Create SoxEffectsChain + const auto dtype = waveform.dtype(); + paddleaudio::sox_effects_chain::SoxEffectsChain chain( + /*input_encoding=*/get_tensor_encodinginfo(dtype), + /*output_encoding=*/get_tensor_encodinginfo(dtype)); + + // Prepare output buffer + std::vector out_buffer; + out_buffer.reserve(waveform.size()); + + // Build and run effects chain + chain.addInputTensor(&waveform, sample_rate, channels_first); + for (const auto& effect : effects) { + chain.addEffect(effect); + } + chain.addOutputBuffer(&out_buffer); + chain.run(); + + // Create tensor from buffer + auto out_tensor = convert_to_tensor( + /*buffer=*/out_buffer.data(), + /*num_samples=*/out_buffer.size(), + /*num_channels=*/chain.getOutputNumChannels(), + dtype, + /*normalize=*/false, + channels_first); + + return std::tuple( + out_tensor, chain.getOutputSampleRate()); +} + +auto apply_effects_file( + const std::string& path, + const std::vector>& effects, + tl::optional normalize, + tl::optional channels_first, + const tl::optional& format) + -> tl::optional> { + // Open input file + SoxFormat sf(sox_open_read( + path.c_str(), + /*signal=*/nullptr, + /*encoding=*/nullptr, + /*filetype=*/format.has_value() ? format.value().c_str() : nullptr)); + + if (static_cast(sf) == nullptr || + sf->encoding.encoding == SOX_ENCODING_UNKNOWN) { + return {}; + } + + const auto dtype = get_dtype(sf->encoding.encoding, sf->signal.precision); + + // Prepare output + std::vector out_buffer; + out_buffer.reserve(sf->signal.length); + + // Create and run SoxEffectsChain + paddleaudio::sox_effects_chain::SoxEffectsChain chain( + /*input_encoding=*/sf->encoding, + /*output_encoding=*/get_tensor_encodinginfo(dtype)); + + chain.addInputFile(sf); + for (const auto& effect : effects) { + chain.addEffect(effect); + } + chain.addOutputBuffer(&out_buffer); + chain.run(); + + // Create tensor from buffer + bool channels_first_ = channels_first.value_or(true); + auto tensor = convert_to_tensor( + /*buffer=*/out_buffer.data(), + /*num_samples=*/out_buffer.size(), + /*num_channels=*/chain.getOutputNumChannels(), + dtype, + normalize.value_or(true), + channels_first_); + + return std::tuple( + tensor, chain.getOutputSampleRate()); +} + } // namespace paddleaudio::sox_effects diff --git a/paddlespeech/audio/src/pybind/sox/effects.h b/paddlespeech/audio/src/pybind/sox/effects.h index 5e67cb011..6ba53d008 100644 --- a/paddlespeech/audio/src/pybind/sox/effects.h +++ b/paddlespeech/audio/src/pybind/sox/effects.h @@ -15,4 +15,22 @@ auto apply_effects_fileobj( tl::optional format) -> tl::optional>; +void initialize_sox_effects(); + +void shutdown_sox_effects(); + +auto apply_effects_tensor( + py::array waveform, + int64_t sample_rate, + const std::vector>& effects, + bool channels_first) -> std::tuple; + +auto apply_effects_file( + const std::string& path, + const std::vector>& effects, + tl::optional normalize, + tl::optional channels_first, + const tl::optional& format) + -> tl::optional>; + } // namespace paddleaudio::sox_effects diff --git a/paddlespeech/audio/src/pybind/sox/effects_chain.cpp b/paddlespeech/audio/src/pybind/sox/effects_chain.cpp index a106209d6..4ad90da36 100644 --- a/paddlespeech/audio/src/pybind/sox/effects_chain.cpp +++ b/paddlespeech/audio/src/pybind/sox/effects_chain.cpp @@ -9,6 +9,336 @@ namespace paddleaudio::sox_effects_chain { namespace { +/// helper classes for passing the location of input tensor and output buffer +/// +/// drain/flow callback functions require plaing C style function signature and +/// the way to pass extra data is to attach data to sox_effect_t::priv pointer. +/// The following structs will be assigned to sox_effect_t::priv pointer which +/// gives sox_effect_t an access to input Tensor and output buffer object. +struct TensorInputPriv { + size_t index; + py::array* waveform; + int64_t sample_rate; + bool channels_first; +}; + +struct TensorOutputPriv { + std::vector* buffer; +}; +struct FileOutputPriv { + sox_format_t* sf; +}; + +/// Callback function to feed Tensor data to SoxEffectChain. +int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) { + // Retrieve the input Tensor and current index + auto priv = static_cast(effp->priv); + auto index = priv->index; + auto tensor = *(priv->waveform); + auto num_channels = effp->out_signal.channels; + + // Adjust the number of samples to read + const size_t num_samples = tensor.size(); + if (index + *osamp > num_samples) { + *osamp = num_samples - index; + } + // Ensure that it's a multiple of the number of channels + *osamp -= *osamp % num_channels; + + // Slice the input Tensor + // refacor this module, chunk + auto i_frame = index / num_channels; + auto num_frames = *osamp / num_channels; + py::array chunk(tensor.dtype(), {num_frames*num_channels}); + py::buffer_info ori_info = tensor.request(); + py::buffer_info info = chunk.request(); + char* ori_start_ptr = (char*)ori_info.ptr + index * chunk.itemsize() / sizeof(char); + std::memcpy(info.ptr, ori_start_ptr, chunk.nbytes()); + + py::dtype chunk_type = py::dtype("i"); // dtype int32 + py::array new_chunk = py::array(chunk_type, chunk.shape()); + py::buffer_info new_info = new_chunk.request(); + void* ptr = (void*) info.ptr; + int* new_ptr = (int*) new_info.ptr; + // Convert to sox_sample_t (int32_t) + switch (chunk.dtype().num()) { + //case c10::ScalarType::Float: { + case 11: { + // Need to convert to 64-bit precision so that + // values around INT32_MIN/MAX are handled correctly. + float* ptr_f = (float*)ptr; + for (int idx = 0; idx < chunk.size(); ++idx) { + double elem = *ptr_f * 2147483648.; + // *new_ptr = std::clamp(elem, INT32_MIN, INT32_MAX); + if (elem > INT32_MAX) { + *new_ptr = INT32_MAX; + } else if (elem < INT32_MIN) { + *new_ptr = INT32_MIN; + } else { *new_ptr = elem; } + } + break; + } + //case c10::ScalarType::Int: { + case 5: { + break; + } + // case short + case 3: { + int16_t* ptr_s = (int16_t*) ptr; + for (int idx = 0; idx < chunk.size(); ++idx) { + *new_ptr = *ptr_s * 65536; + } + break; + } + // case byte + case 1: { + int8_t* ptr_b = (int8_t*) ptr; + for (int idx = 0; idx < chunk.size(); ++idx) { + *new_ptr = (*ptr_b - 128) * 16777216; + } + break; + } + default: + throw std::runtime_error("Unexpected dtype."); + } + // Write to buffer + memcpy(obuf, (int*)new_info.ptr, *osamp * 4); + priv->index += *osamp; + return (priv->index == num_samples) ? SOX_EOF : SOX_SUCCESS; +} + +/// Callback function to fetch data from SoxEffectChain. +int tensor_output_flow( + sox_effect_t* effp, + sox_sample_t const* ibuf, + sox_sample_t* obuf LSX_UNUSED, + size_t* isamp, + size_t* osamp) { + *osamp = 0; + // Get output buffer + auto out_buffer = static_cast(effp->priv)->buffer; + // Append at the end + out_buffer->insert(out_buffer->end(), ibuf, ibuf + *isamp); + return SOX_SUCCESS; +} + +int file_output_flow( + sox_effect_t* effp, + sox_sample_t const* ibuf, + sox_sample_t* obuf LSX_UNUSED, + size_t* isamp, + size_t* osamp) { + *osamp = 0; + if (*isamp) { + auto sf = static_cast(effp->priv)->sf; + if (sox_write(sf, ibuf, *isamp) != *isamp) { + if (sf->sox_errno) { + std::ostringstream stream; + stream << sf->sox_errstr << " " << sox_strerror(sf->sox_errno) << " " + << sf->filename; + throw std::runtime_error(stream.str()); + } + return SOX_EOF; + } + } + return SOX_SUCCESS; +} + +sox_effect_handler_t* get_tensor_input_handler() { + static sox_effect_handler_t handler{ + /*name=*/"input_tensor", + /*usage=*/NULL, + /*flags=*/SOX_EFF_MCHAN, + /*getopts=*/NULL, + /*start=*/NULL, + /*flow=*/NULL, + /*drain=*/tensor_input_drain, + /*stop=*/NULL, + /*kill=*/NULL, + /*priv_size=*/sizeof(TensorInputPriv)}; + return &handler; +} + +sox_effect_handler_t* get_tensor_output_handler() { + static sox_effect_handler_t handler{ + /*name=*/"output_tensor", + /*usage=*/NULL, + /*flags=*/SOX_EFF_MCHAN, + /*getopts=*/NULL, + /*start=*/NULL, + /*flow=*/tensor_output_flow, + /*drain=*/NULL, + /*stop=*/NULL, + /*kill=*/NULL, + /*priv_size=*/sizeof(TensorOutputPriv)}; + return &handler; +} + +sox_effect_handler_t* get_file_output_handler() { + static sox_effect_handler_t handler{ + /*name=*/"output_file", + /*usage=*/NULL, + /*flags=*/SOX_EFF_MCHAN, + /*getopts=*/NULL, + /*start=*/NULL, + /*flow=*/file_output_flow, + /*drain=*/NULL, + /*stop=*/NULL, + /*kill=*/NULL, + /*priv_size=*/sizeof(FileOutputPriv)}; + return &handler; +} + +} // namespace + +SoxEffect::SoxEffect(sox_effect_t* se) noexcept : se_(se) {} + +SoxEffect::~SoxEffect() { + if (se_ != nullptr) { + free(se_); + } +} + +SoxEffect::operator sox_effect_t*() const { + return se_; +} + +auto SoxEffect::operator->() noexcept -> sox_effect_t* { + return se_; +} + +SoxEffectsChain::SoxEffectsChain( + sox_encodinginfo_t input_encoding, + sox_encodinginfo_t output_encoding) + : in_enc_(input_encoding), + out_enc_(output_encoding), + in_sig_(), + interm_sig_(), + out_sig_(), + sec_(sox_create_effects_chain(&in_enc_, &out_enc_)) { + if (!sec_) { + throw std::runtime_error("Failed to create effect chain."); + } +} + +SoxEffectsChain::~SoxEffectsChain() { + if (sec_ != nullptr) { + sox_delete_effects_chain(sec_); + } +} + +void SoxEffectsChain::run() { + sox_flow_effects(sec_, NULL, NULL); +} + +void SoxEffectsChain::addInputTensor( + py::array* waveform, + int64_t sample_rate, + bool channels_first) { + in_sig_ = get_signalinfo(waveform, sample_rate, "wav", channels_first); + interm_sig_ = in_sig_; + SoxEffect e(sox_create_effect(get_tensor_input_handler())); + auto priv = static_cast(e->priv); + priv->index = 0; + priv->waveform = waveform; + priv->sample_rate = sample_rate; + priv->channels_first = channels_first; + if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) { + throw std::runtime_error( + "Internal Error: Failed to add effect: input_tensor"); + } +} + +void SoxEffectsChain::addOutputBuffer( + std::vector* output_buffer) { + SoxEffect e(sox_create_effect(get_tensor_output_handler())); + static_cast(e->priv)->buffer = output_buffer; + if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) { + throw std::runtime_error( + "Internal Error: Failed to add effect: output_tensor"); + } +} + +void SoxEffectsChain::addInputFile(sox_format_t* sf) { + in_sig_ = sf->signal; + interm_sig_ = in_sig_; + SoxEffect e(sox_create_effect(sox_find_effect("input"))); + char* opts[] = {(char*)sf}; + sox_effect_options(e, 1, opts); + if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) { + std::ostringstream stream; + stream << "Internal Error: Failed to add effect: input " << sf->filename; + throw std::runtime_error(stream.str()); + } +} + +void SoxEffectsChain::addOutputFile(sox_format_t* sf) { + out_sig_ = sf->signal; + SoxEffect e(sox_create_effect(get_file_output_handler())); + static_cast(e->priv)->sf = sf; + if (sox_add_effect(sec_, e, &interm_sig_, &out_sig_) != SOX_SUCCESS) { + std::ostringstream stream; + stream << "Internal Error: Failed to add effect: output " << sf->filename; + throw std::runtime_error(stream.str()); + } +} + +void SoxEffectsChain::addEffect(const std::vector effect) { + const auto num_args = effect.size(); + if (num_args == 0) { + throw std::runtime_error("Invalid argument: empty effect."); + } + const auto name = effect[0]; + if (UNSUPPORTED_EFFECTS.find(name) != UNSUPPORTED_EFFECTS.end()) { + std::ostringstream stream; + stream << "Unsupported effect: " << name; + throw std::runtime_error(stream.str()); + } + + auto returned_effect = sox_find_effect(name.c_str()); + if (!returned_effect) { + std::ostringstream stream; + stream << "Unsupported effect: " << name; + throw std::runtime_error(stream.str()); + } + SoxEffect e(sox_create_effect(returned_effect)); + const auto num_options = num_args - 1; + + std::vector opts; + for (size_t i = 1; i < num_args; ++i) { + opts.push_back((char*)effect[i].c_str()); + } + if (sox_effect_options(e, num_options, num_options ? opts.data() : nullptr) != + SOX_SUCCESS) { + std::ostringstream stream; + stream << "Invalid effect option:"; + for (const auto& v : effect) { + stream << " " << v; + } + throw std::runtime_error(stream.str()); + } + + if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) { + std::ostringstream stream; + stream << "Internal Error: Failed to add effect: \"" << name; + for (size_t i = 1; i < num_args; ++i) { + stream << " " << effect[i]; + } + stream << "\""; + throw std::runtime_error(stream.str()); + } +} + +int64_t SoxEffectsChain::getOutputNumChannels() { + return interm_sig_.channels; +} + +int64_t SoxEffectsChain::getOutputSampleRate() { + return interm_sig_.rate; +} + +namespace { + /// helper classes for passing file-like object to SoxEffectChain struct FileObjInputPriv { sox_format_t* sf; diff --git a/paddlespeech/audio/src/pybind/sox/effects_chain.h b/paddlespeech/audio/src/pybind/sox/effects_chain.h index 3de0161e3..6fb994b5a 100644 --- a/paddlespeech/audio/src/pybind/sox/effects_chain.h +++ b/paddlespeech/audio/src/pybind/sox/effects_chain.h @@ -1,9 +1,60 @@ #pragma once -#include "paddlespeech/audio/src/sox/effects_chain.h" +#include +#include "paddlespeech/audio/src/pybind/sox/utils.h" namespace paddleaudio::sox_effects_chain { +// Helper struct to safely close sox_effect_t* pointer returned by +// sox_create_effect + +struct SoxEffect { + explicit SoxEffect(sox_effect_t* se) noexcept; + SoxEffect(const SoxEffect& other) = delete; + SoxEffect(const SoxEffect&& other) = delete; + auto operator=(const SoxEffect& other) -> SoxEffect& = delete; + auto operator=(SoxEffect&& other) -> SoxEffect& = delete; + ~SoxEffect(); + operator sox_effect_t*() const; + auto operator->() noexcept -> sox_effect_t*; + + private: + sox_effect_t* se_; +}; + +// Helper struct to safely close sox_effects_chain_t with handy methods +class SoxEffectsChain { + const sox_encodinginfo_t in_enc_; + const sox_encodinginfo_t out_enc_; + + protected: + sox_signalinfo_t in_sig_; + sox_signalinfo_t interm_sig_; + sox_signalinfo_t out_sig_; + sox_effects_chain_t* sec_; + + public: + explicit SoxEffectsChain( + sox_encodinginfo_t input_encoding, + sox_encodinginfo_t output_encoding); + SoxEffectsChain(const SoxEffectsChain& other) = delete; + SoxEffectsChain(const SoxEffectsChain&& other) = delete; + SoxEffectsChain& operator=(const SoxEffectsChain& other) = delete; + SoxEffectsChain& operator=(SoxEffectsChain&& other) = delete; + ~SoxEffectsChain(); + void run(); + void addInputTensor( + py::array* waveform, + int64_t sample_rate, + bool channels_first); + void addInputFile(sox_format_t* sf); + void addOutputBuffer(std::vector* output_buffer); + void addOutputFile(sox_format_t* sf); + void addEffect(const std::vector effect); + int64_t getOutputNumChannels(); + int64_t getOutputSampleRate(); +}; + class SoxEffectsChainPyBind : public SoxEffectsChain { using SoxEffectsChain::SoxEffectsChain; diff --git a/paddlespeech/audio/src/pybind/sox/io.cpp b/paddlespeech/audio/src/pybind/sox/io.cpp index 6e3230f27..4c27e6aab 100644 --- a/paddlespeech/audio/src/pybind/sox/io.cpp +++ b/paddlespeech/audio/src/pybind/sox/io.cpp @@ -3,14 +3,11 @@ #include "paddlespeech/audio/src/pybind/sox/io.h" #include "paddlespeech/audio/src/pybind/sox/effects.h" +#include "paddlespeech/audio/src/pybind/sox/types.h" #include "paddlespeech/audio/src/pybind/sox/effects_chain.h" #include "paddlespeech/audio/src/pybind/sox/utils.h" #include "paddlespeech/audio/src/optional/optional.hpp" -#include "paddlespeech/audio/src/sox/io.h" -#include "paddlespeech/audio/src/sox/types.h" -#include "paddlespeech/audio/src/sox/utils.h" - using namespace paddleaudio::sox_utils; namespace paddleaudio { @@ -108,6 +105,73 @@ tl::optional> load_audio_fileobj( std::move(fileobj), effects, normalize, channels_first, std::move(format)); } +tl::optional> load_audio_file( + const std::string& path, + const tl::optional& frame_offset, + const tl::optional& num_frames, + tl::optional normalize, + tl::optional channels_first, + const tl::optional& format) { + auto effects = get_effects(frame_offset, num_frames); + return paddleaudio::sox_effects::apply_effects_file( + path, effects, normalize, channels_first, format); +} + +void save_audio_file(const std::string& path, + py::array tensor, + int64_t sample_rate, + bool channels_first, + tl::optional compression, + tl::optional format, + tl::optional encoding, + tl::optional bits_per_sample) { + validate_input_tensor(tensor); + + const auto filetype = [&]() { + if (format.has_value()) return format.value(); + return get_filetype(path); + }(); + + if (filetype == "amr-nb") { + const auto num_channels = tensor.shape(channels_first ? 0 : 1); + //TORCH_CHECK(num_channels == 1, + // "amr-nb format only supports single channel audio."); + } else if (filetype == "htk") { + const auto num_channels = tensor.shape(channels_first ? 0 : 1); + // TORCH_CHECK(num_channels == 1, + // "htk format only supports single channel audio."); + } else if (filetype == "gsm") { + const auto num_channels = tensor.shape(channels_first ? 0 : 1); + //TORCH_CHECK(num_channels == 1, + // "gsm format only supports single channel audio."); + //TORCH_CHECK(sample_rate == 8000, + // "gsm format only supports a sampling rate of 8kHz."); + } + const auto signal_info = + get_signalinfo(&tensor, sample_rate, filetype, channels_first); + const auto encoding_info = get_encodinginfo_for_save( + filetype, tensor.dtype(), compression, encoding, bits_per_sample); + + SoxFormat sf(sox_open_write(path.c_str(), + &signal_info, + &encoding_info, + /*filetype=*/filetype.c_str(), + /*oob=*/nullptr, + /*overwrite_permitted=*/nullptr)); + + if (static_cast(sf) == nullptr) { + throw std::runtime_error( + "Error saving audio file: failed to open file " + path); + } + + paddleaudio::sox_effects_chain::SoxEffectsChain chain( + /*input_encoding=*/get_tensor_encodinginfo(tensor.dtype()), + /*output_encoding=*/sf->encoding); + chain.addInputTensor(&tensor, sample_rate, channels_first); + chain.addOutputFile(sf); + chain.run(); +} + namespace { // helper class to automatically release buffer, to be used by // save_audio_fileobj diff --git a/paddlespeech/audio/src/pybind/sox/io.h b/paddlespeech/audio/src/pybind/sox/io.h index ca03b5db3..94ce18f22 100644 --- a/paddlespeech/audio/src/pybind/sox/io.h +++ b/paddlespeech/audio/src/pybind/sox/io.h @@ -16,14 +16,13 @@ auto get_info_file(const std::string &path, const std::string &format) auto get_info_fileobj(py::object fileobj, const std::string &format) -> std::tuple; -auto load_audio_fileobj( +tl::optional> load_audio_fileobj( py::object fileobj, - tl::optional frame_offset, - tl::optional num_frames, + const tl::optional& frame_offset, + const tl::optional& num_frames, tl::optional normalize, tl::optional channels_first, - tl::optional format) - -> tl::optional>; + const tl::optional& format); void save_audio_fileobj( py::object fileobj, @@ -35,5 +34,28 @@ void save_audio_fileobj( tl::optional encoding, tl::optional bits_per_sample); +auto get_effects(const tl::optional& frame_offset, + const tl::optional& num_frames) + -> std::vector>; + + +tl::optional> load_audio_file( + const std::string& path, + const tl::optional& frame_offset, + const tl::optional& num_frames, + tl::optional normalize, + tl::optional channels_first, + const tl::optional& format); + +void save_audio_file(const std::string& path, + py::array tensor, + int64_t sample_rate, + bool channels_first, + tl::optional compression, + tl::optional format, + tl::optional encoding, + tl::optional bits_per_sample); + + } // namespace paddleaudio } // namespace sox_io diff --git a/paddlespeech/audio/src/pybind/sox/types.cpp b/paddlespeech/audio/src/pybind/sox/types.cpp new file mode 100644 index 000000000..8e3e61373 --- /dev/null +++ b/paddlespeech/audio/src/pybind/sox/types.cpp @@ -0,0 +1,143 @@ +//code is from: https://github.com/pytorch/audio/blob/main/torchaudio/csrc/sox/types.cpp + +#include "paddlespeech/audio/src/pybind/sox/types.h" +#include +#include + +namespace paddleaudio { +namespace sox_utils { + +Format get_format_from_string(const std::string& format) { + if (format == "wav") + return Format::WAV; + if (format == "mp3") + return Format::MP3; + if (format == "flac") + return Format::FLAC; + if (format == "ogg" || format == "vorbis") + return Format::VORBIS; + if (format == "amr-nb") + return Format::AMR_NB; + if (format == "amr-wb") + return Format::AMR_WB; + if (format == "amb") + return Format::AMB; + if (format == "sph") + return Format::SPHERE; + if (format == "htk") + return Format::HTK; + if (format == "gsm") + return Format::GSM; + std::ostringstream stream; + stream << "Internal Error: unexpected format value: " << format; + throw std::runtime_error(stream.str()); +} + +std::string to_string(Encoding v) { + switch (v) { + case Encoding::UNKNOWN: + return "UNKNOWN"; + case Encoding::PCM_SIGNED: + return "PCM_S"; + case Encoding::PCM_UNSIGNED: + return "PCM_U"; + case Encoding::PCM_FLOAT: + return "PCM_F"; + case Encoding::FLAC: + return "FLAC"; + case Encoding::ULAW: + return "ULAW"; + case Encoding::ALAW: + return "ALAW"; + case Encoding::MP3: + return "MP3"; + case Encoding::VORBIS: + return "VORBIS"; + case Encoding::AMR_WB: + return "AMR_WB"; + case Encoding::AMR_NB: + return "AMR_NB"; + case Encoding::OPUS: + return "OPUS"; + default: + throw std::runtime_error("Internal Error: unexpected encoding."); + } +} + +Encoding get_encoding_from_option(const tl::optional encoding) { + if (!encoding.has_value()) + return Encoding::NOT_PROVIDED; + std::string v = encoding.value(); + if (v == "PCM_S") + return Encoding::PCM_SIGNED; + if (v == "PCM_U") + return Encoding::PCM_UNSIGNED; + if (v == "PCM_F") + return Encoding::PCM_FLOAT; + if (v == "ULAW") + return Encoding::ULAW; + if (v == "ALAW") + return Encoding::ALAW; + std::ostringstream stream; + stream << "Internal Error: unexpected encoding value: " << v; + throw std::runtime_error(stream.str()); +} + +BitDepth get_bit_depth_from_option(const tl::optional bit_depth) { + if (!bit_depth.has_value()) + return BitDepth::NOT_PROVIDED; + int64_t v = bit_depth.value(); + switch (v) { + case 8: + return BitDepth::B8; + case 16: + return BitDepth::B16; + case 24: + return BitDepth::B24; + case 32: + return BitDepth::B32; + case 64: + return BitDepth::B64; + default: { + std::ostringstream s; + s << "Internal Error: unexpected bit depth value: " << v; + throw std::runtime_error(s.str()); + } + } +} + +std::string get_encoding(sox_encoding_t encoding) { + switch (encoding) { + case SOX_ENCODING_UNKNOWN: + return "UNKNOWN"; + case SOX_ENCODING_SIGN2: + return "PCM_S"; + case SOX_ENCODING_UNSIGNED: + return "PCM_U"; + case SOX_ENCODING_FLOAT: + return "PCM_F"; + case SOX_ENCODING_FLAC: + return "FLAC"; + case SOX_ENCODING_ULAW: + return "ULAW"; + case SOX_ENCODING_ALAW: + return "ALAW"; + case SOX_ENCODING_MP3: + return "MP3"; + case SOX_ENCODING_VORBIS: + return "VORBIS"; + case SOX_ENCODING_AMR_WB: + return "AMR_WB"; + case SOX_ENCODING_AMR_NB: + return "AMR_NB"; + case SOX_ENCODING_OPUS: + return "OPUS"; + case SOX_ENCODING_GSM: + return "GSM"; + default: + return "UNKNOWN"; + } +} + +} // namespace sox_utils +} // namespace paddleaudio diff --git a/paddlespeech/audio/src/pybind/sox/types.h b/paddlespeech/audio/src/pybind/sox/types.h new file mode 100644 index 000000000..824c0f632 --- /dev/null +++ b/paddlespeech/audio/src/pybind/sox/types.h @@ -0,0 +1,58 @@ +//code is from: https://github.com/pytorch/audio/blob/main/torchaudio/csrc/sox/types.h +#pragma once + +#include +#include "paddlespeech/audio/src/optional/optional.hpp" + +namespace paddleaudio { +namespace sox_utils { + +enum class Format { + WAV, + MP3, + FLAC, + VORBIS, + AMR_NB, + AMR_WB, + AMB, + SPHERE, + GSM, + HTK, +}; + +Format get_format_from_string(const std::string& format); + +enum class Encoding { + NOT_PROVIDED, + UNKNOWN, + PCM_SIGNED, + PCM_UNSIGNED, + PCM_FLOAT, + FLAC, + ULAW, + ALAW, + MP3, + VORBIS, + AMR_WB, + AMR_NB, + OPUS, +}; + +std::string to_string(Encoding v); +Encoding get_encoding_from_option(const tl::optional encoding); + +enum class BitDepth : unsigned { + NOT_PROVIDED = 0, + B8 = 8, + B16 = 16, + B24 = 24, + B32 = 32, + B64 = 64, +}; + +BitDepth get_bit_depth_from_option(const tl::optional bit_depth); + +std::string get_encoding(sox_encoding_t encoding); + +} // namespace sox_utils +} // namespace torchaudio \ No newline at end of file diff --git a/paddlespeech/audio/src/pybind/sox/utils.cpp b/paddlespeech/audio/src/pybind/sox/utils.cpp index 24a2817d2..a930f8cdd 100644 --- a/paddlespeech/audio/src/pybind/sox/utils.cpp +++ b/paddlespeech/audio/src/pybind/sox/utils.cpp @@ -1,7 +1,9 @@ // Copyright (c) 2017 Facebook Inc. (Soumith Chintala), // All rights reserved. +#include #include "paddlespeech/audio/src/pybind/sox/utils.h" +#include "paddlespeech/audio/src/pybind/sox/types.h" #include @@ -35,6 +37,485 @@ auto read_fileobj(py::object *fileobj, const uint64_t size, char *buffer) return num_read; } + +void set_seed(const int64_t seed) { + sox_get_globals()->ranqd1 = static_cast(seed); +} + +void set_verbosity(const int64_t verbosity) { + sox_get_globals()->verbosity = static_cast(verbosity); +} + +void set_use_threads(const bool use_threads) { + sox_get_globals()->use_threads = static_cast(use_threads); +} + +void set_buffer_size(const int64_t buffer_size) { + sox_get_globals()->bufsiz = static_cast(buffer_size); +} + +int64_t get_buffer_size() { + return sox_get_globals()->bufsiz; +} + +std::vector> list_effects() { + std::vector> effects; + for (const sox_effect_fn_t* fns = sox_get_effect_fns(); *fns; ++fns) { + const sox_effect_handler_t* handler = (*fns)(); + if (handler && handler->name) { + if (UNSUPPORTED_EFFECTS.find(handler->name) == + UNSUPPORTED_EFFECTS.end()) { + effects.emplace_back(std::vector{ + handler->name, + handler->usage ? std::string(handler->usage) : std::string("")}); + } + } + } + return effects; +} + +std::vector list_write_formats() { + std::vector formats; + for (const sox_format_tab_t* fns = sox_get_format_fns(); fns->fn; ++fns) { + const sox_format_handler_t* handler = fns->fn(); + for (const char* const* names = handler->names; *names; ++names) { + if (!strchr(*names, '/') && handler->write) + formats.emplace_back(*names); + } + } + return formats; +} + +std::vector list_read_formats() { + std::vector formats; + for (const sox_format_tab_t* fns = sox_get_format_fns(); fns->fn; ++fns) { + const sox_format_handler_t* handler = fns->fn(); + for (const char* const* names = handler->names; *names; ++names) { + if (!strchr(*names, '/') && handler->read) + formats.emplace_back(*names); + } + } + return formats; +} + +SoxFormat::SoxFormat(sox_format_t* fd) noexcept : fd_(fd) {} +SoxFormat::~SoxFormat() { + close(); +} + +sox_format_t* SoxFormat::operator->() const noexcept { + return fd_; +} +SoxFormat::operator sox_format_t*() const noexcept { + return fd_; +} + +void SoxFormat::close() { + if (fd_ != nullptr) { + sox_close(fd_); + fd_ = nullptr; + } +} + +void validate_input_file(const SoxFormat& sf, const std::string& path) { + if (static_cast(sf) == nullptr) { + throw std::runtime_error( + "Error loading audio file: failed to open file " + path); + } + if (sf->encoding.encoding == SOX_ENCODING_UNKNOWN) { + throw std::runtime_error("Error loading audio file: unknown encoding."); + } +} + +void validate_input_memfile(const SoxFormat &sf) { + return validate_input_file(sf, ""); +} + +void validate_input_tensor(const py::array tensor) { + if (tensor.ndim() != 2) { + throw std::runtime_error("Input tensor has to be 2D."); + } + + char dtype = tensor.dtype().char_(); + bool flag = (dtype == 'f') || (dtype == 'd') || (dtype == 'l') || (dtype == 'i'); + if (flag == false) { + throw std::runtime_error( + "Input tensor has to be one of float32, int32, int16 or uint8 type."); + } +} + +py::dtype get_dtype( + const sox_encoding_t encoding, + const unsigned precision) { + switch (encoding) { + case SOX_ENCODING_UNSIGNED: // 8-bit PCM WAV + return py::dtype('u1'); + case SOX_ENCODING_SIGN2: // 16-bit, 24-bit, or 32-bit PCM WAV + switch (precision) { + case 16: + return py::dtype("i2"); + case 24: // Cast 24-bit to 32-bit. + case 32: + return py::dtype('i'); + default: + throw std::runtime_error( + "Only 16, 24, and 32 bits are supported for signed PCM."); + } + default: + // default to float32 for the other formats, including + // 32-bit flaoting-point WAV, + // MP3, + // FLAC, + // VORBIS etc... + return py::dtype("f"); + } +} + +py::array convert_to_tensor( + sox_sample_t* buffer, + const int32_t num_samples, + const int32_t num_channels, + const py::dtype dtype, + const bool normalize, + const bool channels_first) { + py::array t; + uint64_t dummy = 0; + SOX_SAMPLE_LOCALS; + if (normalize || dtype.char_() == 'f') { + t = py::array(dtype, {num_samples / num_channels, num_channels}); + auto ptr = (float*)t.mutable_data(0, 0); + for (int32_t i = 0; i < num_samples; ++i) { + ptr[i] = SOX_SAMPLE_TO_FLOAT_32BIT(buffer[i], dummy); + } + } else if (dtype.char_() == 'i') { + //t = torch::from_blob( + // buffer, {num_samples / num_channels, num_channels}, torch::kInt32) + // .clone(); + t = py::array(dtype, {num_samples / num_channels, num_channels}); + auto ptr = (int*)t.mutable_data(0, 0); + for (int32_t i = 0; i < num_samples; ++i) { + ptr[i] = buffer[i]; + } + } else if (dtype.char_() == 'h') { // int16 + t = py::array(dtype, {num_samples / num_channels, num_channels}); + auto ptr = (int16_t*)t.mutable_data(0, 0); + for (int32_t i = 0; i < num_samples; ++i) { + ptr[i] = SOX_SAMPLE_TO_SIGNED_16BIT(buffer[i], dummy); + } + } else if (dtype.char_() == 'b') { + //t = torch::empty({num_samples / num_channels, num_channels}, torch::kUInt8); + auto ptr = (uint8_t*)t.mutable_data(0,0); + for (int32_t i = 0; i < num_samples; ++i) { + ptr[i] = SOX_SAMPLE_TO_UNSIGNED_8BIT(buffer[i], dummy); + } + } else { + throw std::runtime_error("Unsupported dtype."); + } + return t; +} + +const std::string get_filetype(const std::string path) { + std::string ext = path.substr(path.find_last_of(".") + 1); + std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); + return ext; +} + +namespace { + +std::tuple get_save_encoding_for_wav( + const std::string format, + py::dtype dtype, + const Encoding& encoding, + const BitDepth& bits_per_sample) { + switch (encoding) { + case Encoding::NOT_PROVIDED: + switch (bits_per_sample) { + case BitDepth::NOT_PROVIDED: + switch (dtype.num()) { + case 11: // float32 numpy dtype num + return std::make_tuple<>(SOX_ENCODING_FLOAT, 32); + case 5: // int numpy dtype num + return std::make_tuple<>(SOX_ENCODING_SIGN2, 32); + case 3: // int16 numpy + return std::make_tuple<>(SOX_ENCODING_SIGN2, 16); + case 1: // byte numpy + return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8); + default: + throw std::runtime_error("Internal Error: Unexpected dtype."); + } + case BitDepth::B8: + return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8); + default: + return std::make_tuple<>( + SOX_ENCODING_SIGN2, static_cast(bits_per_sample)); + } + case Encoding::PCM_SIGNED: + switch (bits_per_sample) { + case BitDepth::NOT_PROVIDED: + return std::make_tuple<>(SOX_ENCODING_SIGN2, 32); + case BitDepth::B8: + throw std::runtime_error( + format + " does not support 8-bit signed PCM encoding."); + default: + return std::make_tuple<>( + SOX_ENCODING_SIGN2, static_cast(bits_per_sample)); + } + case Encoding::PCM_UNSIGNED: + switch (bits_per_sample) { + case BitDepth::NOT_PROVIDED: + case BitDepth::B8: + return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8); + default: + throw std::runtime_error( + format + " only supports 8-bit for unsigned PCM encoding."); + } + case Encoding::PCM_FLOAT: + switch (bits_per_sample) { + case BitDepth::NOT_PROVIDED: + case BitDepth::B32: + return std::make_tuple<>(SOX_ENCODING_FLOAT, 32); + case BitDepth::B64: + return std::make_tuple<>(SOX_ENCODING_FLOAT, 64); + default: + throw std::runtime_error( + format + + " only supports 32-bit or 64-bit for floating-point PCM encoding."); + } + case Encoding::ULAW: + switch (bits_per_sample) { + case BitDepth::NOT_PROVIDED: + case BitDepth::B8: + return std::make_tuple<>(SOX_ENCODING_ULAW, 8); + default: + throw std::runtime_error( + format + " only supports 8-bit for mu-law encoding."); + } + case Encoding::ALAW: + switch (bits_per_sample) { + case BitDepth::NOT_PROVIDED: + case BitDepth::B8: + return std::make_tuple<>(SOX_ENCODING_ALAW, 8); + default: + throw std::runtime_error( + format + " only supports 8-bit for a-law encoding."); + } + default: + throw std::runtime_error( + format + " does not support encoding: " + to_string(encoding)); + } +} + +std::tuple get_save_encoding( + const std::string& format, + const py::dtype dtype, + const tl::optional encoding, + const tl::optional bits_per_sample) { + const Format fmt = get_format_from_string(format); + const Encoding enc = get_encoding_from_option(encoding); + const BitDepth bps = get_bit_depth_from_option(bits_per_sample); + + switch (fmt) { + case Format::WAV: + case Format::AMB: + return get_save_encoding_for_wav(format, dtype, enc, bps); + case Format::MP3: + if (enc != Encoding::NOT_PROVIDED) + throw std::runtime_error("mp3 does not support `encoding` option."); + if (bps != BitDepth::NOT_PROVIDED) + throw std::runtime_error( + "mp3 does not support `bits_per_sample` option."); + return std::make_tuple<>(SOX_ENCODING_MP3, 16); + case Format::HTK: + if (enc != Encoding::NOT_PROVIDED) + throw std::runtime_error("htk does not support `encoding` option."); + if (bps != BitDepth::NOT_PROVIDED) + throw std::runtime_error( + "htk does not support `bits_per_sample` option."); + return std::make_tuple<>(SOX_ENCODING_SIGN2, 16); + case Format::VORBIS: + if (enc != Encoding::NOT_PROVIDED) + throw std::runtime_error("vorbis does not support `encoding` option."); + if (bps != BitDepth::NOT_PROVIDED) + throw std::runtime_error( + "vorbis does not support `bits_per_sample` option."); + return std::make_tuple<>(SOX_ENCODING_VORBIS, 16); + case Format::AMR_NB: + if (enc != Encoding::NOT_PROVIDED) + throw std::runtime_error("amr-nb does not support `encoding` option."); + if (bps != BitDepth::NOT_PROVIDED) + throw std::runtime_error( + "amr-nb does not support `bits_per_sample` option."); + return std::make_tuple<>(SOX_ENCODING_AMR_NB, 16); + case Format::FLAC: + if (enc != Encoding::NOT_PROVIDED) + throw std::runtime_error("flac does not support `encoding` option."); + switch (bps) { + case BitDepth::B32: + case BitDepth::B64: + throw std::runtime_error( + "flac does not support `bits_per_sample` larger than 24."); + default: + return std::make_tuple<>( + SOX_ENCODING_FLAC, static_cast(bps)); + } + case Format::SPHERE: + switch (enc) { + case Encoding::NOT_PROVIDED: + case Encoding::PCM_SIGNED: + switch (bps) { + case BitDepth::NOT_PROVIDED: + return std::make_tuple<>(SOX_ENCODING_SIGN2, 32); + default: + return std::make_tuple<>( + SOX_ENCODING_SIGN2, static_cast(bps)); + } + case Encoding::PCM_UNSIGNED: + throw std::runtime_error( + "sph does not support unsigned integer PCM."); + case Encoding::PCM_FLOAT: + throw std::runtime_error("sph does not support floating point PCM."); + case Encoding::ULAW: + switch (bps) { + case BitDepth::NOT_PROVIDED: + case BitDepth::B8: + return std::make_tuple<>(SOX_ENCODING_ULAW, 8); + default: + throw std::runtime_error( + "sph only supports 8-bit for mu-law encoding."); + } + case Encoding::ALAW: + switch (bps) { + case BitDepth::NOT_PROVIDED: + case BitDepth::B8: + return std::make_tuple<>(SOX_ENCODING_ALAW, 8); + default: + return std::make_tuple<>( + SOX_ENCODING_ALAW, static_cast(bps)); + } + default: + throw std::runtime_error( + "sph does not support encoding: " + encoding.value()); + } + case Format::GSM: + if (enc != Encoding::NOT_PROVIDED) + throw std::runtime_error("gsm does not support `encoding` option."); + if (bps != BitDepth::NOT_PROVIDED) + throw std::runtime_error( + "gsm does not support `bits_per_sample` option."); + return std::make_tuple<>(SOX_ENCODING_GSM, 16); + + default: + throw std::runtime_error("Unsupported format: " + format); + } +} + +unsigned get_precision(const std::string filetype, py::dtype dtype) { + if (filetype == "mp3") + return SOX_UNSPEC; + if (filetype == "flac") + return 24; + if (filetype == "ogg" || filetype == "vorbis") + return SOX_UNSPEC; + if (filetype == "wav" || filetype == "amb") { + switch (dtype.num()) { + case 1: // byte in numpy dype num + return 8; + case 3: // short, in numpy dtype num + return 16; + case 5: // int, numpy dtype + return 32; + case 11: // float, numpy dtype + return 32; + default: + throw std::runtime_error("Unsupported dtype."); + } + } + if (filetype == "sph") + return 32; + if (filetype == "amr-nb") { + return 16; + } + if (filetype == "gsm") { + return 16; + } + if (filetype == "htk") { + return 16; + } + throw std::runtime_error("Unsupported file type: " + filetype); +} + +} // namespace + +sox_signalinfo_t get_signalinfo( + const py::array* waveform, + const int64_t sample_rate, + const std::string filetype, + const bool channels_first) { + return sox_signalinfo_t{ + /*rate=*/static_cast(sample_rate), + /*channels=*/ + static_cast(waveform->shape(channels_first ? 0 : 1)), + /*precision=*/get_precision(filetype, waveform->dtype()), + /*length=*/static_cast(waveform->size())}; +} + +sox_encodinginfo_t get_tensor_encodinginfo(py::dtype dtype) { + sox_encoding_t encoding = [&]() { + switch (dtype.num()) { + case 1: // byte + return SOX_ENCODING_UNSIGNED; + case 3: // short + return SOX_ENCODING_SIGN2; + case 5: // int32 + return SOX_ENCODING_SIGN2; + case 11: // float + return SOX_ENCODING_FLOAT; + default: + throw std::runtime_error("Unsupported dtype."); + } + }(); + unsigned bits_per_sample = [&]() { + switch (dtype.num()) { + case 1: // byte + return 8; + case 3: //short + return 16; + case 5: // int32 + return 32; + case 11: // float + return 32; + default: + throw std::runtime_error("Unsupported dtype."); + } + }(); + return sox_encodinginfo_t{ + /*encoding=*/encoding, + /*bits_per_sample=*/bits_per_sample, + /*compression=*/HUGE_VAL, + /*reverse_bytes=*/sox_option_default, + /*reverse_nibbles=*/sox_option_default, + /*reverse_bits=*/sox_option_default, + /*opposite_endian=*/sox_false}; +} + +sox_encodinginfo_t get_encodinginfo_for_save( + const std::string& format, + const py::dtype dtype, + const tl::optional compression, + const tl::optional encoding, + const tl::optional bits_per_sample) { + auto enc = get_save_encoding(format, dtype, encoding, bits_per_sample); + return sox_encodinginfo_t{ + /*encoding=*/std::get<0>(enc), + /*bits_per_sample=*/std::get<1>(enc), + /*compression=*/compression.value_or(HUGE_VAL), + /*reverse_bytes=*/sox_option_default, + /*reverse_nibbles=*/sox_option_default, + /*reverse_bits=*/sox_option_default, + /*opposite_endian=*/sox_false}; +} + + /* SoxFormat::SoxFormat(sox_format_t *fd) noexcept : fd_(fd) {} SoxFormat::~SoxFormat() { close(); } diff --git a/paddlespeech/audio/src/pybind/sox/utils.h b/paddlespeech/audio/src/pybind/sox/utils.h index fa931b1a9..65223bc0c 100644 --- a/paddlespeech/audio/src/pybind/sox/utils.h +++ b/paddlespeech/audio/src/pybind/sox/utils.h @@ -7,8 +7,6 @@ #include #include #include "paddlespeech/audio/src/optional/optional.hpp" -#include "paddlespeech/audio/src/sox/utils.h" -#include "paddlespeech/audio/src/sox/types.h" namespace py = pybind11; @@ -17,5 +15,102 @@ namespace sox_utils { auto read_fileobj(py::object *fileobj, uint64_t size, char *buffer) -> uint64_t; +void set_seed(const int64_t seed); + +void set_verbosity(const int64_t verbosity); + +void set_use_threads(const bool use_threads); + +void set_buffer_size(const int64_t buffer_size); + +int64_t get_buffer_size(); + +std::vector> list_effects(); + +std::vector list_read_formats(); + +std::vector list_write_formats(); + +//////////////////////////////////////////////////////////////////////////////// +// Utilities for sox_io / sox_effects implementations +//////////////////////////////////////////////////////////////////////////////// + +const std::unordered_set UNSUPPORTED_EFFECTS = + {"input", "output", "spectrogram", "noiseprof", "noisered", "splice"}; + +/// helper class to automatically close sox_format_t* +struct SoxFormat { + explicit SoxFormat(sox_format_t* fd) noexcept; + SoxFormat(const SoxFormat& other) = delete; + SoxFormat(SoxFormat&& other) = delete; + SoxFormat& operator=(const SoxFormat& other) = delete; + SoxFormat& operator=(SoxFormat&& other) = delete; + ~SoxFormat(); + sox_format_t* operator->() const noexcept; + operator sox_format_t*() const noexcept; + + void close(); + + private: + sox_format_t* fd_; +}; + +/// +/// Verify that input Tensor is 2D, CPU and either uin8, int16, int32 or float32 +void validate_input_tensor(const py::array); + +void validate_input_file(const SoxFormat& sf, const std::string& path); + +void validate_input_memfile(const SoxFormat &sf); +/// +/// Get target dtype for the given encoding and precision. +py::dtype get_dtype( + const sox_encoding_t encoding, + const unsigned precision); + +/// +/// Convert sox_sample_t buffer to uint8/int16/int32/float32 Tensor +/// NOTE: This function might modify the values in the input buffer to +/// reduce the number of memory copy. +/// @param buffer Pointer to buffer that contains audio data. +/// @param num_samples The number of samples to read. +/// @param num_channels The number of channels. Used to reshape the resulting +/// Tensor. +/// @param dtype Target dtype. Determines the output dtype and value range in +/// conjunction with normalization. +/// @param noramlize Perform normalization. Only effective when dtype is not +/// kFloat32. When effective, the output tensor is kFloat32 type and value range +/// is [-1.0, 1.0] +/// @param channels_first When True, output Tensor has shape of [num_channels, +/// num_frames]. +py::array convert_to_tensor( + sox_sample_t* buffer, + const int32_t num_samples, + const int32_t num_channels, + const py::dtype dtype, + const bool normalize, + const bool channels_first); + +/// Extract extension from file path +const std::string get_filetype(const std::string path); + +/// Get sox_signalinfo_t for passing a py::array object. +sox_signalinfo_t get_signalinfo( + const py::array* waveform, + const int64_t sample_rate, + const std::string filetype, + const bool channels_first); + +/// Get sox_encodinginfo_t for Tensor I/O +sox_encodinginfo_t get_tensor_encodinginfo(const py::dtype dtype); + +/// Get sox_encodinginfo_t for saving to file/file object +sox_encodinginfo_t get_encodinginfo_for_save( + const std::string& format, + const py::dtype dtype, + const tl::optional compression, + const tl::optional encoding, + const tl::optional bits_per_sample); + } // namespace paddleaudio } // namespace sox_utils