diff --git a/audio/audiotools/__init__.py b/audio/audiotools/__init__.py index 53235f4a6..4191639ee 100644 --- a/audio/audiotools/__init__.py +++ b/audio/audiotools/__init__.py @@ -1,10 +1,11 @@ __version__ = "0.0.1" from .core import AudioSignal from .core import STFTParams -# from .core import Meter +from .core import Meter from .core import util +from .core import highpass_filter, highpass_filters from . import metrics from . import data from . import ml from .data import datasets -from .data import transforms \ No newline at end of file +from .data import transforms diff --git a/audio/audiotools/core/__init__.py b/audio/audiotools/core/__init__.py index ecd2d076a..a4038c4ed 100644 --- a/audio/audiotools/core/__init__.py +++ b/audio/audiotools/core/__init__.py @@ -1,4 +1,15 @@ from . import util +from ._julius import fft_conv1d +from ._julius import FFTConv1d +from ._julius import highpass_filter +from ._julius import highpass_filters +from ._julius import lowpass_filter +from ._julius import LowPassFilter +from ._julius import LowPassFilters +from ._julius import pure_tone +from ._julius import split_bands +from ._julius import SplitBands from .audio_signal import AudioSignal from .audio_signal import STFTParams -from .loudness import Meter \ No newline at end of file +from .loudness import Meter +from .resample import resample_frac diff --git a/audio/audiotools/core/_julius.py b/audio/audiotools/core/_julius.py new file mode 100644 index 000000000..cb23cb656 --- /dev/null +++ b/audio/audiotools/core/_julius.py @@ -0,0 +1,714 @@ +# File under the MIT license, see https://github.com/adefossez/julius/LICENSE for details. +# Author: adefossez, 2020 +""" +Implementation of a FFT based 1D convolution in PaddlePaddle. +While FFT is used in some cases for small kernel sizes, it is not the default for long ones, e.g. 512. +This module implements efficient FFT based convolutions for such cases. A typical +application is for evaluating FIR filters with a long receptive field, typically +evaluated with a stride of 1. +""" +import math +import typing +from typing import Optional +from typing import Sequence + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from .resample import sinc + + +def pad_to(tensor: paddle.Tensor, + target_length: int, + mode: str="constant", + value: float=0.0): + """ + Pad the given tensor to the given length, with 0s on the right. + """ + return F.pad( + tensor, (0, target_length - tensor.shape[-1]), + mode=mode, + value=value, + data_format="NCL") + + +def pure_tone(freq: float, sr: float=128, dur: float=4, device=None): + """ + Return a pure tone, i.e. cosine. + + Args: + freq (float): frequency (in Hz) + sr (float): sample rate (in Hz) + dur (float): duration (in seconds) + """ + time = paddle.arange(int(sr * dur), dtype="float32") / sr + return paddle.cos(2 * math.pi * freq * time) + + +def unfold(_input, kernel_size: int, stride: int): + """1D only unfolding similar to the one from PyTorch. + However PyTorch unfold is extremely slow. + + Given an _input tensor of size `[*, T]` this will return + a tensor `[*, F, K]` with `K` the kernel size, and `F` the number + of frames. The i-th frame is a view onto `i * stride: i * stride + kernel_size`. + This will automatically pad the _input to cover at least once all entries in `_input`. + + Args: + _input (Tensor): tensor for which to return the frames. + kernel_size (int): size of each frame. + stride (int): stride between each frame. + + Shape: + + - Inputs: `_input` is `[*, T]` + - Output: `[*, F, kernel_size]` with `F = 1 + ceil((T - kernel_size) / stride)` + + + ..Warning:: unlike PyTorch unfold, this will pad the _input + so that any position in `_input` is covered by at least one frame. + """ + shape = list(_input.shape) + length = shape.pop(-1) + n_frames = math.ceil((max(length, kernel_size) - kernel_size) / stride) + 1 + tgt_length = (n_frames - 1) * stride + kernel_size + padded = F.pad(_input, (0, tgt_length - length), data_format="NCL") + strides: typing.List[int] = [] + for dim in range(padded.dim()): + strides.append(padded.strides[dim]) + assert strides.pop(-1) == 1, "data should be contiguous" + strides = strides + [stride, 1] + return padded.as_strided(shape + [n_frames, kernel_size], strides) + + +def _new_rfft(x: paddle.Tensor): + z = paddle.fft.rfft(x, axis=-1) + + z_real = paddle.real(z) + z_imag = paddle.imag(z) + + z_view_as_real = paddle.stack([z_real, z_imag], axis=-1) + return z_view_as_real + + +def _new_irfft(x: paddle.Tensor, length: int): + x_real = x[..., 0] + x_imag = x[..., 1] + x_view_as_complex = paddle.complex(x_real, x_imag) + return paddle.fft.irfft(x_view_as_complex, n=length, axis=-1) + + +def _compl_mul_conjugate(a: paddle.Tensor, b: paddle.Tensor): + """ + Given a and b two tensors of dimension 4 + with the last dimension being the real and imaginary part, + returns a multiplied by the conjugate of b, the multiplication + being with respect to the second dimension. + + PaddlePaddle does not have direct support for complex number operations + using einsum in the same manner as PyTorch, but we can manually compute + the equivalent result. + """ + # Extract the real and imaginary parts of a and b + real_a = a[..., 0] + imag_a = a[..., 1] + real_b = b[..., 0] + imag_b = b[..., 1] + + # Compute the multiplication with respect to the second dimension manually + real_part = paddle.einsum("bcft,dct->bdft", real_a, real_b) + paddle.einsum( + "bcft,dct->bdft", imag_a, imag_b) + imag_part = paddle.einsum("bcft,dct->bdft", imag_a, real_b) - paddle.einsum( + "bcft,dct->bdft", real_a, imag_b) + + # Stack the real and imaginary parts together + result = paddle.stack([real_part, imag_part], axis=-1) + return result + + +def fft_conv1d( + _input: paddle.Tensor, + weight: paddle.Tensor, + bias: Optional[paddle.Tensor]=None, + stride: int=1, + padding: int=0, + block_ratio: float=5, ): + """ + Same as `paddle.nn.functional.conv1d` but using FFT for the convolution. + Please check PaddlePaddle documentation for more information. + + Args: + _input (Tensor): _input signal of shape `[B, C, T]`. + weight (Tensor): weight of the convolution `[D, C, K]` with `D` the number + of output channels. + bias (Tensor or None): if not None, bias term for the convolution. + stride (int): stride of convolution. + padding (int): padding to apply to the _input. + block_ratio (float): can be tuned for speed. The _input is splitted in chunks + with a size of `int(block_ratio * kernel_size)`. + + Shape: + + - Inputs: `_input` is `[B, C, T]`, `weight` is `[D, C, K]` and bias is `[D]`. + - Output: `(*, T)` + + + ..note:: + This function is faster than `paddle.nn.functional.conv1d` only in specific cases. + Typically, the kernel size should be of the order of 256 to see any real gain, + for a stride of 1. + + ..Warning:: + Dilation and groups are not supported at the moment. This function might use + more memory than the default Conv1d implementation. + """ + _input = F.pad(_input, (padding, padding), data_format="NCL") + batch, channels, length = _input.shape + out_channels, _, kernel_size = weight.shape + + if length < kernel_size: + raise RuntimeError( + f"Input should be at least as large as the kernel size {kernel_size}, " + f"but it is only {length} samples long.") + if block_ratio < 1: + raise RuntimeError("Block ratio must be greater than 1.") + + block_size: int = min(int(kernel_size * block_ratio), length) + fold_stride = block_size - kernel_size + 1 + weight = pad_to(weight, block_size) + weight_z = _new_rfft(weight) + + # We pad the _input and get the different frames, on which + frames = unfold(_input, block_size, fold_stride) + + frames_z = _new_rfft(frames) + out_z = _compl_mul_conjugate(frames_z, weight_z) + out = _new_irfft(out_z, block_size) + # The last bit is invalid, because FFT will do a circular convolution. + out = out[..., :-kernel_size + 1] + out = out.reshape([batch, out_channels, -1]) + out = out[..., ::stride] + target_length = (length - kernel_size) // stride + 1 + out = out[..., :target_length] + if bias is not None: + out += bias[:, None] + return out + + +class FFTConv1d(paddle.nn.Layer): + """ + Same as `paddle.nn.Conv1D` but based on a custom FFT-based convolution. + Please check PaddlePaddle documentation for more information on `paddle.nn.Conv1D`. + + Args: + in_channels (int): number of _input channels. + out_channels (int): number of output channels. + kernel_size (int): kernel size of convolution. + stride (int): stride of convolution. + padding (int): padding to apply to the _input. + bias (bool): if True, use a bias term. + + ..note:: + This module is faster than `paddle.nn.Conv1D` only in specific cases. + Typically, `kernel_size` should be of the order of 256 to see any real gain, + for a stride of 1. + + ..warning:: + Dilation and groups are not supported at the moment. This module might use + more memory than the default Conv1D implementation. + + >>> fftconv = FFTConv1d(12, 24, 128, 4) + >>> x = paddle.randn([4, 12, 1024]) + >>> print(list(fftconv(x).shape)) + [4, 24, 225] + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int=1, + padding: int=0, + bias: bool=True, ): + super(FFTConv1d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + + # Create a Conv1D layer to initialize weights and bias + conv = paddle.nn.Conv1D( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias_attr=bias) + self.weight = conv.weight + if bias: + self.bias = conv.bias + else: + self.bias = None + + def forward(self, _input: paddle.Tensor): + return fft_conv1d(_input, self.weight, self.bias, self.stride, + self.padding) + + +class LowPassFilters(nn.Layer): + """ + Bank of low pass filters. + """ + + def __init__(self, + cutoffs: Sequence[float], + stride: int=1, + pad: bool=True, + zeros: float=8, + fft: Optional[bool]=None, + dtype="float32"): + super(LowPassFilters, self).__init__() + self.cutoffs = list(cutoffs) + if min(self.cutoffs) < 0: + raise ValueError("Minimum cutoff must be larger than zero.") + if max(self.cutoffs) > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.stride = stride + self.pad = pad + self.zeros = zeros + self.half_size = int(zeros / min([c for c in self.cutoffs if c > 0]) / + 2) + if fft is None: + fft = self.half_size > 32 + self.fft = fft + + # Create filters + window = paddle.audio.functional.get_window( + "hann", 2 * self.half_size + 1, fftbins=False, dtype=dtype) + time = paddle.arange( + -self.half_size, self.half_size + 1, dtype="float32") + filters = [] + for cutoff in cutoffs: + if cutoff == 0: + filter_ = paddle.zeros_like(time) + else: + filter_ = 2 * cutoff * window * sinc(2 * cutoff * math.pi * + time) + # Normalize filter + filter_ /= paddle.sum(filter_) + filters.append(filter_) + filters = paddle.stack(filters)[:, None] + self.filters = self.create_parameter( + shape=filters.shape, + default_initializer=nn.initializer.Constant(value=0.0), + dtype="float32", + is_bias=False, + attr=paddle.ParamAttr(trainable=False), ) + self.filters.set_value(filters) + + def forward(self, _input): + shape = list(_input.shape) + _input = _input.reshape([-1, 1, shape[-1]]) + if self.pad: + _input = F.pad( + _input, (self.half_size, self.half_size), + mode="replicate", + data_format="NCL") + if self.fft: + out = fft_conv1d(_input, self.filters, stride=self.stride) + else: + out = F.conv1d(_input, self.filters, stride=self.stride) + + shape.insert(0, len(self.cutoffs)) + shape[-1] = out.shape[-1] + return out.transpose([1, 0, 2]).reshape(shape) + + +class LowPassFilter(nn.Layer): + """ + Same as `LowPassFilters` but applies a single low pass filter. + """ + + def __init__(self, + cutoff: float, + stride: int=1, + pad: bool=True, + zeros: float=8, + fft: Optional[bool]=None): + super(LowPassFilter, self).__init__() + self._lowpasses = LowPassFilters([cutoff], stride, pad, zeros, fft) + + @property + def cutoff(self): + return self._lowpasses.cutoffs[0] + + @property + def stride(self): + return self._lowpasses.stride + + @property + def pad(self): + return self._lowpasses.pad + + @property + def zeros(self): + return self._lowpasses.zeros + + @property + def fft(self): + return self._lowpasses.fft + + def forward(self, _input): + return self._lowpasses(_input)[0] + + +def lowpass_filters( + _input: paddle.Tensor, + cutoffs: Sequence[float], + stride: int=1, + pad: bool=True, + zeros: float=8, + fft: Optional[bool]=None, ): + """ + Functional version of `LowPassFilters`, refer to this class for more information. + """ + return LowPassFilters(cutoffs, stride, pad, zeros, fft)(_input) + + +def lowpass_filter(_input: paddle.Tensor, + cutoff: float, + stride: int=1, + pad: bool=True, + zeros: float=8, + fft: Optional[bool]=None): + """ + Same as `lowpass_filters` but with a single cutoff frequency. + Output will not have a dimension inserted in the front. + """ + return lowpass_filters(_input, [cutoff], stride, pad, zeros, fft)[0] + + +class HighPassFilters(paddle.nn.Layer): + """ + Bank of high pass filters. See `julius.lowpass.LowPassFilters` for more + details on the implementation. + + Args: + cutoffs (list[float]): list of cutoff frequencies, in [0, 0.5] expressed as `f/f_s` where + f_s is the samplerate and `f` is the cutoff frequency. + The upper limit is 0.5, because a signal sampled at `f_s` contains only + frequencies under `f_s / 2`. + stride (int): how much to decimate the output. Probably not a good idea + to do so with a high pass filters though... + pad (bool): if True, appropriately pad the _input with zero over the edge. If `stride=1`, + the output will have the same length as the _input. + zeros (float): Number of zero crossings to keep. + Controls the receptive field of the Finite Impulse Response filter. + For filters with low cutoff frequency, e.g. 40Hz at 44.1kHz, + it is a bad idea to set this to a high value. + This is likely appropriate for most use. Lower values + will result in a faster filter, but with a slower attenuation around the + cutoff frequency. + fft (bool or None): if True, uses `julius.fftconv` rather than PyTorch convolutions. + If False, uses PyTorch convolutions. If None, either one will be chosen automatically + depending on the effective filter size. + + + ..warning:: + All the filters will use the same filter size, aligned on the lowest + frequency provided. If you combine a lot of filters with very diverse frequencies, it might + be more efficient to split them over multiple modules with similar frequencies. + + Shape: + + - Input: `[*, T]` + - Output: `[F, *, T']`, with `T'=T` if `pad` is True and `stride` is 1, and + `F` is the numer of cutoff frequencies. + + >>> highpass = HighPassFilters([1/4]) + >>> x = paddle.randn([4, 12, 21, 1024]) + >>> list(highpass(x).shape) + [1, 4, 12, 21, 1024] + """ + + def __init__(self, + cutoffs: Sequence[float], + stride: int=1, + pad: bool=True, + zeros: float=8, + fft: Optional[bool]=None): + super().__init__() + self._lowpasses = LowPassFilters(cutoffs, stride, pad, zeros, fft) + + @property + def cutoffs(self): + return self._lowpasses.cutoffs + + @property + def stride(self): + return self._lowpasses.stride + + @property + def pad(self): + return self._lowpasses.pad + + @property + def zeros(self): + return self._lowpasses.zeros + + @property + def fft(self): + return self._lowpasses.fft + + def forward(self, _input): + lows = self._lowpasses(_input) + + # We need to extract the right portion of the _input in case + # pad is False or stride > 1 + if self.pad: + start, end = 0, _input.shape[-1] + else: + start = self._lowpasses.half_size + end = -start + _input = _input[..., start:end:self.stride] + highs = _input - lows + return highs + + +class HighPassFilter(paddle.nn.Layer): + """ + Same as `HighPassFilters` but applies a single high pass filter. + + Shape: + + - Input: `[*, T]` + - Output: `[*, T']`, with `T'=T` if `pad` is True and `stride` is 1. + + >>> highpass = HighPassFilter(1/4, stride=1) + >>> x = paddle.randn([4, 124]) + >>> list(highpass(x).shape) + [4, 124] + """ + + def __init__(self, + cutoff: float, + stride: int=1, + pad: bool=True, + zeros: float=8, + fft: Optional[bool]=None): + super().__init__() + self._highpasses = HighPassFilters([cutoff], stride, pad, zeros, fft) + + @property + def cutoff(self): + return self._highpasses.cutoffs[0] + + @property + def stride(self): + return self._highpasses.stride + + @property + def pad(self): + return self._highpasses.pad + + @property + def zeros(self): + return self._highpasses.zeros + + @property + def fft(self): + return self._highpasses.fft + + def forward(self, _input): + return self._highpasses(_input)[0] + + +def highpass_filters( + _input: paddle.Tensor, + cutoffs: Sequence[float], + stride: int=1, + pad: bool=True, + zeros: float=8, + fft: Optional[bool]=None, ): + """ + Functional version of `HighPassFilters`, refer to this class for more information. + """ + return HighPassFilters(cutoffs, stride, pad, zeros, fft)(_input) + + +def highpass_filter(_input: paddle.Tensor, + cutoff: float, + stride: int=1, + pad: bool=True, + zeros: float=8, + fft: Optional[bool]=None): + """ + Functional version of `HighPassFilter`, refer to this class for more information. + Output will not have a dimension inserted in the front. + """ + return highpass_filters(_input, [cutoff], stride, pad, zeros, fft)[0] + + +import paddle +from typing import Optional, Sequence + + +def hz_to_mel(freqs: paddle.Tensor): + """ + Converts a Tensor of frequencies in hertz to the mel scale. + Uses the simple formula by O'Shaughnessy (1987). + + Args: + freqs (paddle.Tensor): frequencies to convert. + + """ + return 2595 * paddle.log10(1 + freqs / 700) + + +def mel_to_hz(mels: paddle.Tensor): + """ + Converts a Tensor of mel scaled frequencies to Hertz. + Uses the simple formula by O'Shaughnessy (1987). + + Args: + mels (paddle.Tensor): mel frequencies to convert. + """ + return 700 * (10**(mels / 2595) - 1) + + +def mel_frequencies(n_mels: int, fmin: float, fmax: float): + """ + Return frequencies that are evenly spaced in mel scale. + + Args: + n_mels (int): number of frequencies to return. + fmin (float): start from this frequency (in Hz). + fmax (float): finish at this frequency (in Hz). + + """ + low = hz_to_mel(paddle.to_tensor(float(fmin))).item() + high = hz_to_mel(paddle.to_tensor(float(fmax))).item() + mels = paddle.linspace(low, high, n_mels) + return mel_to_hz(mels) + + +class SplitBands(paddle.nn.Layer): + """ + Decomposes a signal over the given frequency bands in the waveform domain using + a cascade of low pass filters as implemented by `julius.lowpass.LowPassFilters`. + You can either specify explicitly the frequency cutoffs, or just the number of bands, + in which case the frequency cutoffs will be spread out evenly in mel scale. + + Args: + sample_rate (float): Sample rate of the input signal in Hz. + n_bands (int or None): number of bands, when not giving them explicitly with `cutoffs`. + In that case, the cutoff frequencies will be evenly spaced in mel-space. + cutoffs (list[float] or None): list of frequency cutoffs in Hz. + pad (bool): if True, appropriately pad the input with zero over the edge. If `stride=1`, + the output will have the same length as the input. + zeros (float): Number of zero crossings to keep. See `LowPassFilters` for more informations. + fft (bool or None): See `LowPassFilters` for more info. + + ..note:: + The sum of all the bands will always be the input signal. + + ..warning:: + Unlike `julius.lowpass.LowPassFilters`, the cutoffs frequencies must be provided in Hz along + with the sample rate. + + Shape: + + - Input: `[*, T]` + - Output: `[B, *, T']`, with `T'=T` if `pad` is True. + If `n_bands` was provided, `B = n_bands` otherwise `B = len(cutoffs) + 1` + + >>> bands = SplitBands(sample_rate=128, n_bands=10) + >>> x = paddle.randn(shape=[6, 4, 1024]) + >>> list(bands(x).shape) + [10, 6, 4, 1024] + """ + + def __init__( + self, + sample_rate: float, + n_bands: Optional[int]=None, + cutoffs: Optional[Sequence[float]]=None, + pad: bool=True, + zeros: float=8, + fft: Optional[bool]=None, ): + super(SplitBands, self).__init__() + if (cutoffs is None) + (n_bands is None) != 1: + raise ValueError( + "You must provide either n_bands, or cutoffs, but not both.") + + self.sample_rate = sample_rate + self.n_bands = n_bands + self._cutoffs = list(cutoffs) if cutoffs is not None else None + self.pad = pad + self.zeros = zeros + self.fft = fft + + if cutoffs is None: + if n_bands is None: + raise ValueError("You must provide one of n_bands or cutoffs.") + if not n_bands >= 1: + raise ValueError( + f"n_bands must be greater than one (got {n_bands})") + cutoffs = mel_frequencies(n_bands + 1, 0, sample_rate / 2)[1:-1] + else: + if max(cutoffs) > 0.5 * sample_rate: + raise ValueError( + "A cutoff above sample_rate/2 does not make sense.") + if len(cutoffs) > 0: + self.lowpass = LowPassFilters( + [c / sample_rate for c in cutoffs], + pad=pad, + zeros=zeros, + fft=fft) + else: + self.lowpass = None # type: ignore + + def forward(self, input): + if self.lowpass is None: + return input[None] + lows = self.lowpass(input) + low = lows[0] + bands = [low] + for low_and_band in lows[1:]: + # Get a bandpass filter by subtracting lowpasses + band = low_and_band - low + bands.append(band) + low = low_and_band + # Last band is whatever is left in the signal + bands.append(input - low) + return paddle.stack(bands) + + @property + def cutoffs(self): + if self._cutoffs is not None: + return self._cutoffs + elif self.lowpass is not None: + return [c * self.sample_rate for c in self.lowpass.cutoffs] + else: + return [] + + +def split_bands( + signal: paddle.Tensor, + sample_rate: float, + n_bands: Optional[int]=None, + cutoffs: Optional[Sequence[float]]=None, + pad: bool=True, + zeros: float=8, + fft: Optional[bool]=None, ): + """ + Functional version of `SplitBands`, refer to this class for more information. + + >>> x = paddle.randn(shape=[6, 4, 1024]) + >>> list(split_bands(x, sample_rate=64, cutoffs=[12, 24]).shape) + [3, 6, 4, 1024] + """ + return SplitBands(sample_rate, n_bands, cutoffs, pad, zeros, fft)(signal) diff --git a/audio/audiotools/core/audio_signal.py b/audio/audiotools/core/audio_signal.py index 64435242b..717959050 100644 --- a/audio/audiotools/core/audio_signal.py +++ b/audio/audiotools/core/audio_signal.py @@ -14,15 +14,17 @@ import librosa import numpy as np import paddle import soundfile + from . import util +from .dsp import DSPMixin +from .effects import EffectMixin +from .effects import ImpulseResponseMixin +from .ffmpeg import FFMPEGMixin +from .loudness import LoudnessMixin from .resample import resample_frac # from .display import DisplayMixin -# from .dsp import DSPMixin -# from .effects import EffectMixin -# from .effects import ImpulseResponseMixin -# from .ffmpeg import FFMPEGMixinx -# from loudness import LoudnessMixin + # from .playback import PlayMixin # from .whisper import WhisperMixin @@ -89,13 +91,13 @@ STFTParams.__new__.__defaults__ = (None, None, None, None, None) class AudioSignal( - # EffectMixin, - # LoudnessMixin, + EffectMixin, + LoudnessMixin, # PlayMixin, - # ImpulseResponseMixin, - # DSPMixin, + ImpulseResponseMixin, + DSPMixin, # DisplayMixin, - # FFMPEGMixin, + FFMPEGMixin, # WhisperMixin, ): """This is the core object of this library. Audio is always @@ -525,7 +527,7 @@ class AudioSignal( AudioSignal AudioSignal loaded from file """ - + # need `ffmpeg` data, sample_rate = librosa.load( audio_path, offset=offset, @@ -967,8 +969,7 @@ class AudioSignal( def stft_data(self, data: typing.Union[paddle.Tensor, np.ndarray]): if data is not None: assert paddle.is_tensor(data) and paddle.is_complex(data) - if (self.stft_data is not None and - self.stft_data.shape != data.shape): + if self.stft_data is not None and self.stft_data.shape != data.shape: warnings.warn("stft_data changed shape") self._stft_data = data return @@ -1139,8 +1140,7 @@ class AudioSignal( length = self.signal_length if match_stride: - assert (hop_length == window_length // - 4), "For match_stride, hop must equal n_fft // 4" + assert hop_length == window_length // 4, "For match_stride, hop must equal n_fft // 4" right_pad = math.ceil(length / hop_length) * hop_length - length pad = (window_length - hop_length) // 2 else: @@ -1192,16 +1192,13 @@ class AudioSignal( >>> signal.stft() """ - window_length = (self.stft_params.window_length - if window_length is None else int(window_length)) - hop_length = (self.stft_params.hop_length - if hop_length is None else int(hop_length)) - window_type = (self.stft_params.window_type - if window_type is None else window_type) - match_stride = (self.stft_params.match_stride - if match_stride is None else match_stride) - padding_type = (self.stft_params.padding_type - if padding_type is None else padding_type) + window_length = self.stft_params.window_length if window_length is None else int( + window_length) + hop_length = self.stft_params.hop_length if hop_length is None else int( + hop_length) + window_type = self.stft_params.window_type if window_type is None else window_type + match_stride = self.stft_params.match_stride if match_stride is None else match_stride + padding_type = self.stft_params.padding_type if padding_type is None else padding_type window = self.get_window(window_type, window_length) # window = window.to(self.audio_data.device) @@ -1269,14 +1266,12 @@ class AudioSignal( if self.stft_data is None: raise RuntimeError("Cannot do inverse STFT without self.stft_data!") - window_length = (self.stft_params.window_length - if window_length is None else int(window_length)) - hop_length = (self.stft_params.hop_length - if hop_length is None else int(hop_length)) - window_type = (self.stft_params.window_type - if window_type is None else window_type) - match_stride = (self.stft_params.match_stride - if match_stride is None else match_stride) + window_length = self.stft_params.window_length if window_length is None else int( + window_length) + hop_length = self.stft_params.hop_length if hop_length is None else int( + hop_length) + window_type = self.stft_params.window_type if window_type is None else window_type + match_stride = self.stft_params.match_stride if match_stride is None else match_stride window = self.get_window(window_type, window_length, self.stft_data.place) @@ -1409,7 +1404,6 @@ class AudioSignal( paddle.Tensor [shape=(n_mels, n_mfcc)] T The dct transformation matrix. """ - # from torchaudio.functional import create_dct return create_dct(n_mfcc, n_mels, norm) @@ -1575,8 +1569,7 @@ class AudioSignal( # Representation def _info(self): # ✅ - dur = (f"{self.signal_duration:0.3f}" - if self.signal_duration else "[unknown]") + dur = f"{self.signal_duration:0.3f}" if self.signal_duration else "[unknown]" info = { "duration": f"{dur} seconds", @@ -1654,10 +1647,20 @@ class AudioSignal( def __eq__(self, other): for k, v in list(self.__dict__.items()): if paddle.is_tensor(v): - if not paddle.allclose(v, other.__dict__[k], atol=1e-6): - max_error = (v - other.__dict__[k]).abs().max() - print(f"Max abs error for {k}: {max_error}") - return False + + if paddle.is_complex(v): + if not np.allclose( + v.cpu().numpy(), + other.__dict__[k].cpu().numpy(), + atol=1e-6): + max_error = (v - other.__dict__[k]).abs().max() + print(f"Max abs error for {k}: {max_error}") + return False + else: + if not paddle.allclose(v, other.__dict__[k], atol=1e-6): + max_error = (v - other.__dict__[k]).abs().max() + print(f"Max abs error for {k}: {max_error}") + return False return True # Indexing @@ -1675,10 +1678,10 @@ class AudioSignal( # Future work: make this work for time-indexing # as well, using the hop length. audio_data = self.audio_data[key] - _loudness = (self._loudness[key] - if self._loudness is not None else None) - stft_data = (self.stft_data[key] - if self.stft_data is not None else None) + _loudness = self._loudness[ + key] if self._loudness is not None else None + stft_data = self.stft_data[ + key] if self.stft_data is not None else None sources = None @@ -1707,7 +1710,12 @@ class AudioSignal( if self.audio_data is not None and value.audio_data is not None: self.audio_data[key] = value.audio_data if self._loudness is not None and value._loudness is not None: - self._loudness[key] = value._loudness + if paddle.is_tensor(key) and key.dtype == paddle.bool: + # FOR Paddle BOOL Index + _key_no_bool = paddle.nonzero(key).flatten() + self._loudness[_key_no_bool] = value._loudness + else: + self._loudness[key] = value._loudness if self.stft_data is not None and value.stft_data is not None: self.stft_data[key] = value.stft_data return diff --git a/audio/audiotools/core/dsp.py b/audio/audiotools/core/dsp.py new file mode 100644 index 000000000..9f3b47f31 --- /dev/null +++ b/audio/audiotools/core/dsp.py @@ -0,0 +1,390 @@ +import typing + +import numpy as np +import paddle + +from . import _julius +from . import util + + +class DSPMixin: + _original_batch_size = None + _original_num_channels = None + _padded_signal_length = None + + # def _preprocess_signal_for_windowing(self, window_duration, hop_duration): + # self._original_batch_size = self.batch_size + # self._original_num_channels = self.num_channels + + # window_length = int(window_duration * self.sample_rate) + # hop_length = int(hop_duration * self.sample_rate) + + # if window_length % hop_length != 0: + # factor = window_length // hop_length + # window_length = factor * hop_length + + # self.zero_pad(hop_length, hop_length) + # self._padded_signal_length = self.signal_length + + # return window_length, hop_length + + # def windows( + # self, window_duration: float, hop_duration: float, preprocess: bool = True + # ): + # """Generator which yields windows of specified duration from signal with a specified + # hop length. + + # Parameters + # ---------- + # window_duration : float + # Duration of every window in seconds. + # hop_duration : float + # Hop between windows in seconds. + # preprocess : bool, optional + # Whether to preprocess the signal, so that the first sample is in + # the middle of the first window, by default True + + # Yields + # ------ + # AudioSignal + # Each window is returned as an AudioSignal. + # """ + # if preprocess: + # window_length, hop_length = self._preprocess_signal_for_windowing( + # window_duration, hop_duration + # ) + + # self.audio_data = self.audio_data.reshape(-1, 1, self.signal_length) + + # for b in range(self.batch_size): + # i = 0 + # start_idx = i * hop_length + # while True: + # start_idx = i * hop_length + # i += 1 + # end_idx = start_idx + window_length + # if end_idx > self.signal_length: + # break + # yield self[b, ..., start_idx:end_idx] + + # def collect_windows( + # self, window_duration: float, hop_duration: float, preprocess: bool = True + # ): + # """Reshapes signal into windows of specified duration from signal with a specified + # hop length. Window are placed along the batch dimension. Use with + # :py:func:`audiotools.core.dsp.DSPMixin.overlap_and_add` to reconstruct the + # original signal. + + # Parameters + # ---------- + # window_duration : float + # Duration of every window in seconds. + # hop_duration : float + # Hop between windows in seconds. + # preprocess : bool, optional + # Whether to preprocess the signal, so that the first sample is in + # the middle of the first window, by default True + + # Returns + # ------- + # AudioSignal + # AudioSignal unfolded with shape ``(nb * nch * num_windows, 1, window_length)`` + # """ + # if preprocess: + # window_length, hop_length = self._preprocess_signal_for_windowing( + # window_duration, hop_duration + # ) + + # # self.audio_data: (nb, nch, nt). + # unfolded = paddle.nn.functional.unfold( + # self.audio_data.reshape(-1, 1, 1, self.signal_length), + # kernel_size=(1, window_length), + # stride=(1, hop_length), + # ) + # # unfolded: (nb * nch, window_length, num_windows). + # # -> (nb * nch * num_windows, 1, window_length) + # unfolded = unfolded.permute(0, 2, 1).reshape(-1, 1, window_length) + # self.audio_data = unfolded + # return self + + # def overlap_and_add(self, hop_duration: float): + # """Function which takes a list of windows and overlap adds them into a + # signal the same length as ``audio_signal``. + + # Parameters + # ---------- + # hop_duration : float + # How much to shift for each window + # (overlap is window_duration - hop_duration) in seconds. + + # Returns + # ------- + # AudioSignal + # overlap-and-added signal. + # """ + # hop_length = int(hop_duration * self.sample_rate) + # window_length = self.signal_length + + # nb, nch = self._original_batch_size, self._original_num_channels + + # unfolded = self.audio_data.reshape(nb * nch, -1, window_length).permute(0, 2, 1) + # folded = paddle.nn.functional.fold( + # unfolded, + # output_size=(1, self._padded_signal_length), + # kernel_size=(1, window_length), + # stride=(1, hop_length), + # ) + + # norm = paddle.ones_like(unfolded, device=unfolded.device) + # norm = paddle.nn.functional.fold( + # norm, + # output_size=(1, self._padded_signal_length), + # kernel_size=(1, window_length), + # stride=(1, hop_length), + # ) + + # folded = folded / norm + + # folded = folded.reshape(nb, nch, -1) + # self.audio_data = folded + # self.trim(hop_length, hop_length) + # return self + + def low_pass(self, + cutoffs: typing.Union[paddle.Tensor, np.ndarray, float], + zeros: int=51): + """Low-passes the signal in-place. Each item in the batch + can have a different low-pass cutoff, if the input + to this signal is an array or tensor. If a float, all + items are given the same low-pass filter. + + Parameters + ---------- + cutoffs : typing.Union[paddle.Tensor, np.ndarray, float] + Cutoff in Hz of low-pass filter. + zeros : int, optional + Number of taps to use in low-pass filter, by default 51 + + Returns + ------- + AudioSignal + Low-passed AudioSignal. + """ + cutoffs = util.ensure_tensor(cutoffs, 2, self.batch_size) + cutoffs = cutoffs / self.sample_rate + filtered = paddle.empty_like(self.audio_data) + + for i, cutoff in enumerate(cutoffs): + lp_filter = _julius.LowPassFilter(cutoff.cpu(), zeros=zeros) + filtered[i] = lp_filter(self.audio_data[i]) + + self.audio_data = filtered + self.stft_data = None + return self + + def high_pass(self, + cutoffs: typing.Union[paddle.Tensor, np.ndarray, float], + zeros: int=51): + """High-passes the signal in-place. Each item in the batch + can have a different high-pass cutoff, if the input + to this signal is an array or tensor. If a float, all + items are given the same high-pass filter. + + Parameters + ---------- + cutoffs : typing.Union[paddle.Tensor, np.ndarray, float] + Cutoff in Hz of high-pass filter. + zeros : int, optional + Number of taps to use in high-pass filter, by default 51 + + Returns + ------- + AudioSignal + High-passed AudioSignal. + """ + cutoffs = util.ensure_tensor(cutoffs, 2, self.batch_size) + cutoffs = cutoffs / self.sample_rate + filtered = paddle.empty_like(self.audio_data) + + for i, cutoff in enumerate(cutoffs): + hp_filter = _julius.HighPassFilter(cutoff.cpu(), zeros=zeros) + filtered[i] = hp_filter(self.audio_data[i]) + + self.audio_data = filtered + self.stft_data = None + return self + + # def mask_frequencies( + # self, + # fmin_hz: typing.Union[paddle.Tensor, np.ndarray, float], + # fmax_hz: typing.Union[paddle.Tensor, np.ndarray, float], + # val: float = 0.0, + # ): + # """Masks frequencies between ``fmin_hz`` and ``fmax_hz``, and fills them + # with the value specified by ``val``. Useful for implementing SpecAug. + # The min and max can be different for every item in the batch. + + # Parameters + # ---------- + # fmin_hz : typing.Union[paddle.Tensor, np.ndarray, float] + # Lower end of band to mask out. + # fmax_hz : typing.Union[paddle.Tensor, np.ndarray, float] + # Upper end of band to mask out. + # val : float, optional + # Value to fill in, by default 0.0 + + # Returns + # ------- + # AudioSignal + # Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the + # masked audio data. + # """ + # # SpecAug + # mag, phase = self.magnitude, self.phase + # fmin_hz = util.ensure_tensor(fmin_hz, ndim=mag.ndim) + # fmax_hz = util.ensure_tensor(fmax_hz, ndim=mag.ndim) + # assert paddle.all(fmin_hz < fmax_hz) + + # # build mask + # nbins = mag.shape[-2] + # bins_hz = paddle.linspace(0, self.sample_rate / 2, nbins, device=self.device) + # bins_hz = bins_hz[None, None, :, None].repeat( + # self.batch_size, 1, 1, mag.shape[-1] + # ) + # mask = (fmin_hz <= bins_hz) & (bins_hz < fmax_hz) + # mask = mask.to(self.device) + + # mag = mag.masked_fill(mask, val) + # phase = phase.masked_fill(mask, val) + # self.stft_data = mag * paddle.exp(1j * phase) + # return self + + # def mask_timesteps( + # self, + # tmin_s: typing.Union[paddle.Tensor, np.ndarray, float], + # tmax_s: typing.Union[paddle.Tensor, np.ndarray, float], + # val: float = 0.0, + # ): + # """Masks timesteps between ``tmin_s`` and ``tmax_s``, and fills them + # with the value specified by ``val``. Useful for implementing SpecAug. + # The min and max can be different for every item in the batch. + + # Parameters + # ---------- + # tmin_s : typing.Union[paddle.Tensor, np.ndarray, float] + # Lower end of timesteps to mask out. + # tmax_s : typing.Union[paddle.Tensor, np.ndarray, float] + # Upper end of timesteps to mask out. + # val : float, optional + # Value to fill in, by default 0.0 + + # Returns + # ------- + # AudioSignal + # Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the + # masked audio data. + # """ + # # SpecAug + # mag, phase = self.magnitude, self.phase + # tmin_s = util.ensure_tensor(tmin_s, ndim=mag.ndim) + # tmax_s = util.ensure_tensor(tmax_s, ndim=mag.ndim) + + # assert paddle.all(tmin_s < tmax_s) + + # # build mask + # nt = mag.shape[-1] + # bins_t = paddle.linspace(0, self.signal_duration, nt, device=self.device) + # bins_t = bins_t[None, None, None, :].repeat( + # self.batch_size, 1, mag.shape[-2], 1 + # ) + # mask = (tmin_s <= bins_t) & (bins_t < tmax_s) + + # mag = mag.masked_fill(mask, val) + # phase = phase.masked_fill(mask, val) + # self.stft_data = mag * paddle.exp(1j * phase) + # return self + + # def mask_low_magnitudes( + # self, db_cutoff: typing.Union[paddle.Tensor, np.ndarray, float], val: float = 0.0 + # ): + # """Mask away magnitudes below a specified threshold, which + # can be different for every item in the batch. + + # Parameters + # ---------- + # db_cutoff : typing.Union[paddle.Tensor, np.ndarray, float] + # Decibel value for which things below it will be masked away. + # val : float, optional + # Value to fill in for masked portions, by default 0.0 + + # Returns + # ------- + # AudioSignal + # Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the + # masked audio data. + # """ + # mag = self.magnitude + # log_mag = self.log_magnitude() + + # db_cutoff = util.ensure_tensor(db_cutoff, ndim=mag.ndim) + # mask = log_mag < db_cutoff + # mag = mag.masked_fill(mask, val) + + # self.magnitude = mag + # return self + + # def shift_phase(self, shift: typing.Union[paddle.Tensor, np.ndarray, float]): + # """Shifts the phase by a constant value. + + # Parameters + # ---------- + # shift : typing.Union[paddle.Tensor, np.ndarray, float] + # What to shift the phase by. + + # Returns + # ------- + # AudioSignal + # Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the + # masked audio data. + # """ + # shift = util.ensure_tensor(shift, ndim=self.phase.ndim) + # self.phase = self.phase + shift + # return self + + # def corrupt_phase(self, scale: typing.Union[paddle.Tensor, np.ndarray, float]): + # """Corrupts the phase randomly by some scaled value. + + # Parameters + # ---------- + # scale : typing.Union[paddle.Tensor, np.ndarray, float] + # Standard deviation of noise to add to the phase. + + # Returns + # ------- + # AudioSignal + # Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the + # masked audio data. + # """ + # scale = util.ensure_tensor(scale, ndim=self.phase.ndim) + # self.phase = self.phase + scale * paddle.randn_like(self.phase) + # return self + + # def preemphasis(self, coef: float = 0.85): + # """Applies pre-emphasis to audio signal. + + # Parameters + # ---------- + # coef : float, optional + # How much pre-emphasis to apply, lower values do less. 0 does nothing. + # by default 0.85 + + # Returns + # ------- + # AudioSignal + # Pre-emphasized signal. + # """ + # kernel = paddle.to_tensor([1, -coef, 0]).view(1, 1, -1).to(self.device) + # x = self.audio_data.reshape(-1, 1, self.signal_length) + # x = paddle.nn.functional.conv1d(x, kernel, padding=1) + # self.audio_data = x.reshape(*self.audio_data.shape) + # return self diff --git a/audio/audiotools/core/effects.py b/audio/audiotools/core/effects.py new file mode 100644 index 000000000..561c530e3 --- /dev/null +++ b/audio/audiotools/core/effects.py @@ -0,0 +1,681 @@ +import typing + +import numpy as np +import paddle + +from . import util +from ._julius import SplitBands + +# from . import _julius + + +class EffectMixin: + GAIN_FACTOR = np.log(10) / 20 + """Gain factor for converting between amplitude and decibels.""" + CODEC_PRESETS = { + "8-bit": { + "format": "wav", + "encoding": "ULAW", + "bits_per_sample": 8 + }, + "GSM-FR": { + "format": "gsm" + }, + "MP3": { + "format": "mp3", + "compression": -9 + }, + "Vorbis": { + "format": "vorbis", + "compression": -1 + }, + "Ogg": { + "format": "ogg", + "compression": -1, + }, + "Amr-nb": { + "format": "amr-nb" + }, + } + """Presets for applying codecs via torchaudio.""" + + def mix( + self, + other, + snr: typing.Union[paddle.Tensor, np.ndarray, float]=10, + other_eq: typing.Union[paddle.Tensor, np.ndarray]=None, ): + """Mixes noise with signal at specified + signal-to-noise ratio. Optionally, the + other signal can be equalized in-place. + + + Parameters + ---------- + other : AudioSignal + AudioSignal object to mix with. + snr : typing.Union[paddle.Tensor, np.ndarray, float], optional + Signal to noise ratio, by default 10 + other_eq : typing.Union[paddle.Tensor, np.ndarray], optional + EQ curve to apply to other signal, if any, by default None + + Returns + ------- + AudioSignal + In-place modification of AudioSignal. + """ + snr = util.ensure_tensor(snr) + + pad_len = max(0, self.signal_length - other.signal_length) + other.zero_pad(0, pad_len) + other.truncate_samples(self.signal_length) + if other_eq is not None: + other = other.equalizer(other_eq) + + tgt_loudness = self.loudness() - snr + other = other.normalize(tgt_loudness) + + self.audio_data = self.audio_data + other.audio_data + return self + + def convolve(self, other, start_at_max: bool=True): + """Convolves self with other. + This function uses FFTs to do the convolution. + + Parameters + ---------- + other : AudioSignal + Signal to convolve with. + start_at_max : bool, optional + Whether to start at the max value of other signal, to + avoid inducing delays, by default True + + Returns + ------- + AudioSignal + Convolved signal, in-place. + """ + from . import AudioSignal + + pad_len = self.signal_length - other.signal_length + + if pad_len > 0: + other.zero_pad(0, pad_len) + else: + other.truncate_samples(self.signal_length) + + if start_at_max: + # Use roll to rotate over the max for every item + # so that the impulse responses don't induce any + # delay. + idx = paddle.argmax(paddle.abs(other.audio_data), axis=-1) + irs = paddle.zeros_like(other.audio_data) + for i in range(other.batch_size): + irs[i] = paddle.roll( + other.audio_data[i], shifts=-idx[i].item(), axis=-1) + other = AudioSignal(irs, other.sample_rate) + + delta = paddle.zeros_like(other.audio_data) + delta[..., 0] = 1 + + length = self.signal_length + delta_fft = paddle.fft.rfft(delta, n=length) + other_fft = paddle.fft.rfft(other.audio_data, n=length) + self_fft = paddle.fft.rfft(self.audio_data, n=length) + + convolved_fft = other_fft * self_fft + convolved_audio = paddle.fft.irfft(convolved_fft, n=length) + + delta_convolved_fft = other_fft * delta_fft + delta_audio = paddle.fft.irfft(delta_convolved_fft, n=length) + + # Use the delta to rescale the audio exactly as needed. + delta_max = paddle.max(paddle.abs(delta_audio), axis=-1, keepdim=True) + scale = 1 / paddle.clip(delta_max, min=1e-5) + convolved_audio = convolved_audio * scale + + self.audio_data = convolved_audio + + return self + + def apply_ir( + self, + ir, + drr: typing.Union[paddle.Tensor, np.ndarray, float]=None, + ir_eq: typing.Union[paddle.Tensor, np.ndarray]=None, + use_original_phase: bool=False, ): + """Applies an impulse response to the signal. If ` is`ir_eq`` + is specified, the impulse response is equalized before + it is applied, using the given curve. + + Parameters + ---------- + ir : AudioSignal + Impulse response to convolve with. + drr : typing.Union[paddle.Tensor, np.ndarray, float], optional + Direct-to-reverberant ratio that impulse response will be + altered to, if specified, by default None + ir_eq : typing.Union[paddle.Tensor, np.ndarray], optional + Equalization that will be applied to impulse response + if specified, by default None + use_original_phase : bool, optional + Whether to use the original phase, instead of the convolved + phase, by default False + + Returns + ------- + AudioSignal + Signal with impulse response applied to it + """ + if ir_eq is not None: + ir = ir.equalizer(ir_eq) + if drr is not None: + ir = ir.alter_drr(drr) + + # Save the peak before + max_spk = self.audio_data.abs().max(axis=-1, keepdim=True) + + # Augment the impulse response to simulate microphone effects + # and with varying direct-to-reverberant ratio. + phase = self.phase + self.convolve(ir) + + # Use the input phase + if use_original_phase: + self.stft() + self.stft_data = self.magnitude * paddle.exp(1j * phase) + self.istft() + + # Rescale to the input's amplitude + max_transformed = self.audio_data.abs().max(axis=-1, keepdim=True) + scale_factor = max_spk.clip(1e-8) / max_transformed.clip(1e-8) + self = self * scale_factor + + return self + + def ensure_max_of_audio(self, _max: float=1.0): + """Ensures that ``abs(audio_data) <= max``. + + Parameters + ---------- + max : float, optional + Max absolute value of signal, by default 1.0 + + Returns + ------- + AudioSignal + Signal with values scaled between -max and max. + """ + peak = self.audio_data.abs().max(axis=-1, keepdim=True) + peak_gain = paddle.ones_like(peak) + peak_gain[peak > _max] = _max / peak[peak > _max] + self.audio_data = self.audio_data * peak_gain + return self + + def normalize(self, + db: typing.Union[paddle.Tensor, np.ndarray, float]=-24.0): + """Normalizes the signal's volume to the specified db, in LUFS. + This is GPU-compatible, making for very fast loudness normalization. + + Parameters + ---------- + db : typing.Union[paddle.Tensor, np.ndarray, float], optional + Loudness to normalize to, by default -24.0 + + Returns + ------- + AudioSignal + Normalized audio signal. + """ + db = util.ensure_tensor(db) + ref_db = self.loudness() + gain = db - ref_db + gain = paddle.exp(gain * self.GAIN_FACTOR) + + self.audio_data = self.audio_data * gain[:, None, None] + return self + + # def volume_change(self, db: typing.Union[paddle.Tensor, np.ndarray, float]): + # """Change volume of signal by some amount, in dB. + + # Parameters + # ---------- + # db : typing.Union[paddle.Tensor, np.ndarray, float] + # Amount to change volume by. + + # Returns + # ------- + # AudioSignal + # Signal at new volume. + # """ + # db = util.ensure_tensor(db, ndim=1).to(self.device) + # gain = torch.exp(db * self.GAIN_FACTOR) + # self.audio_data = self.audio_data * gain[:, None, None] + # return self + + # def _to_2d(self): + # waveform = self.audio_data.reshape(-1, self.signal_length) + # return waveform + + # def _to_3d(self, waveform): + # return waveform.reshape(self.batch_size, self.num_channels, -1) + + # def pitch_shift(self, n_semitones: int, quick: bool = True): + # """Pitch shift the signal. All items in the batch + # get the same pitch shift. + + # Parameters + # ---------- + # n_semitones : int + # How many semitones to shift the signal by. + # quick : bool, optional + # Using quick pitch shifting, by default True + + # Returns + # ------- + # AudioSignal + # Pitch shifted audio signal. + # """ + # device = self.device + # effects = [ + # ["pitch", str(n_semitones * 100)], + # ["rate", str(self.sample_rate)], + # ] + # if quick: + # effects[0].insert(1, "-q") + + # waveform = self._to_2d().cpu() + # waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor( + # waveform, self.sample_rate, effects, channels_first=True + # ) + # self.sample_rate = sample_rate + # self.audio_data = self._to_3d(waveform) + # return self.to(device) + + # def time_stretch(self, factor: float, quick: bool = True): + # """Time stretch the audio signal. + + # Parameters + # ---------- + # factor : float + # Factor by which to stretch the AudioSignal. Typically + # between 0.8 and 1.2. + # quick : bool, optional + # Whether to use quick time stretching, by default True + + # Returns + # ------- + # AudioSignal + # Time-stretched AudioSignal. + # """ + # device = self.device + # effects = [ + # ["tempo", str(factor)], + # ["rate", str(self.sample_rate)], + # ] + # if quick: + # effects[0].insert(1, "-q") + + # waveform = self._to_2d().cpu() + # waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor( + # waveform, self.sample_rate, effects, channels_first=True + # ) + # self.sample_rate = sample_rate + # self.audio_data = self._to_3d(waveform) + # return self.to(device) + + # def apply_codec( + # self, + # preset: str = None, + # format: str = "wav", + # encoding: str = None, + # bits_per_sample: int = None, + # compression: int = None, + # ): # pragma: no cover + # """Applies an audio codec to the signal. + + # Parameters + # ---------- + # preset : str, optional + # One of the keys in ``self.CODEC_PRESETS``, by default None + # format : str, optional + # Format for audio codec, by default "wav" + # encoding : str, optional + # Encoding to use, by default None + # bits_per_sample : int, optional + # How many bits per sample, by default None + # compression : int, optional + # Compression amount of codec, by default None + + # Returns + # ------- + # AudioSignal + # AudioSignal with codec applied. + + # Raises + # ------ + # ValueError + # If preset is not in ``self.CODEC_PRESETS``, an error + # is thrown. + # """ + # torchaudio_version_070 = "0.7" in torchaudio.__version__ + # if torchaudio_version_070: + # return self + + # kwargs = { + # "format": format, + # "encoding": encoding, + # "bits_per_sample": bits_per_sample, + # "compression": compression, + # } + + # if preset is not None: + # if preset in self.CODEC_PRESETS: + # kwargs = self.CODEC_PRESETS[preset] + # else: + # raise ValueError( + # f"Unknown preset: {preset}. " + # f"Known presets: {list(self.CODEC_PRESETS.keys())}" + # ) + + # waveform = self._to_2d() + # if kwargs["format"] in ["vorbis", "mp3", "ogg", "amr-nb"]: + # # Apply it in a for loop + # augmented = torch.cat( + # [ + # torchaudio.functional.apply_codec( + # waveform[i][None, :], self.sample_rate, **kwargs + # ) + # for i in range(waveform.shape[0]) + # ], + # dim=0, + # ) + # else: + # augmented = torchaudio.functional.apply_codec( + # waveform, self.sample_rate, **kwargs + # ) + # augmented = self._to_3d(augmented) + + # self.audio_data = augmented + # return self + + def mel_filterbank(self, n_bands: int): + """Breaks signal into mel bands. + + Parameters + ---------- + n_bands : int + Number of mel bands to use. + + Returns + ------- + paddle.Tensor + Mel-filtered bands, with last axis being the band index. + """ + filterbank = SplitBands(self.sample_rate, n_bands).float() + filtered = filterbank(self.audio_data) + return filtered.transpose([1, 2, 3, 0]) + + def equalizer(self, db: typing.Union[paddle.Tensor, np.ndarray]): + """Applies a mel-spaced equalizer to the audio signal. + + Parameters + ---------- + db : typing.Union[paddle.Tensor, np.ndarray] + EQ curve to apply. + + Returns + ------- + AudioSignal + AudioSignal with equalization applied. + """ + db = util.ensure_tensor(db) + n_bands = db.shape[-1] + fbank = self.mel_filterbank(n_bands) + + # If there's a batch dimension, make sure it's the same. + if db.ndim == 2: + if db.shape[0] != 1: + assert db.shape[0] == fbank.shape[0] + else: + db = db.unsqueeze(0) + + weights = (10**db).astype("float32") + fbank = fbank * weights[:, None, None, :] + eq_audio_data = fbank.sum(-1) + self.audio_data = eq_audio_data + return self + + def clip_distortion( + self, + clip_percentile: typing.Union[paddle.Tensor, np.ndarray, float]): + """Clips the signal at a given percentile. The higher it is, + the lower the threshold for clipping. + + Parameters + ---------- + clip_percentile : typing.Union[paddle.Tensor, np.ndarray, float] + Values are between 0.0 to 1.0. Typical values are 0.1 or below. + + Returns + ------- + AudioSignal + Audio signal with clipped audio data. + """ + clip_percentile = util.ensure_tensor(clip_percentile, ndim=1) + clip_percentile = clip_percentile.item() + min_thresh = paddle.quantile( + self.audio_data, clip_percentile / 2, axis=-1)[None] + max_thresh = paddle.quantile( + self.audio_data, 1 - (clip_percentile / 2), axis=-1)[None] + + nc = self.audio_data.shape[1] + min_thresh = min_thresh[:, :nc, :] + max_thresh = max_thresh[:, :nc, :] + + self.audio_data = self.audio_data.clip(min_thresh, max_thresh) + + return self + + # def quantization( + # self, quantization_channels: typing.Union[paddle.Tensor, np.ndarray, int] + # ): + # """Applies quantization to the input waveform. + + # Parameters + # ---------- + # quantization_channels : typing.Union[paddle.Tensor, np.ndarray, int] + # Number of evenly spaced quantization channels to quantize + # to. + + # Returns + # ------- + # AudioSignal + # Quantized AudioSignal. + # """ + # quantization_channels = util.ensure_tensor(quantization_channels, ndim=3) + + # x = self.audio_data + # x = (x + 1) / 2 + # x = x * quantization_channels + # x = x.floor() + # x = x / quantization_channels + # x = 2 * x - 1 + + # residual = (self.audio_data - x).detach() + # self.audio_data = self.audio_data - residual + # return self + + # def mulaw_quantization( + # self, quantization_channels: typing.Union[paddle.Tensor, np.ndarray, int] + # ): + # """Applies mu-law quantization to the input waveform. + + # Parameters + # ---------- + # quantization_channels : typing.Union[paddle.Tensor, np.ndarray, int] + # Number of mu-law spaced quantization channels to quantize + # to. + + # Returns + # ------- + # AudioSignal + # Quantized AudioSignal. + # """ + # mu = quantization_channels - 1.0 + # mu = util.ensure_tensor(mu, ndim=3) + + # x = self.audio_data + + # # quantize + # x = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu) + # x = ((x + 1) / 2 * mu + 0.5).to(torch.int64) + + # # unquantize + # x = (x / mu) * 2 - 1.0 + # x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.0) / mu + + # residual = (self.audio_data - x).detach() + # self.audio_data = self.audio_data - residual + # return self + + # def __matmul__(self, other): + # return self.convolve(other) + + +import paddle +import typing +import numpy as np + + +class ImpulseResponseMixin: + """These functions are generally only used with AudioSignals that are derived + from impulse responses, not other sources like music or speech. These methods + are used to replicate the data augmentation described in [1]. + + 1. Bryan, Nicholas J. "Impulse response data augmentation and deep + neural networks for blind room acoustic parameter estimation." + ICASSP 2020-2020 IEEE International Conference on Acoustics, + Speech and Signal Processing (ICASSP). IEEE, 2020. + """ + + def decompose_ir(self): + """Decomposes an impulse response into early and late + field responses. + """ + # Equations 1 and 2 + # ----------------- + # Breaking up into early + # response + late field response. + + td = paddle.argmax(self.audio_data, axis=-1, keepdim=True) + t0 = int(self.sample_rate * 0.0025) + + idx = paddle.arange(self.audio_data.shape[-1])[None, None, :] + idx = idx.expand([self.batch_size, -1, -1]) + early_idx = (idx >= td - t0) * (idx <= td + t0) + + early_response = paddle.zeros_like(self.audio_data) + + # early_response[early_idx] = self.audio_data[early_idx] + early_response = paddle.where(early_idx, self.audio_data, + early_response) + + late_idx = ~early_idx + late_field = paddle.zeros_like(self.audio_data) + # late_field[late_idx] = self.audio_data[late_idx] + late_field = paddle.where(late_idx, self.audio_data, late_field) + + # Equation 4 + # ---------- + # Decompose early response into windowed + # direct path and windowed residual. + + window = paddle.zeros_like(self.audio_data) + for idx in range(self.batch_size): + window_idx = early_idx[idx, 0] + + # ----- Just for this ----- + # window[idx, ..., window_idx] = self.get_window("hann", window_idx.sum().item()) + indices = paddle.nonzero(window_idx).reshape( + [-1]) # shape: [num_true], dtype: int64 + temp_window = self.get_window("hann", indices.shape[0]) + + window_slice = window[idx, 0] + updated_window_slice = paddle.scatter( + window_slice, index=indices, updates=temp_window) + + window[idx, 0] = updated_window_slice + # ----- Just for that ----- + + return early_response, late_field, window + + def measure_drr(self): + """Measures the direct-to-reverberant ratio of the impulse + response. + + Returns + ------- + float + Direct-to-reverberant ratio + """ + early_response, late_field, _ = self.decompose_ir() + num = (early_response**2).sum(axis=-1) + den = (late_field**2).sum(axis=-1) + drr = 10 * paddle.log10(num / den) + return drr + + @staticmethod + def solve_alpha(early_response, late_field, wd, target_drr): + """Used to solve for the alpha value, which is used + to alter the drr. + """ + # Equation 5 + # ---------- + # Apply the good ol' quadratic formula. + + wd_sq = wd**2 + wd_sq_1 = (1 - wd)**2 + e_sq = early_response**2 + l_sq = late_field**2 + a = (wd_sq * e_sq).sum(axis=-1) + b = (2 * (1 - wd) * wd * e_sq).sum(axis=-1) + c = (wd_sq_1 * e_sq).sum(axis=-1) - paddle.pow( + 10 * paddle.ones_like(target_drr), target_drr / 10) * l_sq.sum( + axis=-1) + + expr = ((b**2) - 4 * a * c).sqrt() + alpha = paddle.maximum( + (-b - expr) / (2 * a), + (-b + expr) / (2 * a), ) + return alpha + + def alter_drr(self, drr: typing.Union[paddle.Tensor, np.ndarray, float]): + """Alters the direct-to-reverberant ratio of the impulse response. + + Parameters + ---------- + drr : typing.Union[paddle.Tensor, np.ndarray, float] + Direct-to-reverberant ratio that impulse response will be + altered to, if specified, by default None + + Returns + ------- + AudioSignal + Altered impulse response. + """ + drr = util.ensure_tensor( + drr, 2, self.batch_size + ) # Assuming util.ensure_tensor is adapted or equivalent exists + + early_response, late_field, window = self.decompose_ir() + alpha = self.solve_alpha(early_response, late_field, window, drr) + min_alpha = late_field.abs().max(axis=-1)[0] / early_response.abs().max( + axis=-1)[0] + alpha = paddle.maximum(alpha, min_alpha)[..., None] + + aug_ir_data = alpha * window * early_response + ( + (1 - window) * early_response) + late_field + self.audio_data = aug_ir_data + self.ensure_max_of_audio( + ) # Assuming ensure_max_of_audio is a method defined elsewhere + return self diff --git a/audio/audiotools/core/ffmpeg.py b/audio/audiotools/core/ffmpeg.py new file mode 100644 index 000000000..d47265ae9 --- /dev/null +++ b/audio/audiotools/core/ffmpeg.py @@ -0,0 +1,115 @@ +import json +import shlex +import subprocess +import tempfile +from pathlib import Path +from typing import Tuple + +import ffmpy +import numpy as np +import paddle + + +def r128stats(filepath: str, quiet: bool): + """Takes a path to an audio file, returns a dict with the loudness + stats computed by the ffmpeg ebur128 filter. + + Parameters + ---------- + filepath : str + Path to compute loudness stats on. + quiet : bool + Whether to show FFMPEG output during computation. + + Returns + ------- + dict + Dictionary containing loudness stats. + """ + ffargs = [ + "ffmpeg", + "-nostats", + "-i", + filepath, + "-filter_complex", + "ebur128", + "-f", + "null", + "-", + ] + if quiet: + ffargs += ["-hide_banner"] + proc = subprocess.Popen( + ffargs, stderr=subprocess.PIPE, universal_newlines=True) + stats = proc.communicate()[1] + summary_index = stats.rfind("Summary:") + + summary_list = stats[summary_index:].split() + i_lufs = float(summary_list[summary_list.index("I:") + 1]) + i_thresh = float(summary_list[summary_list.index("I:") + 4]) + lra = float(summary_list[summary_list.index("LRA:") + 1]) + lra_thresh = float(summary_list[summary_list.index("LRA:") + 4]) + lra_low = float(summary_list[summary_list.index("low:") + 1]) + lra_high = float(summary_list[summary_list.index("high:") + 1]) + stats_dict = { + "I": i_lufs, + "I Threshold": i_thresh, + "LRA": lra, + "LRA Threshold": lra_thresh, + "LRA Low": lra_low, + "LRA High": lra_high, + } + + return stats_dict + + +def ffprobe_offset_and_codec(path: str) -> Tuple[float, str]: + """Given a path to a file, returns the start time offset and codec of + the first audio stream. + """ + ff = ffmpy.FFprobe( + inputs={path: None}, + global_options="-show_entries format=start_time:stream=duration,start_time,codec_type,codec_name,start_pts,time_base -of json -v quiet", + ) + streams = json.loads(ff.run(stdout=subprocess.PIPE)[0])["streams"] + seconds_offset = 0.0 + codec = None + + # Get the offset and codec of the first audio stream we find + # and return its start time, if it has one. + for stream in streams: + if stream["codec_type"] == "audio": + seconds_offset = stream.get("start_time", 0.0) + codec = stream.get("codec_name") + break + return float(seconds_offset), codec + + +class FFMPEGMixin: + _loudness = None + + def ffmpeg_loudness(self, quiet: bool=True): + """Computes loudness of audio file using FFMPEG. + + Parameters + ---------- + quiet : bool, optional + Whether to show FFMPEG output during computation, + by default True + + Returns + ------- + paddle.Tensor + Loudness of every item in the batch, computed via + FFMPEG. + """ + loudness = [] + + with tempfile.NamedTemporaryFile(suffix=".wav") as f: + for i in range(self.batch_size): + self[i].write(f.name) + loudness_stats = r128stats(f.name, quiet=quiet) + loudness.append(loudness_stats["I"]) + + self._loudness = paddle.to_tensor(np.array(loudness)).astype("float32") + return self.loudness() diff --git a/audio/audiotools/core/loudness.py b/audio/audiotools/core/loudness.py new file mode 100644 index 000000000..8009369a1 --- /dev/null +++ b/audio/audiotools/core/loudness.py @@ -0,0 +1,338 @@ +import copy + +import numpy as np +import paddle +import paddle.nn.functional as F +import scipy + +from . import _julius + + +class Meter(paddle.nn.Layer): + """Tensorized version of pyloudnorm.Meter. Works with batched audio tensors. + + Parameters + ---------- + rate : int + Sample rate of audio. + filter_class : str, optional + Class of weighting filter used. + K-weighting' (default), 'Fenton/Lee 1' + 'Fenton/Lee 2', 'Dash et al.' + by default "K-weighting" + block_size : float, optional + Gating block size in seconds, by default 0.400 + zeros : int, optional + Number of zeros to use in FIR approximation of + IIR filters, by default 512 + use_fir : bool, optional + Whether to use FIR approximation or exact IIR formulation. + If computing on GPU, ``use_fir=True`` will be used, as its + much faster, by default False + """ + + def __init__( + self, + rate: int, + filter_class: str="K-weighting", + block_size: float=0.400, + zeros: int=512, + use_fir: bool=False, ): + super().__init__() + + self.rate = rate + self.filter_class = filter_class + self.block_size = block_size + self.use_fir = use_fir + + G = paddle.to_tensor( + np.array([1.0, 1.0, 1.0, 1.41, 1.41]), stop_gradient=True) + self.register_buffer("G", G) + + # Compute impulse responses so that filtering is fast via + # a convolution at runtime, on GPU, unlike lfilter. + impulse = np.zeros((zeros, )) + impulse[..., 0] = 1.0 + + firs = np.zeros((len(self._filters), 1, zeros)) + # passband_gain = torch.zeros(len(self._filters)) + passband_gain = paddle.zeros([len(self._filters)], dtype="float32") + + for i, (_, filter_stage) in enumerate(self._filters.items()): + firs[i] = scipy.signal.lfilter(filter_stage.b, filter_stage.a, + impulse) + passband_gain[i] = filter_stage.passband_gain + + firs = paddle.to_tensor( + firs[..., ::-1].copy(), dtype="float32", stop_gradient=True) + + self.register_buffer("firs", firs) + self.register_buffer("passband_gain", passband_gain) + + def apply_filter_gpu(self, data: paddle.Tensor): + """Performs FIR approximation of loudness computation. + + Parameters + ---------- + data : paddle.Tensor + Audio data of shape (nb, nch, nt). + + Returns + ------- + paddle.Tensor + Filtered audio data. + """ + # Data is of shape (nb, nch, nt) + # Reshape to (nb*nch, 1, nt) + nb, nt, nch = data.shape + data = data.transpose([0, 2, 1]) + data = data.reshape([nb * nch, 1, nt]) + + # Apply padding + pad_length = self.firs.shape[-1] + + # Apply filtering in sequence + for i in range(self.firs.shape[0]): + data = F.pad(data, (pad_length, pad_length), data_format="NCL") + data = _julius.fft_conv1d(data, self.firs[i, None, ...]) + data = self.passband_gain[i] * data + data = data[..., 1:nt + 1] + + data = data.transpose([0, 2, 1]) + data = data[:, :nt, :] + return data + + @staticmethod + def scipy_lfilter(waveform, a_coeffs, b_coeffs, clamp: bool=True): + # 使用 scipy.signal.lfilter 进行滤波(处理三维数据) + output = np.zeros_like(waveform) + for batch_idx in range(waveform.shape[0]): + for channel_idx in range(waveform.shape[2]): + output[batch_idx, :, channel_idx] = scipy.signal.lfilter( + b_coeffs, a_coeffs, waveform[batch_idx, :, channel_idx]) + return output + + def apply_filter_cpu(self, data: paddle.Tensor): + """Performs IIR formulation of loudness computation. + + Parameters + ---------- + data : paddle.Tensor + Audio data of shape (nb, nch, nt). + + Returns + ------- + paddle.Tensor + Filtered audio data. + """ + _data = data.cpu().numpy().copy() + for _, filter_stage in self._filters.items(): + passband_gain = filter_stage.passband_gain + + a_coeffs = filter_stage.a + b_coeffs = filter_stage.b + + filtered = self.scipy_lfilter(_data, a_coeffs, b_coeffs) + _data[:] = passband_gain * filtered + data = paddle.to_tensor(_data) + return data + + def apply_filter(self, data: paddle.Tensor): + """Applies filter on either CPU or GPU, depending + on if the audio is on GPU or is on CPU, or if + ``self.use_fir`` is True. + + Parameters + ---------- + data : paddle.Tensor + Audio data of shape (nb, nch, nt). + + Returns + ------- + paddle.Tensor + Filtered audio data. + """ + if data.place.is_gpu_place() or self.use_fir: + data = self.apply_filter_gpu(data) + else: + data = self.apply_filter_cpu(data) + return data + + def forward(self, data: paddle.Tensor): + """Computes integrated loudness of data. + + Parameters + ---------- + data : paddle.Tensor + Audio data of shape (nb, nch, nt). + + Returns + ------- + paddle.Tensor + Filtered audio data. + """ + return self.integrated_loudness(data) + + def _unfold(self, input_data): + T_g = self.block_size + overlap = 0.75 # overlap of 75% of the block duration + step = 1.0 - overlap # step size by percentage + + kernel_size = int(T_g * self.rate) + stride = int(T_g * self.rate * step) + unfolded = _julius.unfold( + input_data.transpose([0, 2, 1]), kernel_size, stride) + unfolded = unfolded.transpose([0, 1, 3, 2]) + + return unfolded + + def integrated_loudness(self, data: paddle.Tensor): + """Computes integrated loudness of data. + + Parameters + ---------- + data : paddle.Tensor + Audio data of shape (nb, nch, nt). + + Returns + ------- + paddle.Tensor + Filtered audio data. + """ + if not paddle.is_tensor(data): + data = paddle.to_tensor(data, dtype="float32") + else: + data = data.astype("float32") + + input_data = data.clone() + # Data always has a batch and channel dimension. + # Is of shape (nb, nt, nch) + if input_data.ndim < 2: + input_data = input_data.unsqueeze(-1) + if input_data.ndim < 3: + input_data = input_data.unsqueeze(0) + + nb, nt, nch = input_data.shape + + # Apply frequency weighting filters - account + # for the acoustic respose of the head and auditory system + input_data = self.apply_filter(input_data) + + G = self.G # channel gains + T_g = self.block_size # 400 ms gating block standard + Gamma_a = -70.0 # -70 LKFS = absolute loudness threshold + + unfolded = self._unfold(input_data) + + z = (1.0 / (T_g * self.rate)) * unfolded.square().sum(2) + l = -0.691 + 10.0 * paddle.log10( + (G[None, :nch, None] * z).sum(1, keepdim=True)) + l = l.expand_as(z) + + # find gating block indices above absolute threshold + z_avg_gated = z + z_avg_gated[l <= Gamma_a] = 0 + masked = l > Gamma_a + z_avg_gated = z_avg_gated.sum(2) / masked.sum(2) + + # calculate the relative threshold value (see eq. 6) + Gamma_r = -0.691 + 10.0 * paddle.log10( + (z_avg_gated * G[None, :nch]).sum(-1)) - 10.0 + Gamma_r = Gamma_r[:, None, None] + Gamma_r = Gamma_r.expand([nb, nch, l.shape[-1]]) + + # find gating block indices above relative and absolute thresholds (end of eq. 7) + z_avg_gated = z + z_avg_gated[l <= Gamma_a] = 0 + z_avg_gated[l <= Gamma_r] = 0 + masked = (l > Gamma_a) * (l > Gamma_r) + z_avg_gated = z_avg_gated.sum(2) / masked.sum(2) + + # # Cannot use nan_to_num (pytorch 1.8 does not come with GCP-supported cuda version) + # z_avg_gated = torch.nan_to_num(z_avg_gated) + z_avg_gated = paddle.where( + paddle.isnan(z_avg_gated), + paddle.zeros_like(z_avg_gated), z_avg_gated) + z_avg_gated[z_avg_gated == float("inf")] = float( + np.finfo(np.float32).max) + z_avg_gated[z_avg_gated == -float("inf")] = float( + np.finfo(np.float32).min) + + LUFS = -0.691 + 10.0 * paddle.log10( + (G[None, :nch] * z_avg_gated).sum(1)) + return LUFS.astype("float32") + + @property + def filter_class(self): + return self._filter_class + + @filter_class.setter + def filter_class(self, value): + from pyloudnorm import Meter + + meter = Meter(self.rate) + meter.filter_class = value + self._filter_class = value + self._filters = meter._filters + + +class LoudnessMixin: + _loudness = None + MIN_LOUDNESS = -70 + """Minimum loudness possible.""" + + def loudness(self, + filter_class: str="K-weighting", + block_size: float=0.400, + **kwargs): + """Calculates loudness using an implementation of ITU-R BS.1770-4. + Allows control over gating block size and frequency weighting filters for + additional control. Measure the integrated gated loudness of a signal. + + API is derived from PyLoudnorm, but this implementation is ported to PyTorch + and is tensorized across batches. When on GPU, an FIR approximation of the IIR + filters is used to compute loudness for speed. + + Uses the weighting filters and block size defined by the meter + the integrated loudness is measured based upon the gating algorithm + defined in the ITU-R BS.1770-4 specification. + + Parameters + ---------- + filter_class : str, optional + Class of weighting filter used. + K-weighting' (default), 'Fenton/Lee 1' + 'Fenton/Lee 2', 'Dash et al.' + by default "K-weighting" + block_size : float, optional + Gating block size in seconds, by default 0.400 + kwargs : dict, optional + Keyword arguments to :py:func:`audiotools.core.loudness.Meter`. + + Returns + ------- + paddle.Tensor + Loudness of audio data. + """ + if self._loudness is not None: + return self._loudness # .to(self.device) + original_length = self.signal_length + if self.signal_duration < 0.5: + pad_len = int((0.5 - self.signal_duration) * self.sample_rate) + self.zero_pad(0, pad_len) + + # create BS.1770 meter + meter = Meter( + self.sample_rate, + filter_class=filter_class, + block_size=block_size, + **kwargs) + # meter = meter.to(self.device) + # measure loudness + loudness = meter.integrated_loudness( + self.audio_data.transpose([0, 2, 1])) + self.truncate_samples(original_length) + min_loudness = paddle.ones_like(loudness) * self.MIN_LOUDNESS + self._loudness = paddle.maximum(loudness, min_loudness) + + return self._loudness # .to(self.device) diff --git a/audio/audiotools/core/resample.py b/audio/audiotools/core/resample.py index 2e0268734..46ee5b222 100644 --- a/audio/audiotools/core/resample.py +++ b/audio/audiotools/core/resample.py @@ -168,15 +168,17 @@ class ResampleFrac(paddle.nn.Layer): if self.old_sr == self.new_sr: return x shape = x.shape + _dtype = x.dtype length = x.shape[-1] x = x.reshape([-1, length]) x = F.pad( x.unsqueeze(1), [self._width, self._width + self.old_sr], mode="replicate", - data_format="NCL", ) + data_format="NCL", ).astype(self.kernel.dtype) ys = F.conv1d(x, self.kernel, stride=self.old_sr, data_format="NCL") - y = ys.transpose([0, 2, 1]).reshape(list(shape[:-1]) + [-1]) + y = ys.transpose( + [0, 2, 1]).reshape(list(shape[:-1]) + [-1]).astype(_dtype) float_output_length = paddle.to_tensor( self.new_sr * length / self.old_sr, dtype="float32") diff --git a/audio/audiotools/core/util.py b/audio/audiotools/core/util.py index abba8ee67..1ea2e0956 100644 --- a/audio/audiotools/core/util.py +++ b/audio/audiotools/core/util.py @@ -1,3 +1,4 @@ +import collections import csv import glob import math @@ -8,19 +9,27 @@ import typing from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path +from typing import Any +from typing import Callable from typing import Dict +from typing import Iterable from typing import List -from typing import Optional, Union, Type, Any, Callable, Tuple, NamedTuple, Iterable -import collections +from typing import NamedTuple +from typing import Optional +from typing import Tuple +from typing import Type +from typing import Union + import librosa import numpy as np import paddle import soundfile -from audio_signal import AudioSignal from flatten_dict import flatten from flatten_dict import unflatten -from ..data.preprocess import create_csv +from .audio_signal import AudioSignal + +# from ..data.preprocess import create_csv @dataclass @@ -49,10 +58,9 @@ def info(audio_path: str): def ensure_tensor( - x: typing.Union[np.ndarray, paddle.Tensor, float, int], - ndim: int = None, - batch_size: int = None, -): + x: typing.Union[np.ndarray, paddle.Tensor, float, int], + ndim: int=None, + batch_size: int=None, ): """✅Ensures that the input ``x`` is a tensor of specified dimensions and batch size. @@ -86,8 +94,7 @@ def ensure_tensor( def _get_value(other): # ✅ - # from . import AudioSignal - from audio_signal import AudioSignal + from . import AudioSignal if isinstance(other, AudioSignal): return other.audio_data @@ -123,10 +130,11 @@ def random_state(seed: typing.Union[int, np.random.RandomState]): elif isinstance(seed, np.random.RandomState): return seed else: - raise ValueError("%r cannot be used to seed a numpy.random.RandomState" " instance" % seed) + raise ValueError("%r cannot be used to seed a numpy.random.RandomState" + " instance" % seed) -def seed(random_seed): +def seed(random_seed, **kwargs): """✅ Seeds all random states with the same random seed for reproducibility. Seeds ``numpy``, ``random`` and ``paddle`` @@ -176,7 +184,7 @@ def _close_temp_files(tmpfiles: list): AUDIO_EXTENSIONS = [".wav", ".flac", ".mp3", ".mp4"] -def find_audio(folder: str, ext: List[str] = AUDIO_EXTENSIONS): +def find_audio(folder: str, ext: List[str]=AUDIO_EXTENSIONS): """✅Finds all audio files in a directory recursively. Returns a list. @@ -206,11 +214,10 @@ def find_audio(folder: str, ext: List[str] = AUDIO_EXTENSIONS): def read_sources( - sources: List[str], - remove_empty: bool = True, - relative_path: str = "", - ext: List[str] = AUDIO_EXTENSIONS, -): + sources: List[str], + remove_empty: bool=True, + relative_path: str="", + ext: List[str]=AUDIO_EXTENSIONS, ): """✅Reads audio sources that can either be folders full of audio files, or CSV files that contain paths to audio files. CSV files that adhere to the expected @@ -253,7 +260,9 @@ def read_sources( return files -def choose_from_list_of_lists(state: np.random.RandomState, list_of_lists: list, p: float = None): +def choose_from_list_of_lists(state: np.random.RandomState, + list_of_lists: list, + p: float=None): """✅Choose a single item from a list of lists. Parameters @@ -295,7 +304,8 @@ def chdir(newdir: typing.Union[Path, str]): os.chdir(curdir) -def prepare_batch(batch: typing.Union[dict, list, paddle.Tensor], device: str = "cpu"): +def prepare_batch(batch: typing.Union[dict, list, paddle.Tensor], + device: str="cpu"): """✅Moves items in a batch (typically generated by a DataLoader as a list or a dict) to the specified device. This works even if dictionaries are nested. @@ -333,7 +343,7 @@ def prepare_batch(batch: typing.Union[dict, list, paddle.Tensor], device: str = return batch -def sample_from_dist(dist_tuple: tuple, state: np.random.RandomState = None): +def sample_from_dist(dist_tuple: tuple, state: np.random.RandomState=None): """✅Samples from a distribution defined by a tuple. The first item in the tuple is the distribution type, and the rest of the items are arguments to that distribution. The distribution function @@ -381,13 +391,12 @@ DEFAULT_FIG_SIZE = (9, 3) def format_figure( - fig_size: tuple = None, - title: str = None, - fig=None, - format_axes: bool = True, - format: bool = True, - font_color: str = "white", -): + fig_size: tuple=None, + title: str=None, + fig=None, + format_axes: bool=True, + format: bool=True, + font_color: str="white", ): """✅Prettifies the spectrogram and waveform plots. A title can be inset into the top right corner, and the axes can be inset into the figure, allowing the data to take up the entire @@ -447,8 +456,7 @@ def format_figure( va="top", color=font_color, fontsize=12 * font_scale, - alpha=0.75, - ) + alpha=0.75, ) ticks = ax.get_xticks()[2:] for t in ticks[:-1]: @@ -462,15 +470,15 @@ def format_figure( va="bottom", color=font_color, fontsize=12 * font_scale, - alpha=0.75, - ) + alpha=0.75, ) ax.margins(0, 0) ax.set_axis_off() ax.xaxis.set_major_locator(plt.NullLocator()) ax.yaxis.set_major_locator(plt.NullLocator()) - plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) + plt.subplots_adjust( + top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) if title is not None: t = axs[0].annotate( @@ -482,17 +490,61 @@ def format_figure( textcoords="offset points", ha="right", va="top", - color="white", - ) + color="white", ) t.set_bbox(dict(facecolor="black", alpha=0.5, edgecolor="black")) _default_collate_err_msg_format = ( - "default_collate: batch must contain tensors, numpy arrays, numbers, " "dicts or lists; found {}" -) + "default_collate: batch must contain tensors, numpy arrays, numbers, " + "dicts or lists; found {}") + + +def collate_tensor_fn( + batch, + *, + collate_fn_map: Optional[Dict[Union[type, Tuple[type, ...]], + Callable]]=None, ): + out = paddle.stack(batch, axis=0) + return out + + +def collate_float_fn( + batch, + *, + collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], + Callable]]=None, ): + return paddle.to_tensor(batch, dtype=paddle.float64) + +def collate_int_fn( + batch, + *, + collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], + Callable]]=None, ): + return paddle.to_tensor(batch) -def default_collate(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None): + +def collate_str_fn( + batch, + *, + collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], + Callable]]=None, ): + return batch + + +default_collate_fn_map: Dict[Union[Type, Tuple[Type, ...]], Callable] = { + paddle.Tensor: collate_tensor_fn +} +default_collate_fn_map[float] = collate_float_fn +default_collate_fn_map[int] = collate_int_fn +default_collate_fn_map[str] = collate_str_fn +default_collate_fn_map[bytes] = collate_str_fn + + +def default_collate(batch, + *, + collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], + Callable]]=None): r""" General collate function that handles collection type of element within each batch. @@ -514,43 +566,63 @@ def default_collate(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Ty if collate_fn_map is not None: if elem_type in collate_fn_map: - return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map) + return collate_fn_map[elem_type]( + batch, collate_fn_map=collate_fn_map) for collate_type in collate_fn_map: if isinstance(elem, collate_type): - return collate_fn_map[collate_type](batch, collate_fn_map=collate_fn_map) + return collate_fn_map[collate_type]( + batch, collate_fn_map=collate_fn_map) if isinstance(elem, collections.abc.Mapping): try: - return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}) + return elem_type({ + key: default_collate( + [d[key] for d in batch], collate_fn_map=collate_fn_map) + for key in elem + }) except TypeError: # The mapping type may not support `__init__(iterable)`. - return {key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem} + return { + key: default_collate( + [d[key] for d in batch], collate_fn_map=collate_fn_map) + for key in elem + } elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple - return elem_type(*(collate(samples, collate_fn_map=collate_fn_map) for samples in zip(*batch))) + return elem_type(*(default_collate( + samples, collate_fn_map=collate_fn_map) for samples in zip(*batch))) elif isinstance(elem, collections.abc.Sequence): # check to make sure that the elements in batch have consistent size it = iter(batch) elem_size = len(next(it)) if not all(len(elem) == elem_size for elem in it): - raise RuntimeError("each element in list of batch should be of equal size") - transposed = list(zip(*batch)) # It may be accessed twice, so we use a list. + raise RuntimeError( + "each element in list of batch should be of equal size") + transposed = list(zip( + *batch)) # It may be accessed twice, so we use a list. if isinstance(elem, tuple): return [ - collate(samples, collate_fn_map=collate_fn_map) for samples in transposed + default_collate(samples, collate_fn_map=collate_fn_map) + for samples in transposed ] # Backwards compatibility. else: try: - return elem_type([collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]) + return elem_type([ + default_collate(samples, collate_fn_map=collate_fn_map) + for samples in transposed + ]) except TypeError: # The sequence type may not support `__init__(iterable)` (e.g., `range`). - return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed] + return [ + default_collate(samples, collate_fn_map=collate_fn_map) + for samples in transposed + ] raise TypeError(_default_collate_err_msg_format.format(elem_type)) -def collate(list_of_dicts: list, n_splits: int = None): +def collate(list_of_dicts: list, n_splits: int=None): """Collates a list of dictionaries (e.g. as returned by a dataloader) into a dictionary with batched values. This routine uses the default torch collate function for everything @@ -585,8 +657,11 @@ def collate(list_of_dicts: list, n_splits: int = None): for i in range(0, list_len, n_items): # Flatten the dictionaries to avoid recursion. - list_of_dicts_ = [flatten(d) for d in list_of_dicts[i : i + n_items]] - dict_of_lists = {k: [dic[k] for dic in list_of_dicts_] for k in list_of_dicts_[0]} + list_of_dicts_ = [flatten(d) for d in list_of_dicts[i:i + n_items]] + dict_of_lists = { + k: [dic[k] for dic in list_of_dicts_] + for k in list_of_dicts_[0] + } batch = {} for k, v in dict_of_lists.items(): @@ -594,9 +669,117 @@ def collate(list_of_dicts: list, n_splits: int = None): if all(isinstance(s, AudioSignal) for s in v): batch[k] = AudioSignal.batch(v, pad_signals=True) else: - # Borrow the default collate fn from torch. - batch[k] = default_collate(v) + batch[k] = default_collate( + v, collate_fn_map=default_collate_fn_map) batches.append(unflatten(batch)) batches = batches[0] if not return_list else batches return batches + + +def hz_to_bin(hz: paddle.Tensor, n_fft: int, sample_rate: int): + """Closest frequency bin given a frequency, number + of bins, and a sampling rate. + + Parameters + ---------- + hz : paddle.Tensor + Tensor of frequencies in Hz. + n_fft : int + Number of FFT bins. + sample_rate : int + Sample rate of audio. + + Returns + ------- + paddle.Tensor + Closest bins to the data. + """ + shape = hz.shape + hz = hz.reshape([-1]) + freqs = paddle.linspace(0, sample_rate / 2, 2 + n_fft // 2) + hz = paddle.clip(hz, max=sample_rate / 2) + + closest = (hz[None, :] - freqs[:, None]).abs() + closest_bins = closest.argmin(axis=0) + + return closest_bins.reshape(shape) + + +def generate_chord_dataset( + max_voices: int=8, + sample_rate: int=44100, + num_items: int=5, + duration: float=1.0, + min_note: str="C2", + max_note: str="C6", + output_dir: Path="chords", ): + """ + Generates a toy multitrack dataset of chords, synthesized from sine waves. + + + Parameters + ---------- + max_voices : int, optional + Maximum number of voices in a chord, by default 8 + sample_rate : int, optional + Sample rate of audio, by default 44100 + num_items : int, optional + Number of items to generate, by default 5 + duration : float, optional + Duration of each item, by default 1.0 + min_note : str, optional + Minimum note in the dataset, by default "C2" + max_note : str, optional + Maximum note in the dataset, by default "C6" + output_dir : Path, optional + Directory to save the dataset, by default "chords" + + """ + import librosa + from . import AudioSignal + from ..data.preprocess import create_csv + + min_midi = librosa.note_to_midi(min_note) + max_midi = librosa.note_to_midi(max_note) + + tracks = [] + for idx in range(num_items): + track = {} + # figure out how many voices to put in this track + num_voices = random.randint(1, max_voices) + for voice_idx in range(num_voices): + # choose some random params + midinote = random.randint(min_midi, max_midi) + dur = random.uniform(0.85 * duration, duration) + + sig = AudioSignal.wave( + frequency=librosa.midi_to_hz(midinote), + duration=dur, + sample_rate=sample_rate, + shape="sine", ) + track[f"voice_{voice_idx}"] = sig + tracks.append(track) + + # save the tracks to disk + output_dir = Path(output_dir) + output_dir.mkdir(exist_ok=True) + for idx, track in enumerate(tracks): + track_dir = output_dir / f"track_{idx}" + track_dir.mkdir(exist_ok=True) + for voice_name, sig in track.items(): + sig.write(track_dir / f"{voice_name}.wav") + + all_voices = list(set([k for track in tracks for k in track.keys()])) + voice_lists = {voice: [] for voice in all_voices} + for track in tracks: + for voice_name in all_voices: + if voice_name in track: + voice_lists[voice_name].append(track[voice_name].path_to_file) + else: + voice_lists[voice_name].append("") + + for voice_name, paths in voice_lists.items(): + create_csv(paths, output_dir / f"{voice_name}.csv", loudness=True) + + return output_dir diff --git a/audio/audiotools/data/datasets.py b/audio/audiotools/data/datasets.py index 276e00d54..60697bf74 100644 --- a/audio/audiotools/data/datasets.py +++ b/audio/audiotools/data/datasets.py @@ -5,10 +5,12 @@ from typing import List from typing import Union import numpy as np -from audio_signal import AudioSignal -import util import paddle -from paddle.io import SequenceSampler, DistributedBatchSampler +from paddle.io import DistributedBatchSampler +from paddle.io import SequenceSampler + +from ..core import AudioSignal +from ..core import util class AudioLoader: @@ -41,20 +43,20 @@ class AudioLoader: """ def __init__( - self, - sources: List[str] = None, - weights: List[float] = None, - transform: Callable = None, - relative_path: str = "", - ext: List[str] = util.AUDIO_EXTENSIONS, - shuffle: bool = True, - shuffle_state: int = 0, - ): - self.audio_lists = util.read_sources(sources, relative_path=relative_path, ext=ext) - - self.audio_indices = [ - (src_idx, item_idx) for src_idx, src in enumerate(self.audio_lists) for item_idx in range(len(src)) - ] + self, + sources: List[str]=None, + weights: List[float]=None, + transform: Callable=None, + relative_path: str="", + ext: List[str]=util.AUDIO_EXTENSIONS, + shuffle: bool=True, + shuffle_state: int=0, ): + self.audio_lists = util.read_sources( + sources, relative_path=relative_path, ext=ext) + + self.audio_indices = [(src_idx, item_idx) + for src_idx, src in enumerate(self.audio_lists) + for item_idx in range(len(src))] if shuffle: state = util.random_state(shuffle_state) state.shuffle(self.audio_indices) @@ -64,27 +66,28 @@ class AudioLoader: self.transform = transform def __call__( - self, - state, - sample_rate: int, - duration: float, - loudness_cutoff: float = -40, - num_channels: int = 1, - offset: float = None, - source_idx: int = None, - item_idx: int = None, - global_idx: int = None, - ): + self, + state, + sample_rate: int, + duration: float, + loudness_cutoff: float=-40, + num_channels: int=1, + offset: float=None, + source_idx: int=None, + item_idx: int=None, + global_idx: int=None, ): if source_idx is not None and item_idx is not None: try: audio_info = self.audio_lists[source_idx][item_idx] except: audio_info = {"path": "none"} elif global_idx is not None: - source_idx, item_idx = self.audio_indices[global_idx % len(self.audio_indices)] + source_idx, item_idx = self.audio_indices[global_idx % + len(self.audio_indices)] audio_info = self.audio_lists[source_idx][item_idx] else: - audio_info, source_idx, item_idx = util.choose_from_list_of_lists(state, self.audio_lists, p=self.weights) + audio_info, source_idx, item_idx = util.choose_from_list_of_lists( + state, self.audio_lists, p=self.weights) path = audio_info["path"] signal = AudioSignal.zeros(duration, sample_rate, num_channels) @@ -95,14 +98,12 @@ class AudioLoader: path, duration=duration, state=state, - loudness_cutoff=loudness_cutoff, - ) + loudness_cutoff=loudness_cutoff, ) else: signal = AudioSignal( path, offset=offset, - duration=duration, - ) + duration=duration, ) if num_channels == 1: signal = signal.to_mono() @@ -122,7 +123,8 @@ class AudioLoader: "path": str(path), } if self.transform is not None: - item["transform_args"] = self.transform.instantiate(state, signal=signal) + item["transform_args"] = self.transform.instantiate( + state, signal=signal) return item @@ -130,7 +132,7 @@ def default_matcher(x, y): return Path(x).parent == Path(y).parent -def align_lists(lists, matcher: Callable = default_matcher): +def align_lists(lists, matcher: Callable=default_matcher): longest_list = lists[np.argmax([len(l) for l in lists])] for i, x in enumerate(longest_list): for l in lists: @@ -347,20 +349,20 @@ class AudioDataset: """ def __init__( - self, - loaders: Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]], - sample_rate: int, - n_examples: int = 1000, - duration: float = 0.5, - offset: float = None, - loudness_cutoff: float = -40, - num_channels: int = 1, - transform: Callable = None, - aligned: bool = False, - shuffle_loaders: bool = False, - matcher: Callable = default_matcher, - without_replacement: bool = True, - ): + self, + loaders: Union[AudioLoader, List[AudioLoader], Dict[str, + AudioLoader]], + sample_rate: int, + n_examples: int=1000, + duration: float=0.5, + offset: float=None, + loudness_cutoff: float=-40, + num_channels: int=1, + transform: Callable=None, + aligned: bool=False, + shuffle_loaders: bool=False, + matcher: Callable=default_matcher, + without_replacement: bool=True, ): # Internally we convert loaders to a dictionary if isinstance(loaders, list): loaders = {i: l for i, l in enumerate(loaders)} @@ -415,13 +417,11 @@ class AudioDataset: # Path mapper takes the current loader + everything # returned by the first loader. offset = item[keys[0]]["signal"].metadata["offset"] - loader_kwargs.update( - { - "offset": offset, - "source_idx": item[keys[0]]["source_idx"], - "item_idx": item[keys[0]]["item_idx"], - } - ) + loader_kwargs.update({ + "offset": offset, + "source_idx": item[keys[0]]["source_idx"], + "item_idx": item[keys[0]]["item_idx"], + }) item[key] = loader(**loader_kwargs) # Sort dictionary back into original order @@ -430,7 +430,8 @@ class AudioDataset: item["idx"] = idx if self.transform is not None: - item["transform_args"] = self.transform.instantiate(state=state, signal=item[keys[0]]["signal"]) + item["transform_args"] = self.transform.instantiate( + state=state, signal=item[keys[0]]["signal"]) # If there's only one loader, pop it up # to the main dictionary, instead of keeping it @@ -444,7 +445,7 @@ class AudioDataset: return self.length @staticmethod - def collate(list_of_dicts: Union[list, dict], n_splits: int = None): + def collate(list_of_dicts: Union[list, dict], n_splits: int=None): """Collates items drawn from this dataset. Uses :py:func:`audiotools.core.util.collate`. @@ -495,24 +496,29 @@ class ConcatDataset(AudioDataset): class ResumableDistributedSampler(DistributedBatchSampler): # pragma: no cover """Distributed sampler that can be resumed from a given start index.""" - def __init__( - self, dataset, batch_size, start_idx: int = None, num_replicas=None, rank=None, shuffle=False, drop_last=False - ): + def __init__(self, + dataset, + batch_size, + start_idx: int=None, + num_replicas=None, + rank=None, + shuffle=False, + drop_last=False): super().__init__( dataset=dataset, batch_size=batch_size, num_replicas=num_replicas, rank=rank, shuffle=shuffle, - drop_last=drop_last, - ) + drop_last=drop_last, ) # Start index, allows to resume an experiment at the index it was if start_idx is not None: self.start_idx = start_idx // self.num_replicas else: self.start_idx = 0 # 重新计算样本总数,因为 DistributedBatchSampler 的 __len__ 方法是基于 shuffle 后的样本总数计算的 - self.total_size = len(self.dataset) if not shuffle else len(self.indices) + self.total_size = len(self.dataset) if not shuffle else len( + self.indices) def __iter__(self): # 由于 Paddle 的 DistributedBatchSampler 直接返回 batch,我们需要将其展开为单个索引 @@ -536,7 +542,7 @@ class ResumableDistributedSampler(DistributedBatchSampler): # pragma: no cover class ResumableSequentialSampler(SequenceSampler): # pragma: no cover """Sequential sampler that can be resumed from a given start index.""" - def __init__(self, dataset, start_idx: int = None, **kwargs): + def __init__(self, dataset, start_idx: int=None, **kwargs): super().__init__(dataset, **kwargs) # Start index, allows to resume an experiment at the index it was self.start_idx = start_idx if start_idx is not None else 0 diff --git a/audio/audiotools/data/transforms.py b/audio/audiotools/data/transforms.py index fffc66e89..c6976fd64 100644 --- a/audio/audiotools/data/transforms.py +++ b/audio/audiotools/data/transforms.py @@ -5,18 +5,15 @@ from typing import List import numpy as np import paddle -# import torch from flatten_dict import flatten from flatten_dict import unflatten from numpy.random import RandomState -# from .. import ml +from .. import ml from ..core import AudioSignal from ..core import util from .datasets import AudioLoader -tt = paddle.to_tensor - class BaseTransform: """✅This is the base class for all transforms that are implemented @@ -79,7 +76,7 @@ class BaseTransform: """ - def __init__(self, keys: list = [], name: str = None, prob: float = 1.0): + def __init__(self, keys: list=[], name: str=None, prob: float=1.0): # Get keys from the _transform signature. tfm_keys = list(signature(self._transform).parameters.keys()) @@ -108,18 +105,18 @@ class BaseTransform: def _transform(self, signal): return signal - def _instantiate(self, state: RandomState, signal: AudioSignal = None): + def _instantiate(self, state: RandomState, signal: AudioSignal=None): return {} @staticmethod - def apply_mask(batch: dict, mask: torch.Tensor): + def apply_mask(batch: dict, mask: paddle.Tensor): """Applies a mask to the batch. Parameters ---------- batch : dict Batch whose values will be masked in the ``transform`` pass. - mask : torch.Tensor + mask : paddle.Tensor Mask to apply to batch. Returns @@ -127,7 +124,16 @@ class BaseTransform: dict A dictionary that contains values only where ``mask = True``. """ - masked_batch = {k: v[mask] for k, v in flatten(batch).items()} + # masked_batch = {k: v[mask] for k, v in flatten(batch).items()} + masked_batch = {} + for k, v in flatten(batch).items(): + if 0 == mask.dim() and 0 == v.dim(): + if mask: # 0d 的 True + masked_batch[k] = v[None] + else: + masked_batch[k] = paddle.to_tensor([], dtype=v.dtype) + else: + masked_batch[k] = v[mask] return unflatten(masked_batch) def transform(self, signal: AudioSignal, **kwargs): @@ -158,7 +164,7 @@ class BaseTransform: tfm_kwargs = self._prepare(kwargs) mask = tfm_kwargs["mask"] - if torch.any(mask): + if paddle.any(mask): tfm_kwargs = self.apply_mask(tfm_kwargs, mask) tfm_kwargs = {k: v for k, v in tfm_kwargs.items() if k != "mask"} signal[mask] = self._transform(signal[mask], **tfm_kwargs) @@ -169,10 +175,9 @@ class BaseTransform: return self.transform(*args, **kwargs) def instantiate( - self, - state: RandomState = None, - signal: AudioSignal = None, - ): + self, + state: RandomState=None, + signal: AudioSignal=None, ): """Instantiates parameters for the transform. Parameters @@ -202,7 +207,8 @@ class BaseTransform: # is needed before passing it in, so that the end-user # doesn't need to have variables they're not using flowing # into their function. - needs_signal = "signal" in set(signature(self._instantiate).parameters.keys()) + needs_signal = "signal" in set( + signature(self._instantiate).parameters.keys()) kwargs = {} if needs_signal: kwargs = {"signal": signal} @@ -211,12 +217,12 @@ class BaseTransform: params = self._instantiate(state, **kwargs) for k in list(params.keys()): v = params[k] - if isinstance(v, (AudioSignal, torch.Tensor, dict)): + if isinstance(v, (AudioSignal, paddle.Tensor, dict)): params[k] = v else: - params[k] = tt(v) + params[k] = paddle.to_tensor(v) mask = state.rand() <= self.prob - params[f"mask"] = tt(mask) + params[f"mask"] = paddle.to_tensor(mask) # Put the params into a nested dictionary that will be # used later when calling the transform. This is to avoid @@ -226,10 +232,9 @@ class BaseTransform: return params def batch_instantiate( - self, - states: list = None, - signal: AudioSignal = None, - ): + self, + states: list=None, + signal: AudioSignal=None, ): """Instantiates arguments for every item in a batch, given a list of states. Each state in the list corresponds to one item in the batch. @@ -343,7 +348,7 @@ class Compose(BaseTransform): Probability of applying this transform, by default 1.0 """ - def __init__(self, *transforms: list, name: str = None, prob: float = 1.0): + def __init__(self, *transforms: list, name: str=None, prob: float=1.0): if isinstance(transforms[0], list): transforms = transforms[0] @@ -407,7 +412,7 @@ class Compose(BaseTransform): signal = transform(signal, **kwargs) return signal - def _instantiate(self, state: RandomState, signal: AudioSignal = None): + def _instantiate(self, state: RandomState, signal: AudioSignal=None): parameters = {} for transform in self.transforms: parameters.update(transform.instantiate(state, signal=signal)) @@ -448,12 +453,11 @@ class Choose(Compose): """ def __init__( - self, - *transforms: list, - weights: list = None, - name: str = None, - prob: float = 1.0, - ): + self, + *transforms: list, + weights: list=None, + name: str=None, + prob: float=1.0, ): super().__init__(*transforms, name=name, prob=prob) if weights is None: @@ -461,7 +465,7 @@ class Choose(Compose): weights = [1 / _len for _ in range(_len)] self.weights = np.array(weights) - def _instantiate(self, state: RandomState, signal: AudioSignal = None): + def _instantiate(self, state: RandomState, signal: AudioSignal=None): kwargs = super()._instantiate(state, signal) tfm_idx = list(range(len(self.transforms))) tfm_idx = state.choice(tfm_idx, p=self.weights) @@ -487,12 +491,11 @@ class Repeat(Compose): """ def __init__( - self, - transform, - n_repeat: int = 1, - name: str = None, - prob: float = 1.0, - ): + self, + transform, + n_repeat: int=1, + name: str=None, + prob: float=1.0, ): transforms = [copy.copy(transform) for _ in range(n_repeat)] super().__init__(transforms, name=name, prob=prob) @@ -513,13 +516,12 @@ class RepeatUpTo(Choose): """ def __init__( - self, - transform, - max_repeat: int = 5, - weights: list = None, - name: str = None, - prob: float = 1.0, - ): + self, + transform, + max_repeat: int=5, + weights: list=None, + name: str=None, + prob: float=1.0, ): transforms = [] for n in range(1, max_repeat): transforms.append(Repeat(transform, n_repeat=n)) @@ -545,11 +547,10 @@ class ClippingDistortion(BaseTransform): """ def __init__( - self, - perc: tuple = ("uniform", 0.0, 0.1), - name: str = None, - prob: float = 1.0, - ): + self, + perc: tuple=("uniform", 0.0, 0.1), + name: str=None, + prob: float=1.0, ): super().__init__(name=name, prob=prob) self.perc = perc @@ -561,43 +562,42 @@ class ClippingDistortion(BaseTransform): return signal.clip_distortion(perc) -# class Equalizer(BaseTransform): -# """❌Applies an equalization curve to the audio signal. Corresponds -# to :py:func:`audiotools.core.effects.EffectMixin.equalizer`. +class Equalizer(BaseTransform): + """Applies an equalization curve to the audio signal. Corresponds + to :py:func:`audiotools.core.effects.EffectMixin.equalizer`. -# Parameters -# ---------- -# eq_amount : tuple, optional -# The maximum dB cut to apply to the audio in any band, -# by default ("const", 1.0 dB) -# n_bands : int, optional -# Number of bands in EQ, by default 6 -# name : str, optional -# Name of this transform, used to identify it in the dictionary -# produced by ``self.instantiate``, by default None -# prob : float, optional -# Probability of applying this transform, by default 1.0 -# """ + Parameters + ---------- + eq_amount : tuple, optional + The maximum dB cut to apply to the audio in any band, + by default ("const", 1.0 dB) + n_bands : int, optional + Number of bands in EQ, by default 6 + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ -# def __init__( -# self, -# eq_amount: tuple = ("const", 1.0), -# n_bands: int = 6, -# name: str = None, -# prob: float = 1.0, -# ): -# super().__init__(name=name, prob=prob) + def __init__( + self, + eq_amount: tuple=("const", 1.0), + n_bands: int=6, + name: str=None, + prob: float=1.0, ): + super().__init__(name=name, prob=prob) -# self.eq_amount = eq_amount -# self.n_bands = n_bands + self.eq_amount = eq_amount + self.n_bands = n_bands -# def _instantiate(self, state: RandomState): -# eq_amount = util.sample_from_dist(self.eq_amount, state) -# eq = -eq_amount * state.rand(self.n_bands) -# return {"eq": eq} + def _instantiate(self, state: RandomState): + eq_amount = util.sample_from_dist(self.eq_amount, state) + eq = -eq_amount * state.rand(self.n_bands) + return {"eq": eq} -# def _transform(self, signal, eq): -# return signal.equalizer(eq) + def _transform(self, signal, eq): + return signal.equalizer(eq) # class Quantization(BaseTransform): @@ -632,7 +632,6 @@ class ClippingDistortion(BaseTransform): # def _transform(self, signal, channels): # return signal.quantization(channels) - # class MuLawQuantization(BaseTransform): # """Applies mu-law quantization to the input waveform. Corresponds # to :py:func:`audiotools.core.effects.EffectMixin.mulaw_quantization`. @@ -665,7 +664,6 @@ class ClippingDistortion(BaseTransform): # def _transform(self, signal, channels): # return signal.mulaw_quantization(channels) - # class NoiseFloor(BaseTransform): # """Adds a noise floor of Gaussian noise to the signal at a specified # dB. @@ -704,92 +702,90 @@ class ClippingDistortion(BaseTransform): # return signal + nz_signal -# class BackgroundNoise(BaseTransform): -# """Adds background noise from audio specified by a set of CSV files. -# A valid CSV file looks like, and is typically generated by -# :py:func:`audiotools.data.preprocess.create_csv`: +class BackgroundNoise(BaseTransform): + """Adds background noise from audio specified by a set of CSV files. + A valid CSV file looks like, and is typically generated by + :py:func:`audiotools.data.preprocess.create_csv`: -# .. csv-table:: -# :header: path - -# room_tone/m6_script2_clean.wav -# room_tone/m6_script2_cleanraw.wav -# room_tone/m6_script2_ipad_balcony1.wav -# room_tone/m6_script2_ipad_bedroom1.wav -# room_tone/m6_script2_ipad_confroom1.wav -# room_tone/m6_script2_ipad_confroom2.wav -# room_tone/m6_script2_ipad_livingroom1.wav -# room_tone/m6_script2_ipad_office1.wav - -# .. note:: -# All paths are relative to an environment variable called ``PATH_TO_DATA``, -# so that CSV files are portable across machines where data may be -# located in different places. + .. csv-table:: + :header: path -# This transform calls :py:func:`audiotools.core.effects.EffectMixin.mix` -# and :py:func:`audiotools.core.effects.EffectMixin.equalizer` under the -# hood. + room_tone/m6_script2_clean.wav + room_tone/m6_script2_cleanraw.wav + room_tone/m6_script2_ipad_balcony1.wav + room_tone/m6_script2_ipad_bedroom1.wav + room_tone/m6_script2_ipad_confroom1.wav + room_tone/m6_script2_ipad_confroom2.wav + room_tone/m6_script2_ipad_livingroom1.wav + room_tone/m6_script2_ipad_office1.wav -# Parameters -# ---------- -# snr : tuple, optional -# Signal-to-noise ratio, by default ("uniform", 10.0, 30.0) -# sources : List[str], optional -# Sources containing folders, or CSVs with paths to audio files, -# by default None -# weights : List[float], optional -# Weights to sample audio files from each source, by default None -# eq_amount : tuple, optional -# Amount of equalization to apply, by default ("const", 1.0) -# n_bands : int, optional -# Number of bands in equalizer, by default 3 -# name : str, optional -# Name of this transform, used to identify it in the dictionary -# produced by ``self.instantiate``, by default None -# prob : float, optional -# Probability of applying this transform, by default 1.0 -# loudness_cutoff : float, optional -# Loudness cutoff when loading from audio files, by default None -# """ + .. note:: + All paths are relative to an environment variable called ``PATH_TO_DATA``, + so that CSV files are portable across machines where data may be + located in different places. -# def __init__( -# self, -# snr: tuple = ("uniform", 10.0, 30.0), -# sources: List[str] = None, -# weights: List[float] = None, -# eq_amount: tuple = ("const", 1.0), -# n_bands: int = 3, -# name: str = None, -# prob: float = 1.0, -# loudness_cutoff: float = None, -# ): -# super().__init__(name=name, prob=prob) + This transform calls :py:func:`audiotools.core.effects.EffectMixin.mix` + and :py:func:`audiotools.core.effects.EffectMixin.equalizer` under the + hood. -# self.snr = snr -# self.eq_amount = eq_amount -# self.n_bands = n_bands -# self.loader = AudioLoader(sources, weights) -# self.loudness_cutoff = loudness_cutoff + Parameters + ---------- + snr : tuple, optional + Signal-to-noise ratio, by default ("uniform", 10.0, 30.0) + sources : List[str], optional + Sources containing folders, or CSVs with paths to audio files, + by default None + weights : List[float], optional + Weights to sample audio files from each source, by default None + eq_amount : tuple, optional + Amount of equalization to apply, by default ("const", 1.0) + n_bands : int, optional + Number of bands in equalizer, by default 3 + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + loudness_cutoff : float, optional + Loudness cutoff when loading from audio files, by default None + """ -# def _instantiate(self, state: RandomState, signal: AudioSignal): -# eq_amount = util.sample_from_dist(self.eq_amount, state) -# eq = -eq_amount * state.rand(self.n_bands) -# snr = util.sample_from_dist(self.snr, state) + def __init__( + self, + snr: tuple=("uniform", 10.0, 30.0), + sources: List[str]=None, + weights: List[float]=None, + eq_amount: tuple=("const", 1.0), + n_bands: int=3, + name: str=None, + prob: float=1.0, + loudness_cutoff: float=None, ): + super().__init__(name=name, prob=prob) -# bg_signal = self.loader( -# state, -# signal.sample_rate, -# duration=signal.signal_duration, -# loudness_cutoff=self.loudness_cutoff, -# num_channels=signal.num_channels, -# )["signal"] + self.snr = snr + self.eq_amount = eq_amount + self.n_bands = n_bands + self.loader = AudioLoader(sources, weights) + self.loudness_cutoff = loudness_cutoff -# return {"eq": eq, "bg_signal": bg_signal, "snr": snr} + def _instantiate(self, state: RandomState, signal: AudioSignal): + eq_amount = util.sample_from_dist(self.eq_amount, state) + eq = -eq_amount * state.rand(self.n_bands) + snr = util.sample_from_dist(self.snr, state) -# def _transform(self, signal, bg_signal, snr, eq): -# # Clone bg_signal so that transform can be repeatedly applied -# # to different signals with the same effect. -# return signal.mix(bg_signal.clone(), snr, eq) + bg_signal = self.loader( + state, + signal.sample_rate, + duration=signal.signal_duration, + loudness_cutoff=self.loudness_cutoff, + num_channels=signal.num_channels, )["signal"] + + return {"eq": eq, "bg_signal": bg_signal, "snr": snr} + + def _transform(self, signal, bg_signal, snr, eq): + # Clone bg_signal so that transform can be repeatedly applied + # to different signals with the same effect. + return signal.mix(bg_signal.clone(), snr, eq) # class CrossTalk(BaseTransform): @@ -854,88 +850,88 @@ class ClippingDistortion(BaseTransform): # return mix -# class RoomImpulseResponse(BaseTransform): -# """Convolves signal with a room impulse response, at a specified -# direct-to-reverberant ratio, with equalization applied. Room impulse -# response data is drawn from a CSV file that was produced via -# :py:func:`audiotools.data.preprocess.create_csv`. +class RoomImpulseResponse(BaseTransform): + """Convolves signal with a room impulse response, at a specified + direct-to-reverberant ratio, with equalization applied. Room impulse + response data is drawn from a CSV file that was produced via + :py:func:`audiotools.data.preprocess.create_csv`. -# This transform calls :py:func:`audiotools.core.effects.EffectMixin.apply_ir` -# under the hood. + This transform calls :py:func:`audiotools.core.effects.EffectMixin.apply_ir` + under the hood. -# Parameters -# ---------- -# drr : tuple, optional -# _description_, by default ("uniform", 0.0, 30.0) -# sources : List[str], optional -# Sources containing folders, or CSVs with paths to audio files, -# by default None -# weights : List[float], optional -# Weights to sample audio files from each source, by default None -# eq_amount : tuple, optional -# Amount of equalization to apply, by default ("const", 1.0) -# n_bands : int, optional -# Number of bands in equalizer, by default 6 -# name : str, optional -# Name of this transform, used to identify it in the dictionary -# produced by ``self.instantiate``, by default None -# prob : float, optional -# Probability of applying this transform, by default 1.0 -# use_original_phase : bool, optional -# Whether or not to use the original phase, by default False -# offset : float, optional -# Offset from each impulse response file to use, by default 0.0 -# duration : float, optional -# Duration of each impulse response, by default 1.0 -# """ - -# def __init__( -# self, -# drr: tuple = ("uniform", 0.0, 30.0), -# sources: List[str] = None, -# weights: List[float] = None, -# eq_amount: tuple = ("const", 1.0), -# n_bands: int = 6, -# name: str = None, -# prob: float = 1.0, -# use_original_phase: bool = False, -# offset: float = 0.0, -# duration: float = 1.0, -# ): -# super().__init__(name=name, prob=prob) - -# self.drr = drr -# self.eq_amount = eq_amount -# self.n_bands = n_bands -# self.use_original_phase = use_original_phase - -# self.loader = AudioLoader(sources, weights) -# self.offset = offset -# self.duration = duration - -# def _instantiate(self, state: RandomState, signal: AudioSignal = None): -# eq_amount = util.sample_from_dist(self.eq_amount, state) -# eq = -eq_amount * state.rand(self.n_bands) -# drr = util.sample_from_dist(self.drr, state) - -# ir_signal = self.loader( -# state, -# signal.sample_rate, -# offset=self.offset, -# duration=self.duration, -# loudness_cutoff=None, -# num_channels=signal.num_channels, -# )["signal"] -# ir_signal.zero_pad_to(signal.sample_rate) + Parameters + ---------- + drr : tuple, optional + _description_, by default ("uniform", 0.0, 30.0) + sources : List[str], optional + Sources containing folders, or CSVs with paths to audio files, + by default None + weights : List[float], optional + Weights to sample audio files from each source, by default None + eq_amount : tuple, optional + Amount of equalization to apply, by default ("const", 1.0) + n_bands : int, optional + Number of bands in equalizer, by default 6 + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + use_original_phase : bool, optional + Whether or not to use the original phase, by default False + offset : float, optional + Offset from each impulse response file to use, by default 0.0 + duration : float, optional + Duration of each impulse response, by default 1.0 + """ -# return {"eq": eq, "ir_signal": ir_signal, "drr": drr} + def __init__( + self, + drr: tuple=("uniform", 0.0, 30.0), + sources: List[str]=None, + weights: List[float]=None, + eq_amount: tuple=("const", 1.0), + n_bands: int=6, + name: str=None, + prob: float=1.0, + use_original_phase: bool=False, + offset: float=0.0, + duration: float=1.0, ): + super().__init__(name=name, prob=prob) -# def _transform(self, signal, ir_signal, drr, eq): -# # Clone ir_signal so that transform can be repeatedly applied -# # to different signals with the same effect. -# return signal.apply_ir( -# ir_signal.clone(), drr, eq, use_original_phase=self.use_original_phase -# ) + self.drr = drr + self.eq_amount = eq_amount + self.n_bands = n_bands + self.use_original_phase = use_original_phase + + self.loader = AudioLoader(sources, weights) + self.offset = offset + self.duration = duration + + def _instantiate(self, state: RandomState, signal: AudioSignal=None): + eq_amount = util.sample_from_dist(self.eq_amount, state) + eq = -eq_amount * state.rand(self.n_bands) + drr = util.sample_from_dist(self.drr, state) + + ir_signal = self.loader( + state, + signal.sample_rate, + offset=self.offset, + duration=self.duration, + loudness_cutoff=None, + num_channels=signal.num_channels, )["signal"] + ir_signal.zero_pad_to(signal.sample_rate) + + return {"eq": eq, "ir_signal": ir_signal, "drr": drr} + + def _transform(self, signal, ir_signal, drr, eq): + # Clone ir_signal so that transform can be repeatedly applied + # to different signals with the same effect. + return signal.apply_ir( + ir_signal.clone(), + drr, + eq, + use_original_phase=self.use_original_phase) # class VolumeChange(BaseTransform): @@ -970,37 +966,36 @@ class ClippingDistortion(BaseTransform): # return signal.volume_change(db) -# class VolumeNorm(BaseTransform): -# """Normalizes the volume of the excerpt to a specified decibel. +class VolumeNorm(BaseTransform): + """Normalizes the volume of the excerpt to a specified decibel. -# Uses :py:func:`audiotools.core.effects.EffectMixin.normalize`. + Uses :py:func:`audiotools.core.effects.EffectMixin.normalize`. -# Parameters -# ---------- -# db : tuple, optional -# dB to normalize signal to, by default ("const", -24) -# name : str, optional -# Name of this transform, used to identify it in the dictionary -# produced by ``self.instantiate``, by default None -# prob : float, optional -# Probability of applying this transform, by default 1.0 -# """ + Parameters + ---------- + db : tuple, optional + dB to normalize signal to, by default ("const", -24) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ -# def __init__( -# self, -# db: tuple = ("const", -24), -# name: str = None, -# prob: float = 1.0, -# ): -# super().__init__(name=name, prob=prob) + def __init__( + self, + db: tuple=("const", -24), + name: str=None, + prob: float=1.0, ): + super().__init__(name=name, prob=prob) -# self.db = db + self.db = db -# def _instantiate(self, state: RandomState): -# return {"db": util.sample_from_dist(self.db, state)} + def _instantiate(self, state: RandomState): + return {"db": util.sample_from_dist(self.db, state)} -# def _transform(self, signal, db): -# return signal.normalize(db) + def _transform(self, signal, db): + return signal.normalize(db) # class GlobalVolumeNorm(BaseTransform): @@ -1063,111 +1058,108 @@ class ClippingDistortion(BaseTransform): # return signal.volume_change(db) -# class Silence(BaseTransform): -# """Zeros out the signal with some probability. +class Silence(BaseTransform): + """Zeros out the signal with some probability. -# Parameters -# ---------- -# name : str, optional -# Name of this transform, used to identify it in the dictionary -# produced by ``self.instantiate``, by default None -# prob : float, optional -# Probability of applying this transform, by default 0.1 -# """ + Parameters + ---------- + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 0.1 + """ -# def __init__(self, name: str = None, prob: float = 0.1): -# super().__init__(name=name, prob=prob) + def __init__(self, name: str=None, prob: float=0.1): + super().__init__(name=name, prob=prob) -# def _transform(self, signal): -# _loudness = signal._loudness -# signal = AudioSignal( -# torch.zeros_like(signal.audio_data), -# sample_rate=signal.sample_rate, -# stft_params=signal.stft_params, -# ) -# # So that the amound of noise added is as if it wasn't silenced. -# # TODO: improve this hack -# signal._loudness = _loudness + def _transform(self, signal): + _loudness = signal._loudness + signal = AudioSignal( + paddle.zeros_like(signal.audio_data), + sample_rate=signal.sample_rate, + stft_params=signal.stft_params, ) + # So that the amound of noise added is as if it wasn't silenced. + # TODO: improve this hack + signal._loudness = _loudness -# return signal + return signal -# class LowPass(BaseTransform): -# """Applies a LowPass filter. +class LowPass(BaseTransform): + """Applies a LowPass filter. -# Uses :py:func:`audiotools.core.dsp.DSPMixin.low_pass`. + Uses :py:func:`audiotools.core.dsp.DSPMixin.low_pass`. -# Parameters -# ---------- -# cutoff : tuple, optional -# Cutoff frequency distribution, -# by default ``("choice", [4000, 8000, 16000])`` -# zeros : int, optional -# Number of zero-crossings in filter, argument to -# ``julius.LowPassFilters``, by default 51 -# name : str, optional -# Name of this transform, used to identify it in the dictionary -# produced by ``self.instantiate``, by default None -# prob : float, optional -# Probability of applying this transform, by default 1.0 -# """ + Parameters + ---------- + cutoff : tuple, optional + Cutoff frequency distribution, + by default ``("choice", [4000, 8000, 16000])`` + zeros : int, optional + Number of zero-crossings in filter, argument to + ``julius.LowPassFilters``, by default 51 + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ -# def __init__( -# self, -# cutoff: tuple = ("choice", [4000, 8000, 16000]), -# zeros: int = 51, -# name: str = None, -# prob: float = 1, -# ): -# super().__init__(name=name, prob=prob) + def __init__( + self, + cutoff: tuple=("choice", [4000, 8000, 16000]), + zeros: int=51, + name: str=None, + prob: float=1, ): + super().__init__(name=name, prob=prob) -# self.cutoff = cutoff -# self.zeros = zeros + self.cutoff = cutoff + self.zeros = zeros -# def _instantiate(self, state: RandomState): -# return {"cutoff": util.sample_from_dist(self.cutoff, state)} + def _instantiate(self, state: RandomState): + return {"cutoff": util.sample_from_dist(self.cutoff, state)} -# def _transform(self, signal, cutoff): -# return signal.low_pass(cutoff, zeros=self.zeros) + def _transform(self, signal, cutoff): + return signal.low_pass(cutoff, zeros=self.zeros) -# class HighPass(BaseTransform): -# """Applies a HighPass filter. +class HighPass(BaseTransform): + """Applies a HighPass filter. -# Uses :py:func:`audiotools.core.dsp.DSPMixin.high_pass`. + Uses :py:func:`audiotools.core.dsp.DSPMixin.high_pass`. -# Parameters -# ---------- -# cutoff : tuple, optional -# Cutoff frequency distribution, -# by default ``("choice", [50, 100, 250, 500, 1000])`` -# zeros : int, optional -# Number of zero-crossings in filter, argument to -# ``julius.LowPassFilters``, by default 51 -# name : str, optional -# Name of this transform, used to identify it in the dictionary -# produced by ``self.instantiate``, by default None -# prob : float, optional -# Probability of applying this transform, by default 1.0 -# """ + Parameters + ---------- + cutoff : tuple, optional + Cutoff frequency distribution, + by default ``("choice", [50, 100, 250, 500, 1000])`` + zeros : int, optional + Number of zero-crossings in filter, argument to + ``julius.LowPassFilters``, by default 51 + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ -# def __init__( -# self, -# cutoff: tuple = ("choice", [50, 100, 250, 500, 1000]), -# zeros: int = 51, -# name: str = None, -# prob: float = 1, -# ): -# super().__init__(name=name, prob=prob) + def __init__( + self, + cutoff: tuple=("choice", [50, 100, 250, 500, 1000]), + zeros: int=51, + name: str=None, + prob: float=1, ): + super().__init__(name=name, prob=prob) -# self.cutoff = cutoff -# self.zeros = zeros + self.cutoff = cutoff + self.zeros = zeros -# def _instantiate(self, state: RandomState): -# return {"cutoff": util.sample_from_dist(self.cutoff, state)} + def _instantiate(self, state: RandomState): + return {"cutoff": util.sample_from_dist(self.cutoff, state)} -# def _transform(self, signal, cutoff): -# return signal.high_pass(cutoff, zeros=self.zeros) + def _transform(self, signal, cutoff): + return signal.high_pass(cutoff, zeros=self.zeros) # class RescaleAudio(BaseTransform): @@ -1196,7 +1188,6 @@ class ClippingDistortion(BaseTransform): # def _transform(self, signal): # return signal.ensure_max_of_audio(self.val) - # class ShiftPhase(SpectralTransform): # """Shifts the phase of the audio. @@ -1228,7 +1219,6 @@ class ClippingDistortion(BaseTransform): # def _transform(self, signal, shift): # return signal.shift_phase(shift) - # class InvertPhase(ShiftPhase): # """Inverts the phase of the audio. @@ -1246,7 +1236,6 @@ class ClippingDistortion(BaseTransform): # def __init__(self, name: str = None, prob: float = 1): # super().__init__(shift=("const", np.pi), name=name, prob=prob) - # class CorruptPhase(SpectralTransform): # """Corrupts the phase of the audio. @@ -1277,7 +1266,6 @@ class ClippingDistortion(BaseTransform): # def _transform(self, signal, corruption): # return signal.shift_phase(shift=corruption) - # class FrequencyMask(SpectralTransform): # """Masks a band of frequencies at a center frequency # from the audio. @@ -1323,7 +1311,6 @@ class ClippingDistortion(BaseTransform): # def _transform(self, signal, fmin_hz: float, fmax_hz: float): # return signal.mask_frequencies(fmin_hz=fmin_hz, fmax_hz=fmax_hz) - # class TimeMask(SpectralTransform): # """Masks out contiguous time-steps from signal. @@ -1368,7 +1355,6 @@ class ClippingDistortion(BaseTransform): # def _transform(self, signal, tmin_s: float, tmax_s: float): # return signal.mask_timesteps(tmin_s=tmin_s, tmax_s=tmax_s) - # class MaskLowMagnitudes(SpectralTransform): # """Masks low magnitude regions out of signal. @@ -1401,7 +1387,6 @@ class ClippingDistortion(BaseTransform): # def _transform(self, signal, db_cutoff: float): # return signal.mask_low_magnitudes(db_cutoff) - # class Smoothing(BaseTransform): # """Convolves the signal with a smoothing window. @@ -1452,7 +1437,6 @@ class ClippingDistortion(BaseTransform): # out = out * (sscale / oscale) # return out - # class TimeNoise(TimeMask): # """Similar to :py:func:`audiotools.data.transforms.TimeMask`, but # replaces with noise instead of zeros. @@ -1494,7 +1478,6 @@ class ClippingDistortion(BaseTransform): # signal.phase = phase # return signal - # class FrequencyNoise(FrequencyMask): # """Similar to :py:func:`audiotools.data.transforms.FrequencyMask`, but # replaces with noise instead of zeros. @@ -1535,7 +1518,6 @@ class ClippingDistortion(BaseTransform): # signal.phase = phase # return signal - # class SpectralDenoising(Equalizer): # """Applies denoising algorithm detailed in # :py:func:`audiotools.ml.layers.spectral_gate.SpectralGate`, diff --git a/audio/audiotools/metrics/quality.py b/audio/audiotools/metrics/quality.py index eec3014bc..6390fd9a3 100644 --- a/audio/audiotools/metrics/quality.py +++ b/audio/audiotools/metrics/quality.py @@ -2,7 +2,8 @@ import os import numpy as np import paddle -from audio_signal import AudioSignal + +from ..core import AudioSignal def visqol( diff --git a/audio/audiotools/ml/accelerator.py b/audio/audiotools/ml/accelerator.py index 00974d2cb..1b8636581 100644 --- a/audio/audiotools/ml/accelerator.py +++ b/audio/audiotools/ml/accelerator.py @@ -3,13 +3,15 @@ import typing import paddle import paddle.distributed as dist -from paddle.io import DataLoader, DistributedBatchSampler, SequentialSampler +from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler +from paddle.io import SequenceSampler class ResumableDistributedSampler(DistributedBatchSampler): # pragma: no cover """Distributed sampler that can be resumed from a given start index.""" - def __init__(self, dataset, start_idx: int = None, **kwargs): + def __init__(self, dataset, start_idx: int=None, **kwargs): super().__init__(dataset, **kwargs) # Start index, allows to resume an experiment at the index it was self.start_idx = start_idx // self.num_replicas if start_idx is not None else 0 @@ -21,10 +23,10 @@ class ResumableDistributedSampler(DistributedBatchSampler): # pragma: no cover self.start_idx = 0 # set the index back to 0 so for the next epoch -class ResumableSequentialSampler(SequentialSampler): # pragma: no cover +class ResumableSequentialSampler(SequenceSampler): # pragma: no cover """Sequential sampler that can be resumed from a given start index.""" - def __init__(self, dataset, start_idx: int = None, **kwargs): + def __init__(self, dataset, start_idx: int=None, **kwargs): super().__init__(dataset, **kwargs) # Start index, allows to resume an experiment at the index it was self.start_idx = start_idx if start_idx is not None else 0 @@ -57,7 +59,7 @@ class Accelerator: # pragma: no cover (Note: This is a placeholder as PaddlePaddle doesn't have native support for AMP as of now) """ - def __init__(self, amp: bool = False): + def __init__(self, amp: bool=False): trainer_id = os.getenv("PADDLE_TRAINER_ID", None) self.world_size = paddle.distributed.get_world_size() @@ -139,13 +141,16 @@ class Accelerator: # pragma: no cover optimizer : paddle.optimizer.Optimizer Optimizer to step forward. """ - self.scaler.step(optimizer) + self.scaler.step(optimizer) def update(self): # https://www.paddlepaddle.org.cn/documentation/docs/zh/2.6/api/paddle/amp/GradScaler_cn.html#step-optimizer - self.scaler.update() + self.scaler.update() - def prepare_dataloader(self, dataset: typing.Iterable, start_idx: int = None, **kwargs): + def prepare_dataloader(self, + dataset: typing.Iterable, + start_idx: int=None, + **kwargs): """Wraps a dataset with a DataLoader, using the correct sampler if DDP is enabled. @@ -171,10 +176,10 @@ class Accelerator: # pragma: no cover shuffle=kwargs.get("shuffle", True), drop_last=kwargs.get("drop_last", False), num_replicas=self.world_size, - rank=self.local_rank, - ) + rank=self.local_rank, ) if "num_workers" in kwargs: - kwargs["num_workers"] = max(kwargs["num_workers"] // self.world_size, 1) + kwargs["num_workers"] = max(kwargs["num_workers"] // + self.world_size, 1) else: sampler = ResumableSequentialSampler(dataset, start_idx) @@ -182,8 +187,7 @@ class Accelerator: # pragma: no cover dataset, batch_sampler=sampler if self.use_ddp else None, sampler=sampler if not self.use_ddp else None, - **kwargs, - ) + **kwargs, ) return dataloader @staticmethod diff --git a/audio/audiotools/ml/basemodel.py b/audio/audiotools/ml/basemodel.py index 633b50430..19f33a82b 100644 --- a/audio/audiotools/ml/basemodel.py +++ b/audio/audiotools/ml/basemodel.py @@ -106,7 +106,7 @@ class BaseModel(nn.Layer): if not package: state_dict = {"state_dict": self.state_dict(), "metadata": metadata} - paddle.save(state_dict, path) + paddle.save(state_dict, str(path)) else: self._save_package(path, intern=intern, extern=extern, mock=mock) @@ -118,7 +118,7 @@ class BaseModel(nn.Layer): the first parameter. May not be valid if model is split across multiple devices. """ - return list(self.parameters())[0].device + return list(self.parameters())[0].place @classmethod def load( @@ -152,7 +152,7 @@ class BaseModel(nn.Layer): try: model = cls._load_package(location, package_name=package_name) except: - model_dict = paddle.load(location, "cpu") + model_dict = paddle.load(location) metadata = model_dict["metadata"] metadata["kwargs"].update(kwargs) @@ -163,7 +163,7 @@ class BaseModel(nn.Layer): metadata["kwargs"].pop(k) model = cls(*args, **metadata["kwargs"]) - model.load_state_dict(model_dict["state_dict"], strict=strict) + model.set_state_dict(model_dict["state_dict"]) model.metadata = metadata return model @@ -220,7 +220,7 @@ class BaseModel(nn.Layer): self.save(weights_path, package=False) for path, obj in extra_data.items(): - paddle.save(obj, target_base / path) + paddle.save(obj, str(target_base / path)) return target_base @@ -257,7 +257,7 @@ class BaseModel(nn.Layer): model_pth = "package.pth" if package else "weights.pth" model_pth = folder / model_pth - model = cls.load(model_pth, strict=strict) + model = cls.load(str(model_pth)) extra_data = {} excluded = ["package.pth", "weights.pth"] files = [ @@ -265,6 +265,6 @@ class BaseModel(nn.Layer): if x.is_file() and x.name not in excluded ] for f in files: - extra_data[f.name] = paddle.load(f, **kwargs) + extra_data[f.name] = paddle.load(str(f), **kwargs) return model, extra_data diff --git a/audio/audiotools/requirements.txt b/audio/audiotools/requirements.txt deleted file mode 100644 index 925b740fe..000000000 --- a/audio/audiotools/requirements.txt +++ /dev/null @@ -1,5 +0,0 @@ -soundfile -librosa -scipy -rich -flatten_dict \ No newline at end of file