parent
f2ca05e830
commit
e243128c0b
@ -1,10 +1,11 @@
|
|||||||
__version__ = "0.0.1"
|
__version__ = "0.0.1"
|
||||||
from .core import AudioSignal
|
from .core import AudioSignal
|
||||||
from .core import STFTParams
|
from .core import STFTParams
|
||||||
# from .core import Meter
|
from .core import Meter
|
||||||
from .core import util
|
from .core import util
|
||||||
|
from .core import highpass_filter, highpass_filters
|
||||||
from . import metrics
|
from . import metrics
|
||||||
from . import data
|
from . import data
|
||||||
from . import ml
|
from . import ml
|
||||||
from .data import datasets
|
from .data import datasets
|
||||||
from .data import transforms
|
from .data import transforms
|
||||||
|
@ -1,4 +1,15 @@
|
|||||||
from . import util
|
from . import util
|
||||||
|
from ._julius import fft_conv1d
|
||||||
|
from ._julius import FFTConv1d
|
||||||
|
from ._julius import highpass_filter
|
||||||
|
from ._julius import highpass_filters
|
||||||
|
from ._julius import lowpass_filter
|
||||||
|
from ._julius import LowPassFilter
|
||||||
|
from ._julius import LowPassFilters
|
||||||
|
from ._julius import pure_tone
|
||||||
|
from ._julius import split_bands
|
||||||
|
from ._julius import SplitBands
|
||||||
from .audio_signal import AudioSignal
|
from .audio_signal import AudioSignal
|
||||||
from .audio_signal import STFTParams
|
from .audio_signal import STFTParams
|
||||||
from .loudness import Meter
|
from .loudness import Meter
|
||||||
|
from .resample import resample_frac
|
||||||
|
@ -0,0 +1,714 @@
|
|||||||
|
# File under the MIT license, see https://github.com/adefossez/julius/LICENSE for details.
|
||||||
|
# Author: adefossez, 2020
|
||||||
|
"""
|
||||||
|
Implementation of a FFT based 1D convolution in PaddlePaddle.
|
||||||
|
While FFT is used in some cases for small kernel sizes, it is not the default for long ones, e.g. 512.
|
||||||
|
This module implements efficient FFT based convolutions for such cases. A typical
|
||||||
|
application is for evaluating FIR filters with a long receptive field, typically
|
||||||
|
evaluated with a stride of 1.
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
import typing
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
|
||||||
|
from .resample import sinc
|
||||||
|
|
||||||
|
|
||||||
|
def pad_to(tensor: paddle.Tensor,
|
||||||
|
target_length: int,
|
||||||
|
mode: str="constant",
|
||||||
|
value: float=0.0):
|
||||||
|
"""
|
||||||
|
Pad the given tensor to the given length, with 0s on the right.
|
||||||
|
"""
|
||||||
|
return F.pad(
|
||||||
|
tensor, (0, target_length - tensor.shape[-1]),
|
||||||
|
mode=mode,
|
||||||
|
value=value,
|
||||||
|
data_format="NCL")
|
||||||
|
|
||||||
|
|
||||||
|
def pure_tone(freq: float, sr: float=128, dur: float=4, device=None):
|
||||||
|
"""
|
||||||
|
Return a pure tone, i.e. cosine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
freq (float): frequency (in Hz)
|
||||||
|
sr (float): sample rate (in Hz)
|
||||||
|
dur (float): duration (in seconds)
|
||||||
|
"""
|
||||||
|
time = paddle.arange(int(sr * dur), dtype="float32") / sr
|
||||||
|
return paddle.cos(2 * math.pi * freq * time)
|
||||||
|
|
||||||
|
|
||||||
|
def unfold(_input, kernel_size: int, stride: int):
|
||||||
|
"""1D only unfolding similar to the one from PyTorch.
|
||||||
|
However PyTorch unfold is extremely slow.
|
||||||
|
|
||||||
|
Given an _input tensor of size `[*, T]` this will return
|
||||||
|
a tensor `[*, F, K]` with `K` the kernel size, and `F` the number
|
||||||
|
of frames. The i-th frame is a view onto `i * stride: i * stride + kernel_size`.
|
||||||
|
This will automatically pad the _input to cover at least once all entries in `_input`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
_input (Tensor): tensor for which to return the frames.
|
||||||
|
kernel_size (int): size of each frame.
|
||||||
|
stride (int): stride between each frame.
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
|
||||||
|
- Inputs: `_input` is `[*, T]`
|
||||||
|
- Output: `[*, F, kernel_size]` with `F = 1 + ceil((T - kernel_size) / stride)`
|
||||||
|
|
||||||
|
|
||||||
|
..Warning:: unlike PyTorch unfold, this will pad the _input
|
||||||
|
so that any position in `_input` is covered by at least one frame.
|
||||||
|
"""
|
||||||
|
shape = list(_input.shape)
|
||||||
|
length = shape.pop(-1)
|
||||||
|
n_frames = math.ceil((max(length, kernel_size) - kernel_size) / stride) + 1
|
||||||
|
tgt_length = (n_frames - 1) * stride + kernel_size
|
||||||
|
padded = F.pad(_input, (0, tgt_length - length), data_format="NCL")
|
||||||
|
strides: typing.List[int] = []
|
||||||
|
for dim in range(padded.dim()):
|
||||||
|
strides.append(padded.strides[dim])
|
||||||
|
assert strides.pop(-1) == 1, "data should be contiguous"
|
||||||
|
strides = strides + [stride, 1]
|
||||||
|
return padded.as_strided(shape + [n_frames, kernel_size], strides)
|
||||||
|
|
||||||
|
|
||||||
|
def _new_rfft(x: paddle.Tensor):
|
||||||
|
z = paddle.fft.rfft(x, axis=-1)
|
||||||
|
|
||||||
|
z_real = paddle.real(z)
|
||||||
|
z_imag = paddle.imag(z)
|
||||||
|
|
||||||
|
z_view_as_real = paddle.stack([z_real, z_imag], axis=-1)
|
||||||
|
return z_view_as_real
|
||||||
|
|
||||||
|
|
||||||
|
def _new_irfft(x: paddle.Tensor, length: int):
|
||||||
|
x_real = x[..., 0]
|
||||||
|
x_imag = x[..., 1]
|
||||||
|
x_view_as_complex = paddle.complex(x_real, x_imag)
|
||||||
|
return paddle.fft.irfft(x_view_as_complex, n=length, axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def _compl_mul_conjugate(a: paddle.Tensor, b: paddle.Tensor):
|
||||||
|
"""
|
||||||
|
Given a and b two tensors of dimension 4
|
||||||
|
with the last dimension being the real and imaginary part,
|
||||||
|
returns a multiplied by the conjugate of b, the multiplication
|
||||||
|
being with respect to the second dimension.
|
||||||
|
|
||||||
|
PaddlePaddle does not have direct support for complex number operations
|
||||||
|
using einsum in the same manner as PyTorch, but we can manually compute
|
||||||
|
the equivalent result.
|
||||||
|
"""
|
||||||
|
# Extract the real and imaginary parts of a and b
|
||||||
|
real_a = a[..., 0]
|
||||||
|
imag_a = a[..., 1]
|
||||||
|
real_b = b[..., 0]
|
||||||
|
imag_b = b[..., 1]
|
||||||
|
|
||||||
|
# Compute the multiplication with respect to the second dimension manually
|
||||||
|
real_part = paddle.einsum("bcft,dct->bdft", real_a, real_b) + paddle.einsum(
|
||||||
|
"bcft,dct->bdft", imag_a, imag_b)
|
||||||
|
imag_part = paddle.einsum("bcft,dct->bdft", imag_a, real_b) - paddle.einsum(
|
||||||
|
"bcft,dct->bdft", real_a, imag_b)
|
||||||
|
|
||||||
|
# Stack the real and imaginary parts together
|
||||||
|
result = paddle.stack([real_part, imag_part], axis=-1)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def fft_conv1d(
|
||||||
|
_input: paddle.Tensor,
|
||||||
|
weight: paddle.Tensor,
|
||||||
|
bias: Optional[paddle.Tensor]=None,
|
||||||
|
stride: int=1,
|
||||||
|
padding: int=0,
|
||||||
|
block_ratio: float=5, ):
|
||||||
|
"""
|
||||||
|
Same as `paddle.nn.functional.conv1d` but using FFT for the convolution.
|
||||||
|
Please check PaddlePaddle documentation for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
_input (Tensor): _input signal of shape `[B, C, T]`.
|
||||||
|
weight (Tensor): weight of the convolution `[D, C, K]` with `D` the number
|
||||||
|
of output channels.
|
||||||
|
bias (Tensor or None): if not None, bias term for the convolution.
|
||||||
|
stride (int): stride of convolution.
|
||||||
|
padding (int): padding to apply to the _input.
|
||||||
|
block_ratio (float): can be tuned for speed. The _input is splitted in chunks
|
||||||
|
with a size of `int(block_ratio * kernel_size)`.
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
|
||||||
|
- Inputs: `_input` is `[B, C, T]`, `weight` is `[D, C, K]` and bias is `[D]`.
|
||||||
|
- Output: `(*, T)`
|
||||||
|
|
||||||
|
|
||||||
|
..note::
|
||||||
|
This function is faster than `paddle.nn.functional.conv1d` only in specific cases.
|
||||||
|
Typically, the kernel size should be of the order of 256 to see any real gain,
|
||||||
|
for a stride of 1.
|
||||||
|
|
||||||
|
..Warning::
|
||||||
|
Dilation and groups are not supported at the moment. This function might use
|
||||||
|
more memory than the default Conv1d implementation.
|
||||||
|
"""
|
||||||
|
_input = F.pad(_input, (padding, padding), data_format="NCL")
|
||||||
|
batch, channels, length = _input.shape
|
||||||
|
out_channels, _, kernel_size = weight.shape
|
||||||
|
|
||||||
|
if length < kernel_size:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Input should be at least as large as the kernel size {kernel_size}, "
|
||||||
|
f"but it is only {length} samples long.")
|
||||||
|
if block_ratio < 1:
|
||||||
|
raise RuntimeError("Block ratio must be greater than 1.")
|
||||||
|
|
||||||
|
block_size: int = min(int(kernel_size * block_ratio), length)
|
||||||
|
fold_stride = block_size - kernel_size + 1
|
||||||
|
weight = pad_to(weight, block_size)
|
||||||
|
weight_z = _new_rfft(weight)
|
||||||
|
|
||||||
|
# We pad the _input and get the different frames, on which
|
||||||
|
frames = unfold(_input, block_size, fold_stride)
|
||||||
|
|
||||||
|
frames_z = _new_rfft(frames)
|
||||||
|
out_z = _compl_mul_conjugate(frames_z, weight_z)
|
||||||
|
out = _new_irfft(out_z, block_size)
|
||||||
|
# The last bit is invalid, because FFT will do a circular convolution.
|
||||||
|
out = out[..., :-kernel_size + 1]
|
||||||
|
out = out.reshape([batch, out_channels, -1])
|
||||||
|
out = out[..., ::stride]
|
||||||
|
target_length = (length - kernel_size) // stride + 1
|
||||||
|
out = out[..., :target_length]
|
||||||
|
if bias is not None:
|
||||||
|
out += bias[:, None]
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class FFTConv1d(paddle.nn.Layer):
|
||||||
|
"""
|
||||||
|
Same as `paddle.nn.Conv1D` but based on a custom FFT-based convolution.
|
||||||
|
Please check PaddlePaddle documentation for more information on `paddle.nn.Conv1D`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): number of _input channels.
|
||||||
|
out_channels (int): number of output channels.
|
||||||
|
kernel_size (int): kernel size of convolution.
|
||||||
|
stride (int): stride of convolution.
|
||||||
|
padding (int): padding to apply to the _input.
|
||||||
|
bias (bool): if True, use a bias term.
|
||||||
|
|
||||||
|
..note::
|
||||||
|
This module is faster than `paddle.nn.Conv1D` only in specific cases.
|
||||||
|
Typically, `kernel_size` should be of the order of 256 to see any real gain,
|
||||||
|
for a stride of 1.
|
||||||
|
|
||||||
|
..warning::
|
||||||
|
Dilation and groups are not supported at the moment. This module might use
|
||||||
|
more memory than the default Conv1D implementation.
|
||||||
|
|
||||||
|
>>> fftconv = FFTConv1d(12, 24, 128, 4)
|
||||||
|
>>> x = paddle.randn([4, 12, 1024])
|
||||||
|
>>> print(list(fftconv(x).shape))
|
||||||
|
[4, 24, 225]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
stride: int=1,
|
||||||
|
padding: int=0,
|
||||||
|
bias: bool=True, ):
|
||||||
|
super(FFTConv1d, self).__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.stride = stride
|
||||||
|
self.padding = padding
|
||||||
|
|
||||||
|
# Create a Conv1D layer to initialize weights and bias
|
||||||
|
conv = paddle.nn.Conv1D(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
bias_attr=bias)
|
||||||
|
self.weight = conv.weight
|
||||||
|
if bias:
|
||||||
|
self.bias = conv.bias
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
def forward(self, _input: paddle.Tensor):
|
||||||
|
return fft_conv1d(_input, self.weight, self.bias, self.stride,
|
||||||
|
self.padding)
|
||||||
|
|
||||||
|
|
||||||
|
class LowPassFilters(nn.Layer):
|
||||||
|
"""
|
||||||
|
Bank of low pass filters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
cutoffs: Sequence[float],
|
||||||
|
stride: int=1,
|
||||||
|
pad: bool=True,
|
||||||
|
zeros: float=8,
|
||||||
|
fft: Optional[bool]=None,
|
||||||
|
dtype="float32"):
|
||||||
|
super(LowPassFilters, self).__init__()
|
||||||
|
self.cutoffs = list(cutoffs)
|
||||||
|
if min(self.cutoffs) < 0:
|
||||||
|
raise ValueError("Minimum cutoff must be larger than zero.")
|
||||||
|
if max(self.cutoffs) > 0.5:
|
||||||
|
raise ValueError("A cutoff above 0.5 does not make sense.")
|
||||||
|
self.stride = stride
|
||||||
|
self.pad = pad
|
||||||
|
self.zeros = zeros
|
||||||
|
self.half_size = int(zeros / min([c for c in self.cutoffs if c > 0]) /
|
||||||
|
2)
|
||||||
|
if fft is None:
|
||||||
|
fft = self.half_size > 32
|
||||||
|
self.fft = fft
|
||||||
|
|
||||||
|
# Create filters
|
||||||
|
window = paddle.audio.functional.get_window(
|
||||||
|
"hann", 2 * self.half_size + 1, fftbins=False, dtype=dtype)
|
||||||
|
time = paddle.arange(
|
||||||
|
-self.half_size, self.half_size + 1, dtype="float32")
|
||||||
|
filters = []
|
||||||
|
for cutoff in cutoffs:
|
||||||
|
if cutoff == 0:
|
||||||
|
filter_ = paddle.zeros_like(time)
|
||||||
|
else:
|
||||||
|
filter_ = 2 * cutoff * window * sinc(2 * cutoff * math.pi *
|
||||||
|
time)
|
||||||
|
# Normalize filter
|
||||||
|
filter_ /= paddle.sum(filter_)
|
||||||
|
filters.append(filter_)
|
||||||
|
filters = paddle.stack(filters)[:, None]
|
||||||
|
self.filters = self.create_parameter(
|
||||||
|
shape=filters.shape,
|
||||||
|
default_initializer=nn.initializer.Constant(value=0.0),
|
||||||
|
dtype="float32",
|
||||||
|
is_bias=False,
|
||||||
|
attr=paddle.ParamAttr(trainable=False), )
|
||||||
|
self.filters.set_value(filters)
|
||||||
|
|
||||||
|
def forward(self, _input):
|
||||||
|
shape = list(_input.shape)
|
||||||
|
_input = _input.reshape([-1, 1, shape[-1]])
|
||||||
|
if self.pad:
|
||||||
|
_input = F.pad(
|
||||||
|
_input, (self.half_size, self.half_size),
|
||||||
|
mode="replicate",
|
||||||
|
data_format="NCL")
|
||||||
|
if self.fft:
|
||||||
|
out = fft_conv1d(_input, self.filters, stride=self.stride)
|
||||||
|
else:
|
||||||
|
out = F.conv1d(_input, self.filters, stride=self.stride)
|
||||||
|
|
||||||
|
shape.insert(0, len(self.cutoffs))
|
||||||
|
shape[-1] = out.shape[-1]
|
||||||
|
return out.transpose([1, 0, 2]).reshape(shape)
|
||||||
|
|
||||||
|
|
||||||
|
class LowPassFilter(nn.Layer):
|
||||||
|
"""
|
||||||
|
Same as `LowPassFilters` but applies a single low pass filter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
cutoff: float,
|
||||||
|
stride: int=1,
|
||||||
|
pad: bool=True,
|
||||||
|
zeros: float=8,
|
||||||
|
fft: Optional[bool]=None):
|
||||||
|
super(LowPassFilter, self).__init__()
|
||||||
|
self._lowpasses = LowPassFilters([cutoff], stride, pad, zeros, fft)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cutoff(self):
|
||||||
|
return self._lowpasses.cutoffs[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stride(self):
|
||||||
|
return self._lowpasses.stride
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pad(self):
|
||||||
|
return self._lowpasses.pad
|
||||||
|
|
||||||
|
@property
|
||||||
|
def zeros(self):
|
||||||
|
return self._lowpasses.zeros
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fft(self):
|
||||||
|
return self._lowpasses.fft
|
||||||
|
|
||||||
|
def forward(self, _input):
|
||||||
|
return self._lowpasses(_input)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def lowpass_filters(
|
||||||
|
_input: paddle.Tensor,
|
||||||
|
cutoffs: Sequence[float],
|
||||||
|
stride: int=1,
|
||||||
|
pad: bool=True,
|
||||||
|
zeros: float=8,
|
||||||
|
fft: Optional[bool]=None, ):
|
||||||
|
"""
|
||||||
|
Functional version of `LowPassFilters`, refer to this class for more information.
|
||||||
|
"""
|
||||||
|
return LowPassFilters(cutoffs, stride, pad, zeros, fft)(_input)
|
||||||
|
|
||||||
|
|
||||||
|
def lowpass_filter(_input: paddle.Tensor,
|
||||||
|
cutoff: float,
|
||||||
|
stride: int=1,
|
||||||
|
pad: bool=True,
|
||||||
|
zeros: float=8,
|
||||||
|
fft: Optional[bool]=None):
|
||||||
|
"""
|
||||||
|
Same as `lowpass_filters` but with a single cutoff frequency.
|
||||||
|
Output will not have a dimension inserted in the front.
|
||||||
|
"""
|
||||||
|
return lowpass_filters(_input, [cutoff], stride, pad, zeros, fft)[0]
|
||||||
|
|
||||||
|
|
||||||
|
class HighPassFilters(paddle.nn.Layer):
|
||||||
|
"""
|
||||||
|
Bank of high pass filters. See `julius.lowpass.LowPassFilters` for more
|
||||||
|
details on the implementation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cutoffs (list[float]): list of cutoff frequencies, in [0, 0.5] expressed as `f/f_s` where
|
||||||
|
f_s is the samplerate and `f` is the cutoff frequency.
|
||||||
|
The upper limit is 0.5, because a signal sampled at `f_s` contains only
|
||||||
|
frequencies under `f_s / 2`.
|
||||||
|
stride (int): how much to decimate the output. Probably not a good idea
|
||||||
|
to do so with a high pass filters though...
|
||||||
|
pad (bool): if True, appropriately pad the _input with zero over the edge. If `stride=1`,
|
||||||
|
the output will have the same length as the _input.
|
||||||
|
zeros (float): Number of zero crossings to keep.
|
||||||
|
Controls the receptive field of the Finite Impulse Response filter.
|
||||||
|
For filters with low cutoff frequency, e.g. 40Hz at 44.1kHz,
|
||||||
|
it is a bad idea to set this to a high value.
|
||||||
|
This is likely appropriate for most use. Lower values
|
||||||
|
will result in a faster filter, but with a slower attenuation around the
|
||||||
|
cutoff frequency.
|
||||||
|
fft (bool or None): if True, uses `julius.fftconv` rather than PyTorch convolutions.
|
||||||
|
If False, uses PyTorch convolutions. If None, either one will be chosen automatically
|
||||||
|
depending on the effective filter size.
|
||||||
|
|
||||||
|
|
||||||
|
..warning::
|
||||||
|
All the filters will use the same filter size, aligned on the lowest
|
||||||
|
frequency provided. If you combine a lot of filters with very diverse frequencies, it might
|
||||||
|
be more efficient to split them over multiple modules with similar frequencies.
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
|
||||||
|
- Input: `[*, T]`
|
||||||
|
- Output: `[F, *, T']`, with `T'=T` if `pad` is True and `stride` is 1, and
|
||||||
|
`F` is the numer of cutoff frequencies.
|
||||||
|
|
||||||
|
>>> highpass = HighPassFilters([1/4])
|
||||||
|
>>> x = paddle.randn([4, 12, 21, 1024])
|
||||||
|
>>> list(highpass(x).shape)
|
||||||
|
[1, 4, 12, 21, 1024]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
cutoffs: Sequence[float],
|
||||||
|
stride: int=1,
|
||||||
|
pad: bool=True,
|
||||||
|
zeros: float=8,
|
||||||
|
fft: Optional[bool]=None):
|
||||||
|
super().__init__()
|
||||||
|
self._lowpasses = LowPassFilters(cutoffs, stride, pad, zeros, fft)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cutoffs(self):
|
||||||
|
return self._lowpasses.cutoffs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stride(self):
|
||||||
|
return self._lowpasses.stride
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pad(self):
|
||||||
|
return self._lowpasses.pad
|
||||||
|
|
||||||
|
@property
|
||||||
|
def zeros(self):
|
||||||
|
return self._lowpasses.zeros
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fft(self):
|
||||||
|
return self._lowpasses.fft
|
||||||
|
|
||||||
|
def forward(self, _input):
|
||||||
|
lows = self._lowpasses(_input)
|
||||||
|
|
||||||
|
# We need to extract the right portion of the _input in case
|
||||||
|
# pad is False or stride > 1
|
||||||
|
if self.pad:
|
||||||
|
start, end = 0, _input.shape[-1]
|
||||||
|
else:
|
||||||
|
start = self._lowpasses.half_size
|
||||||
|
end = -start
|
||||||
|
_input = _input[..., start:end:self.stride]
|
||||||
|
highs = _input - lows
|
||||||
|
return highs
|
||||||
|
|
||||||
|
|
||||||
|
class HighPassFilter(paddle.nn.Layer):
|
||||||
|
"""
|
||||||
|
Same as `HighPassFilters` but applies a single high pass filter.
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
|
||||||
|
- Input: `[*, T]`
|
||||||
|
- Output: `[*, T']`, with `T'=T` if `pad` is True and `stride` is 1.
|
||||||
|
|
||||||
|
>>> highpass = HighPassFilter(1/4, stride=1)
|
||||||
|
>>> x = paddle.randn([4, 124])
|
||||||
|
>>> list(highpass(x).shape)
|
||||||
|
[4, 124]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
cutoff: float,
|
||||||
|
stride: int=1,
|
||||||
|
pad: bool=True,
|
||||||
|
zeros: float=8,
|
||||||
|
fft: Optional[bool]=None):
|
||||||
|
super().__init__()
|
||||||
|
self._highpasses = HighPassFilters([cutoff], stride, pad, zeros, fft)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cutoff(self):
|
||||||
|
return self._highpasses.cutoffs[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stride(self):
|
||||||
|
return self._highpasses.stride
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pad(self):
|
||||||
|
return self._highpasses.pad
|
||||||
|
|
||||||
|
@property
|
||||||
|
def zeros(self):
|
||||||
|
return self._highpasses.zeros
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fft(self):
|
||||||
|
return self._highpasses.fft
|
||||||
|
|
||||||
|
def forward(self, _input):
|
||||||
|
return self._highpasses(_input)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def highpass_filters(
|
||||||
|
_input: paddle.Tensor,
|
||||||
|
cutoffs: Sequence[float],
|
||||||
|
stride: int=1,
|
||||||
|
pad: bool=True,
|
||||||
|
zeros: float=8,
|
||||||
|
fft: Optional[bool]=None, ):
|
||||||
|
"""
|
||||||
|
Functional version of `HighPassFilters`, refer to this class for more information.
|
||||||
|
"""
|
||||||
|
return HighPassFilters(cutoffs, stride, pad, zeros, fft)(_input)
|
||||||
|
|
||||||
|
|
||||||
|
def highpass_filter(_input: paddle.Tensor,
|
||||||
|
cutoff: float,
|
||||||
|
stride: int=1,
|
||||||
|
pad: bool=True,
|
||||||
|
zeros: float=8,
|
||||||
|
fft: Optional[bool]=None):
|
||||||
|
"""
|
||||||
|
Functional version of `HighPassFilter`, refer to this class for more information.
|
||||||
|
Output will not have a dimension inserted in the front.
|
||||||
|
"""
|
||||||
|
return highpass_filters(_input, [cutoff], stride, pad, zeros, fft)[0]
|
||||||
|
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
|
|
||||||
|
def hz_to_mel(freqs: paddle.Tensor):
|
||||||
|
"""
|
||||||
|
Converts a Tensor of frequencies in hertz to the mel scale.
|
||||||
|
Uses the simple formula by O'Shaughnessy (1987).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
freqs (paddle.Tensor): frequencies to convert.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return 2595 * paddle.log10(1 + freqs / 700)
|
||||||
|
|
||||||
|
|
||||||
|
def mel_to_hz(mels: paddle.Tensor):
|
||||||
|
"""
|
||||||
|
Converts a Tensor of mel scaled frequencies to Hertz.
|
||||||
|
Uses the simple formula by O'Shaughnessy (1987).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mels (paddle.Tensor): mel frequencies to convert.
|
||||||
|
"""
|
||||||
|
return 700 * (10**(mels / 2595) - 1)
|
||||||
|
|
||||||
|
|
||||||
|
def mel_frequencies(n_mels: int, fmin: float, fmax: float):
|
||||||
|
"""
|
||||||
|
Return frequencies that are evenly spaced in mel scale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_mels (int): number of frequencies to return.
|
||||||
|
fmin (float): start from this frequency (in Hz).
|
||||||
|
fmax (float): finish at this frequency (in Hz).
|
||||||
|
|
||||||
|
"""
|
||||||
|
low = hz_to_mel(paddle.to_tensor(float(fmin))).item()
|
||||||
|
high = hz_to_mel(paddle.to_tensor(float(fmax))).item()
|
||||||
|
mels = paddle.linspace(low, high, n_mels)
|
||||||
|
return mel_to_hz(mels)
|
||||||
|
|
||||||
|
|
||||||
|
class SplitBands(paddle.nn.Layer):
|
||||||
|
"""
|
||||||
|
Decomposes a signal over the given frequency bands in the waveform domain using
|
||||||
|
a cascade of low pass filters as implemented by `julius.lowpass.LowPassFilters`.
|
||||||
|
You can either specify explicitly the frequency cutoffs, or just the number of bands,
|
||||||
|
in which case the frequency cutoffs will be spread out evenly in mel scale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample_rate (float): Sample rate of the input signal in Hz.
|
||||||
|
n_bands (int or None): number of bands, when not giving them explicitly with `cutoffs`.
|
||||||
|
In that case, the cutoff frequencies will be evenly spaced in mel-space.
|
||||||
|
cutoffs (list[float] or None): list of frequency cutoffs in Hz.
|
||||||
|
pad (bool): if True, appropriately pad the input with zero over the edge. If `stride=1`,
|
||||||
|
the output will have the same length as the input.
|
||||||
|
zeros (float): Number of zero crossings to keep. See `LowPassFilters` for more informations.
|
||||||
|
fft (bool or None): See `LowPassFilters` for more info.
|
||||||
|
|
||||||
|
..note::
|
||||||
|
The sum of all the bands will always be the input signal.
|
||||||
|
|
||||||
|
..warning::
|
||||||
|
Unlike `julius.lowpass.LowPassFilters`, the cutoffs frequencies must be provided in Hz along
|
||||||
|
with the sample rate.
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
|
||||||
|
- Input: `[*, T]`
|
||||||
|
- Output: `[B, *, T']`, with `T'=T` if `pad` is True.
|
||||||
|
If `n_bands` was provided, `B = n_bands` otherwise `B = len(cutoffs) + 1`
|
||||||
|
|
||||||
|
>>> bands = SplitBands(sample_rate=128, n_bands=10)
|
||||||
|
>>> x = paddle.randn(shape=[6, 4, 1024])
|
||||||
|
>>> list(bands(x).shape)
|
||||||
|
[10, 6, 4, 1024]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sample_rate: float,
|
||||||
|
n_bands: Optional[int]=None,
|
||||||
|
cutoffs: Optional[Sequence[float]]=None,
|
||||||
|
pad: bool=True,
|
||||||
|
zeros: float=8,
|
||||||
|
fft: Optional[bool]=None, ):
|
||||||
|
super(SplitBands, self).__init__()
|
||||||
|
if (cutoffs is None) + (n_bands is None) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"You must provide either n_bands, or cutoffs, but not both.")
|
||||||
|
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.n_bands = n_bands
|
||||||
|
self._cutoffs = list(cutoffs) if cutoffs is not None else None
|
||||||
|
self.pad = pad
|
||||||
|
self.zeros = zeros
|
||||||
|
self.fft = fft
|
||||||
|
|
||||||
|
if cutoffs is None:
|
||||||
|
if n_bands is None:
|
||||||
|
raise ValueError("You must provide one of n_bands or cutoffs.")
|
||||||
|
if not n_bands >= 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"n_bands must be greater than one (got {n_bands})")
|
||||||
|
cutoffs = mel_frequencies(n_bands + 1, 0, sample_rate / 2)[1:-1]
|
||||||
|
else:
|
||||||
|
if max(cutoffs) > 0.5 * sample_rate:
|
||||||
|
raise ValueError(
|
||||||
|
"A cutoff above sample_rate/2 does not make sense.")
|
||||||
|
if len(cutoffs) > 0:
|
||||||
|
self.lowpass = LowPassFilters(
|
||||||
|
[c / sample_rate for c in cutoffs],
|
||||||
|
pad=pad,
|
||||||
|
zeros=zeros,
|
||||||
|
fft=fft)
|
||||||
|
else:
|
||||||
|
self.lowpass = None # type: ignore
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
if self.lowpass is None:
|
||||||
|
return input[None]
|
||||||
|
lows = self.lowpass(input)
|
||||||
|
low = lows[0]
|
||||||
|
bands = [low]
|
||||||
|
for low_and_band in lows[1:]:
|
||||||
|
# Get a bandpass filter by subtracting lowpasses
|
||||||
|
band = low_and_band - low
|
||||||
|
bands.append(band)
|
||||||
|
low = low_and_band
|
||||||
|
# Last band is whatever is left in the signal
|
||||||
|
bands.append(input - low)
|
||||||
|
return paddle.stack(bands)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cutoffs(self):
|
||||||
|
if self._cutoffs is not None:
|
||||||
|
return self._cutoffs
|
||||||
|
elif self.lowpass is not None:
|
||||||
|
return [c * self.sample_rate for c in self.lowpass.cutoffs]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def split_bands(
|
||||||
|
signal: paddle.Tensor,
|
||||||
|
sample_rate: float,
|
||||||
|
n_bands: Optional[int]=None,
|
||||||
|
cutoffs: Optional[Sequence[float]]=None,
|
||||||
|
pad: bool=True,
|
||||||
|
zeros: float=8,
|
||||||
|
fft: Optional[bool]=None, ):
|
||||||
|
"""
|
||||||
|
Functional version of `SplitBands`, refer to this class for more information.
|
||||||
|
|
||||||
|
>>> x = paddle.randn(shape=[6, 4, 1024])
|
||||||
|
>>> list(split_bands(x, sample_rate=64, cutoffs=[12, 24]).shape)
|
||||||
|
[3, 6, 4, 1024]
|
||||||
|
"""
|
||||||
|
return SplitBands(sample_rate, n_bands, cutoffs, pad, zeros, fft)(signal)
|
@ -0,0 +1,390 @@
|
|||||||
|
import typing
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from . import _julius
|
||||||
|
from . import util
|
||||||
|
|
||||||
|
|
||||||
|
class DSPMixin:
|
||||||
|
_original_batch_size = None
|
||||||
|
_original_num_channels = None
|
||||||
|
_padded_signal_length = None
|
||||||
|
|
||||||
|
# def _preprocess_signal_for_windowing(self, window_duration, hop_duration):
|
||||||
|
# self._original_batch_size = self.batch_size
|
||||||
|
# self._original_num_channels = self.num_channels
|
||||||
|
|
||||||
|
# window_length = int(window_duration * self.sample_rate)
|
||||||
|
# hop_length = int(hop_duration * self.sample_rate)
|
||||||
|
|
||||||
|
# if window_length % hop_length != 0:
|
||||||
|
# factor = window_length // hop_length
|
||||||
|
# window_length = factor * hop_length
|
||||||
|
|
||||||
|
# self.zero_pad(hop_length, hop_length)
|
||||||
|
# self._padded_signal_length = self.signal_length
|
||||||
|
|
||||||
|
# return window_length, hop_length
|
||||||
|
|
||||||
|
# def windows(
|
||||||
|
# self, window_duration: float, hop_duration: float, preprocess: bool = True
|
||||||
|
# ):
|
||||||
|
# """Generator which yields windows of specified duration from signal with a specified
|
||||||
|
# hop length.
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
# ----------
|
||||||
|
# window_duration : float
|
||||||
|
# Duration of every window in seconds.
|
||||||
|
# hop_duration : float
|
||||||
|
# Hop between windows in seconds.
|
||||||
|
# preprocess : bool, optional
|
||||||
|
# Whether to preprocess the signal, so that the first sample is in
|
||||||
|
# the middle of the first window, by default True
|
||||||
|
|
||||||
|
# Yields
|
||||||
|
# ------
|
||||||
|
# AudioSignal
|
||||||
|
# Each window is returned as an AudioSignal.
|
||||||
|
# """
|
||||||
|
# if preprocess:
|
||||||
|
# window_length, hop_length = self._preprocess_signal_for_windowing(
|
||||||
|
# window_duration, hop_duration
|
||||||
|
# )
|
||||||
|
|
||||||
|
# self.audio_data = self.audio_data.reshape(-1, 1, self.signal_length)
|
||||||
|
|
||||||
|
# for b in range(self.batch_size):
|
||||||
|
# i = 0
|
||||||
|
# start_idx = i * hop_length
|
||||||
|
# while True:
|
||||||
|
# start_idx = i * hop_length
|
||||||
|
# i += 1
|
||||||
|
# end_idx = start_idx + window_length
|
||||||
|
# if end_idx > self.signal_length:
|
||||||
|
# break
|
||||||
|
# yield self[b, ..., start_idx:end_idx]
|
||||||
|
|
||||||
|
# def collect_windows(
|
||||||
|
# self, window_duration: float, hop_duration: float, preprocess: bool = True
|
||||||
|
# ):
|
||||||
|
# """Reshapes signal into windows of specified duration from signal with a specified
|
||||||
|
# hop length. Window are placed along the batch dimension. Use with
|
||||||
|
# :py:func:`audiotools.core.dsp.DSPMixin.overlap_and_add` to reconstruct the
|
||||||
|
# original signal.
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
# ----------
|
||||||
|
# window_duration : float
|
||||||
|
# Duration of every window in seconds.
|
||||||
|
# hop_duration : float
|
||||||
|
# Hop between windows in seconds.
|
||||||
|
# preprocess : bool, optional
|
||||||
|
# Whether to preprocess the signal, so that the first sample is in
|
||||||
|
# the middle of the first window, by default True
|
||||||
|
|
||||||
|
# Returns
|
||||||
|
# -------
|
||||||
|
# AudioSignal
|
||||||
|
# AudioSignal unfolded with shape ``(nb * nch * num_windows, 1, window_length)``
|
||||||
|
# """
|
||||||
|
# if preprocess:
|
||||||
|
# window_length, hop_length = self._preprocess_signal_for_windowing(
|
||||||
|
# window_duration, hop_duration
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # self.audio_data: (nb, nch, nt).
|
||||||
|
# unfolded = paddle.nn.functional.unfold(
|
||||||
|
# self.audio_data.reshape(-1, 1, 1, self.signal_length),
|
||||||
|
# kernel_size=(1, window_length),
|
||||||
|
# stride=(1, hop_length),
|
||||||
|
# )
|
||||||
|
# # unfolded: (nb * nch, window_length, num_windows).
|
||||||
|
# # -> (nb * nch * num_windows, 1, window_length)
|
||||||
|
# unfolded = unfolded.permute(0, 2, 1).reshape(-1, 1, window_length)
|
||||||
|
# self.audio_data = unfolded
|
||||||
|
# return self
|
||||||
|
|
||||||
|
# def overlap_and_add(self, hop_duration: float):
|
||||||
|
# """Function which takes a list of windows and overlap adds them into a
|
||||||
|
# signal the same length as ``audio_signal``.
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
# ----------
|
||||||
|
# hop_duration : float
|
||||||
|
# How much to shift for each window
|
||||||
|
# (overlap is window_duration - hop_duration) in seconds.
|
||||||
|
|
||||||
|
# Returns
|
||||||
|
# -------
|
||||||
|
# AudioSignal
|
||||||
|
# overlap-and-added signal.
|
||||||
|
# """
|
||||||
|
# hop_length = int(hop_duration * self.sample_rate)
|
||||||
|
# window_length = self.signal_length
|
||||||
|
|
||||||
|
# nb, nch = self._original_batch_size, self._original_num_channels
|
||||||
|
|
||||||
|
# unfolded = self.audio_data.reshape(nb * nch, -1, window_length).permute(0, 2, 1)
|
||||||
|
# folded = paddle.nn.functional.fold(
|
||||||
|
# unfolded,
|
||||||
|
# output_size=(1, self._padded_signal_length),
|
||||||
|
# kernel_size=(1, window_length),
|
||||||
|
# stride=(1, hop_length),
|
||||||
|
# )
|
||||||
|
|
||||||
|
# norm = paddle.ones_like(unfolded, device=unfolded.device)
|
||||||
|
# norm = paddle.nn.functional.fold(
|
||||||
|
# norm,
|
||||||
|
# output_size=(1, self._padded_signal_length),
|
||||||
|
# kernel_size=(1, window_length),
|
||||||
|
# stride=(1, hop_length),
|
||||||
|
# )
|
||||||
|
|
||||||
|
# folded = folded / norm
|
||||||
|
|
||||||
|
# folded = folded.reshape(nb, nch, -1)
|
||||||
|
# self.audio_data = folded
|
||||||
|
# self.trim(hop_length, hop_length)
|
||||||
|
# return self
|
||||||
|
|
||||||
|
def low_pass(self,
|
||||||
|
cutoffs: typing.Union[paddle.Tensor, np.ndarray, float],
|
||||||
|
zeros: int=51):
|
||||||
|
"""Low-passes the signal in-place. Each item in the batch
|
||||||
|
can have a different low-pass cutoff, if the input
|
||||||
|
to this signal is an array or tensor. If a float, all
|
||||||
|
items are given the same low-pass filter.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
cutoffs : typing.Union[paddle.Tensor, np.ndarray, float]
|
||||||
|
Cutoff in Hz of low-pass filter.
|
||||||
|
zeros : int, optional
|
||||||
|
Number of taps to use in low-pass filter, by default 51
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
Low-passed AudioSignal.
|
||||||
|
"""
|
||||||
|
cutoffs = util.ensure_tensor(cutoffs, 2, self.batch_size)
|
||||||
|
cutoffs = cutoffs / self.sample_rate
|
||||||
|
filtered = paddle.empty_like(self.audio_data)
|
||||||
|
|
||||||
|
for i, cutoff in enumerate(cutoffs):
|
||||||
|
lp_filter = _julius.LowPassFilter(cutoff.cpu(), zeros=zeros)
|
||||||
|
filtered[i] = lp_filter(self.audio_data[i])
|
||||||
|
|
||||||
|
self.audio_data = filtered
|
||||||
|
self.stft_data = None
|
||||||
|
return self
|
||||||
|
|
||||||
|
def high_pass(self,
|
||||||
|
cutoffs: typing.Union[paddle.Tensor, np.ndarray, float],
|
||||||
|
zeros: int=51):
|
||||||
|
"""High-passes the signal in-place. Each item in the batch
|
||||||
|
can have a different high-pass cutoff, if the input
|
||||||
|
to this signal is an array or tensor. If a float, all
|
||||||
|
items are given the same high-pass filter.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
cutoffs : typing.Union[paddle.Tensor, np.ndarray, float]
|
||||||
|
Cutoff in Hz of high-pass filter.
|
||||||
|
zeros : int, optional
|
||||||
|
Number of taps to use in high-pass filter, by default 51
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
High-passed AudioSignal.
|
||||||
|
"""
|
||||||
|
cutoffs = util.ensure_tensor(cutoffs, 2, self.batch_size)
|
||||||
|
cutoffs = cutoffs / self.sample_rate
|
||||||
|
filtered = paddle.empty_like(self.audio_data)
|
||||||
|
|
||||||
|
for i, cutoff in enumerate(cutoffs):
|
||||||
|
hp_filter = _julius.HighPassFilter(cutoff.cpu(), zeros=zeros)
|
||||||
|
filtered[i] = hp_filter(self.audio_data[i])
|
||||||
|
|
||||||
|
self.audio_data = filtered
|
||||||
|
self.stft_data = None
|
||||||
|
return self
|
||||||
|
|
||||||
|
# def mask_frequencies(
|
||||||
|
# self,
|
||||||
|
# fmin_hz: typing.Union[paddle.Tensor, np.ndarray, float],
|
||||||
|
# fmax_hz: typing.Union[paddle.Tensor, np.ndarray, float],
|
||||||
|
# val: float = 0.0,
|
||||||
|
# ):
|
||||||
|
# """Masks frequencies between ``fmin_hz`` and ``fmax_hz``, and fills them
|
||||||
|
# with the value specified by ``val``. Useful for implementing SpecAug.
|
||||||
|
# The min and max can be different for every item in the batch.
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
# ----------
|
||||||
|
# fmin_hz : typing.Union[paddle.Tensor, np.ndarray, float]
|
||||||
|
# Lower end of band to mask out.
|
||||||
|
# fmax_hz : typing.Union[paddle.Tensor, np.ndarray, float]
|
||||||
|
# Upper end of band to mask out.
|
||||||
|
# val : float, optional
|
||||||
|
# Value to fill in, by default 0.0
|
||||||
|
|
||||||
|
# Returns
|
||||||
|
# -------
|
||||||
|
# AudioSignal
|
||||||
|
# Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
|
||||||
|
# masked audio data.
|
||||||
|
# """
|
||||||
|
# # SpecAug
|
||||||
|
# mag, phase = self.magnitude, self.phase
|
||||||
|
# fmin_hz = util.ensure_tensor(fmin_hz, ndim=mag.ndim)
|
||||||
|
# fmax_hz = util.ensure_tensor(fmax_hz, ndim=mag.ndim)
|
||||||
|
# assert paddle.all(fmin_hz < fmax_hz)
|
||||||
|
|
||||||
|
# # build mask
|
||||||
|
# nbins = mag.shape[-2]
|
||||||
|
# bins_hz = paddle.linspace(0, self.sample_rate / 2, nbins, device=self.device)
|
||||||
|
# bins_hz = bins_hz[None, None, :, None].repeat(
|
||||||
|
# self.batch_size, 1, 1, mag.shape[-1]
|
||||||
|
# )
|
||||||
|
# mask = (fmin_hz <= bins_hz) & (bins_hz < fmax_hz)
|
||||||
|
# mask = mask.to(self.device)
|
||||||
|
|
||||||
|
# mag = mag.masked_fill(mask, val)
|
||||||
|
# phase = phase.masked_fill(mask, val)
|
||||||
|
# self.stft_data = mag * paddle.exp(1j * phase)
|
||||||
|
# return self
|
||||||
|
|
||||||
|
# def mask_timesteps(
|
||||||
|
# self,
|
||||||
|
# tmin_s: typing.Union[paddle.Tensor, np.ndarray, float],
|
||||||
|
# tmax_s: typing.Union[paddle.Tensor, np.ndarray, float],
|
||||||
|
# val: float = 0.0,
|
||||||
|
# ):
|
||||||
|
# """Masks timesteps between ``tmin_s`` and ``tmax_s``, and fills them
|
||||||
|
# with the value specified by ``val``. Useful for implementing SpecAug.
|
||||||
|
# The min and max can be different for every item in the batch.
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
# ----------
|
||||||
|
# tmin_s : typing.Union[paddle.Tensor, np.ndarray, float]
|
||||||
|
# Lower end of timesteps to mask out.
|
||||||
|
# tmax_s : typing.Union[paddle.Tensor, np.ndarray, float]
|
||||||
|
# Upper end of timesteps to mask out.
|
||||||
|
# val : float, optional
|
||||||
|
# Value to fill in, by default 0.0
|
||||||
|
|
||||||
|
# Returns
|
||||||
|
# -------
|
||||||
|
# AudioSignal
|
||||||
|
# Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
|
||||||
|
# masked audio data.
|
||||||
|
# """
|
||||||
|
# # SpecAug
|
||||||
|
# mag, phase = self.magnitude, self.phase
|
||||||
|
# tmin_s = util.ensure_tensor(tmin_s, ndim=mag.ndim)
|
||||||
|
# tmax_s = util.ensure_tensor(tmax_s, ndim=mag.ndim)
|
||||||
|
|
||||||
|
# assert paddle.all(tmin_s < tmax_s)
|
||||||
|
|
||||||
|
# # build mask
|
||||||
|
# nt = mag.shape[-1]
|
||||||
|
# bins_t = paddle.linspace(0, self.signal_duration, nt, device=self.device)
|
||||||
|
# bins_t = bins_t[None, None, None, :].repeat(
|
||||||
|
# self.batch_size, 1, mag.shape[-2], 1
|
||||||
|
# )
|
||||||
|
# mask = (tmin_s <= bins_t) & (bins_t < tmax_s)
|
||||||
|
|
||||||
|
# mag = mag.masked_fill(mask, val)
|
||||||
|
# phase = phase.masked_fill(mask, val)
|
||||||
|
# self.stft_data = mag * paddle.exp(1j * phase)
|
||||||
|
# return self
|
||||||
|
|
||||||
|
# def mask_low_magnitudes(
|
||||||
|
# self, db_cutoff: typing.Union[paddle.Tensor, np.ndarray, float], val: float = 0.0
|
||||||
|
# ):
|
||||||
|
# """Mask away magnitudes below a specified threshold, which
|
||||||
|
# can be different for every item in the batch.
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
# ----------
|
||||||
|
# db_cutoff : typing.Union[paddle.Tensor, np.ndarray, float]
|
||||||
|
# Decibel value for which things below it will be masked away.
|
||||||
|
# val : float, optional
|
||||||
|
# Value to fill in for masked portions, by default 0.0
|
||||||
|
|
||||||
|
# Returns
|
||||||
|
# -------
|
||||||
|
# AudioSignal
|
||||||
|
# Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
|
||||||
|
# masked audio data.
|
||||||
|
# """
|
||||||
|
# mag = self.magnitude
|
||||||
|
# log_mag = self.log_magnitude()
|
||||||
|
|
||||||
|
# db_cutoff = util.ensure_tensor(db_cutoff, ndim=mag.ndim)
|
||||||
|
# mask = log_mag < db_cutoff
|
||||||
|
# mag = mag.masked_fill(mask, val)
|
||||||
|
|
||||||
|
# self.magnitude = mag
|
||||||
|
# return self
|
||||||
|
|
||||||
|
# def shift_phase(self, shift: typing.Union[paddle.Tensor, np.ndarray, float]):
|
||||||
|
# """Shifts the phase by a constant value.
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
# ----------
|
||||||
|
# shift : typing.Union[paddle.Tensor, np.ndarray, float]
|
||||||
|
# What to shift the phase by.
|
||||||
|
|
||||||
|
# Returns
|
||||||
|
# -------
|
||||||
|
# AudioSignal
|
||||||
|
# Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
|
||||||
|
# masked audio data.
|
||||||
|
# """
|
||||||
|
# shift = util.ensure_tensor(shift, ndim=self.phase.ndim)
|
||||||
|
# self.phase = self.phase + shift
|
||||||
|
# return self
|
||||||
|
|
||||||
|
# def corrupt_phase(self, scale: typing.Union[paddle.Tensor, np.ndarray, float]):
|
||||||
|
# """Corrupts the phase randomly by some scaled value.
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
# ----------
|
||||||
|
# scale : typing.Union[paddle.Tensor, np.ndarray, float]
|
||||||
|
# Standard deviation of noise to add to the phase.
|
||||||
|
|
||||||
|
# Returns
|
||||||
|
# -------
|
||||||
|
# AudioSignal
|
||||||
|
# Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
|
||||||
|
# masked audio data.
|
||||||
|
# """
|
||||||
|
# scale = util.ensure_tensor(scale, ndim=self.phase.ndim)
|
||||||
|
# self.phase = self.phase + scale * paddle.randn_like(self.phase)
|
||||||
|
# return self
|
||||||
|
|
||||||
|
# def preemphasis(self, coef: float = 0.85):
|
||||||
|
# """Applies pre-emphasis to audio signal.
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
# ----------
|
||||||
|
# coef : float, optional
|
||||||
|
# How much pre-emphasis to apply, lower values do less. 0 does nothing.
|
||||||
|
# by default 0.85
|
||||||
|
|
||||||
|
# Returns
|
||||||
|
# -------
|
||||||
|
# AudioSignal
|
||||||
|
# Pre-emphasized signal.
|
||||||
|
# """
|
||||||
|
# kernel = paddle.to_tensor([1, -coef, 0]).view(1, 1, -1).to(self.device)
|
||||||
|
# x = self.audio_data.reshape(-1, 1, self.signal_length)
|
||||||
|
# x = paddle.nn.functional.conv1d(x, kernel, padding=1)
|
||||||
|
# self.audio_data = x.reshape(*self.audio_data.shape)
|
||||||
|
# return self
|
@ -0,0 +1,681 @@
|
|||||||
|
import typing
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from . import util
|
||||||
|
from ._julius import SplitBands
|
||||||
|
|
||||||
|
# from . import _julius
|
||||||
|
|
||||||
|
|
||||||
|
class EffectMixin:
|
||||||
|
GAIN_FACTOR = np.log(10) / 20
|
||||||
|
"""Gain factor for converting between amplitude and decibels."""
|
||||||
|
CODEC_PRESETS = {
|
||||||
|
"8-bit": {
|
||||||
|
"format": "wav",
|
||||||
|
"encoding": "ULAW",
|
||||||
|
"bits_per_sample": 8
|
||||||
|
},
|
||||||
|
"GSM-FR": {
|
||||||
|
"format": "gsm"
|
||||||
|
},
|
||||||
|
"MP3": {
|
||||||
|
"format": "mp3",
|
||||||
|
"compression": -9
|
||||||
|
},
|
||||||
|
"Vorbis": {
|
||||||
|
"format": "vorbis",
|
||||||
|
"compression": -1
|
||||||
|
},
|
||||||
|
"Ogg": {
|
||||||
|
"format": "ogg",
|
||||||
|
"compression": -1,
|
||||||
|
},
|
||||||
|
"Amr-nb": {
|
||||||
|
"format": "amr-nb"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
"""Presets for applying codecs via torchaudio."""
|
||||||
|
|
||||||
|
def mix(
|
||||||
|
self,
|
||||||
|
other,
|
||||||
|
snr: typing.Union[paddle.Tensor, np.ndarray, float]=10,
|
||||||
|
other_eq: typing.Union[paddle.Tensor, np.ndarray]=None, ):
|
||||||
|
"""Mixes noise with signal at specified
|
||||||
|
signal-to-noise ratio. Optionally, the
|
||||||
|
other signal can be equalized in-place.
|
||||||
|
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
other : AudioSignal
|
||||||
|
AudioSignal object to mix with.
|
||||||
|
snr : typing.Union[paddle.Tensor, np.ndarray, float], optional
|
||||||
|
Signal to noise ratio, by default 10
|
||||||
|
other_eq : typing.Union[paddle.Tensor, np.ndarray], optional
|
||||||
|
EQ curve to apply to other signal, if any, by default None
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
In-place modification of AudioSignal.
|
||||||
|
"""
|
||||||
|
snr = util.ensure_tensor(snr)
|
||||||
|
|
||||||
|
pad_len = max(0, self.signal_length - other.signal_length)
|
||||||
|
other.zero_pad(0, pad_len)
|
||||||
|
other.truncate_samples(self.signal_length)
|
||||||
|
if other_eq is not None:
|
||||||
|
other = other.equalizer(other_eq)
|
||||||
|
|
||||||
|
tgt_loudness = self.loudness() - snr
|
||||||
|
other = other.normalize(tgt_loudness)
|
||||||
|
|
||||||
|
self.audio_data = self.audio_data + other.audio_data
|
||||||
|
return self
|
||||||
|
|
||||||
|
def convolve(self, other, start_at_max: bool=True):
|
||||||
|
"""Convolves self with other.
|
||||||
|
This function uses FFTs to do the convolution.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
other : AudioSignal
|
||||||
|
Signal to convolve with.
|
||||||
|
start_at_max : bool, optional
|
||||||
|
Whether to start at the max value of other signal, to
|
||||||
|
avoid inducing delays, by default True
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
Convolved signal, in-place.
|
||||||
|
"""
|
||||||
|
from . import AudioSignal
|
||||||
|
|
||||||
|
pad_len = self.signal_length - other.signal_length
|
||||||
|
|
||||||
|
if pad_len > 0:
|
||||||
|
other.zero_pad(0, pad_len)
|
||||||
|
else:
|
||||||
|
other.truncate_samples(self.signal_length)
|
||||||
|
|
||||||
|
if start_at_max:
|
||||||
|
# Use roll to rotate over the max for every item
|
||||||
|
# so that the impulse responses don't induce any
|
||||||
|
# delay.
|
||||||
|
idx = paddle.argmax(paddle.abs(other.audio_data), axis=-1)
|
||||||
|
irs = paddle.zeros_like(other.audio_data)
|
||||||
|
for i in range(other.batch_size):
|
||||||
|
irs[i] = paddle.roll(
|
||||||
|
other.audio_data[i], shifts=-idx[i].item(), axis=-1)
|
||||||
|
other = AudioSignal(irs, other.sample_rate)
|
||||||
|
|
||||||
|
delta = paddle.zeros_like(other.audio_data)
|
||||||
|
delta[..., 0] = 1
|
||||||
|
|
||||||
|
length = self.signal_length
|
||||||
|
delta_fft = paddle.fft.rfft(delta, n=length)
|
||||||
|
other_fft = paddle.fft.rfft(other.audio_data, n=length)
|
||||||
|
self_fft = paddle.fft.rfft(self.audio_data, n=length)
|
||||||
|
|
||||||
|
convolved_fft = other_fft * self_fft
|
||||||
|
convolved_audio = paddle.fft.irfft(convolved_fft, n=length)
|
||||||
|
|
||||||
|
delta_convolved_fft = other_fft * delta_fft
|
||||||
|
delta_audio = paddle.fft.irfft(delta_convolved_fft, n=length)
|
||||||
|
|
||||||
|
# Use the delta to rescale the audio exactly as needed.
|
||||||
|
delta_max = paddle.max(paddle.abs(delta_audio), axis=-1, keepdim=True)
|
||||||
|
scale = 1 / paddle.clip(delta_max, min=1e-5)
|
||||||
|
convolved_audio = convolved_audio * scale
|
||||||
|
|
||||||
|
self.audio_data = convolved_audio
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def apply_ir(
|
||||||
|
self,
|
||||||
|
ir,
|
||||||
|
drr: typing.Union[paddle.Tensor, np.ndarray, float]=None,
|
||||||
|
ir_eq: typing.Union[paddle.Tensor, np.ndarray]=None,
|
||||||
|
use_original_phase: bool=False, ):
|
||||||
|
"""Applies an impulse response to the signal. If ` is`ir_eq``
|
||||||
|
is specified, the impulse response is equalized before
|
||||||
|
it is applied, using the given curve.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
ir : AudioSignal
|
||||||
|
Impulse response to convolve with.
|
||||||
|
drr : typing.Union[paddle.Tensor, np.ndarray, float], optional
|
||||||
|
Direct-to-reverberant ratio that impulse response will be
|
||||||
|
altered to, if specified, by default None
|
||||||
|
ir_eq : typing.Union[paddle.Tensor, np.ndarray], optional
|
||||||
|
Equalization that will be applied to impulse response
|
||||||
|
if specified, by default None
|
||||||
|
use_original_phase : bool, optional
|
||||||
|
Whether to use the original phase, instead of the convolved
|
||||||
|
phase, by default False
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
Signal with impulse response applied to it
|
||||||
|
"""
|
||||||
|
if ir_eq is not None:
|
||||||
|
ir = ir.equalizer(ir_eq)
|
||||||
|
if drr is not None:
|
||||||
|
ir = ir.alter_drr(drr)
|
||||||
|
|
||||||
|
# Save the peak before
|
||||||
|
max_spk = self.audio_data.abs().max(axis=-1, keepdim=True)
|
||||||
|
|
||||||
|
# Augment the impulse response to simulate microphone effects
|
||||||
|
# and with varying direct-to-reverberant ratio.
|
||||||
|
phase = self.phase
|
||||||
|
self.convolve(ir)
|
||||||
|
|
||||||
|
# Use the input phase
|
||||||
|
if use_original_phase:
|
||||||
|
self.stft()
|
||||||
|
self.stft_data = self.magnitude * paddle.exp(1j * phase)
|
||||||
|
self.istft()
|
||||||
|
|
||||||
|
# Rescale to the input's amplitude
|
||||||
|
max_transformed = self.audio_data.abs().max(axis=-1, keepdim=True)
|
||||||
|
scale_factor = max_spk.clip(1e-8) / max_transformed.clip(1e-8)
|
||||||
|
self = self * scale_factor
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def ensure_max_of_audio(self, _max: float=1.0):
|
||||||
|
"""Ensures that ``abs(audio_data) <= max``.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
max : float, optional
|
||||||
|
Max absolute value of signal, by default 1.0
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
Signal with values scaled between -max and max.
|
||||||
|
"""
|
||||||
|
peak = self.audio_data.abs().max(axis=-1, keepdim=True)
|
||||||
|
peak_gain = paddle.ones_like(peak)
|
||||||
|
peak_gain[peak > _max] = _max / peak[peak > _max]
|
||||||
|
self.audio_data = self.audio_data * peak_gain
|
||||||
|
return self
|
||||||
|
|
||||||
|
def normalize(self,
|
||||||
|
db: typing.Union[paddle.Tensor, np.ndarray, float]=-24.0):
|
||||||
|
"""Normalizes the signal's volume to the specified db, in LUFS.
|
||||||
|
This is GPU-compatible, making for very fast loudness normalization.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
db : typing.Union[paddle.Tensor, np.ndarray, float], optional
|
||||||
|
Loudness to normalize to, by default -24.0
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
Normalized audio signal.
|
||||||
|
"""
|
||||||
|
db = util.ensure_tensor(db)
|
||||||
|
ref_db = self.loudness()
|
||||||
|
gain = db - ref_db
|
||||||
|
gain = paddle.exp(gain * self.GAIN_FACTOR)
|
||||||
|
|
||||||
|
self.audio_data = self.audio_data * gain[:, None, None]
|
||||||
|
return self
|
||||||
|
|
||||||
|
# def volume_change(self, db: typing.Union[paddle.Tensor, np.ndarray, float]):
|
||||||
|
# """Change volume of signal by some amount, in dB.
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
# ----------
|
||||||
|
# db : typing.Union[paddle.Tensor, np.ndarray, float]
|
||||||
|
# Amount to change volume by.
|
||||||
|
|
||||||
|
# Returns
|
||||||
|
# -------
|
||||||
|
# AudioSignal
|
||||||
|
# Signal at new volume.
|
||||||
|
# """
|
||||||
|
# db = util.ensure_tensor(db, ndim=1).to(self.device)
|
||||||
|
# gain = torch.exp(db * self.GAIN_FACTOR)
|
||||||
|
# self.audio_data = self.audio_data * gain[:, None, None]
|
||||||
|
# return self
|
||||||
|
|
||||||
|
# def _to_2d(self):
|
||||||
|
# waveform = self.audio_data.reshape(-1, self.signal_length)
|
||||||
|
# return waveform
|
||||||
|
|
||||||
|
# def _to_3d(self, waveform):
|
||||||
|
# return waveform.reshape(self.batch_size, self.num_channels, -1)
|
||||||
|
|
||||||
|
# def pitch_shift(self, n_semitones: int, quick: bool = True):
|
||||||
|
# """Pitch shift the signal. All items in the batch
|
||||||
|
# get the same pitch shift.
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
# ----------
|
||||||
|
# n_semitones : int
|
||||||
|
# How many semitones to shift the signal by.
|
||||||
|
# quick : bool, optional
|
||||||
|
# Using quick pitch shifting, by default True
|
||||||
|
|
||||||
|
# Returns
|
||||||
|
# -------
|
||||||
|
# AudioSignal
|
||||||
|
# Pitch shifted audio signal.
|
||||||
|
# """
|
||||||
|
# device = self.device
|
||||||
|
# effects = [
|
||||||
|
# ["pitch", str(n_semitones * 100)],
|
||||||
|
# ["rate", str(self.sample_rate)],
|
||||||
|
# ]
|
||||||
|
# if quick:
|
||||||
|
# effects[0].insert(1, "-q")
|
||||||
|
|
||||||
|
# waveform = self._to_2d().cpu()
|
||||||
|
# waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
|
||||||
|
# waveform, self.sample_rate, effects, channels_first=True
|
||||||
|
# )
|
||||||
|
# self.sample_rate = sample_rate
|
||||||
|
# self.audio_data = self._to_3d(waveform)
|
||||||
|
# return self.to(device)
|
||||||
|
|
||||||
|
# def time_stretch(self, factor: float, quick: bool = True):
|
||||||
|
# """Time stretch the audio signal.
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
# ----------
|
||||||
|
# factor : float
|
||||||
|
# Factor by which to stretch the AudioSignal. Typically
|
||||||
|
# between 0.8 and 1.2.
|
||||||
|
# quick : bool, optional
|
||||||
|
# Whether to use quick time stretching, by default True
|
||||||
|
|
||||||
|
# Returns
|
||||||
|
# -------
|
||||||
|
# AudioSignal
|
||||||
|
# Time-stretched AudioSignal.
|
||||||
|
# """
|
||||||
|
# device = self.device
|
||||||
|
# effects = [
|
||||||
|
# ["tempo", str(factor)],
|
||||||
|
# ["rate", str(self.sample_rate)],
|
||||||
|
# ]
|
||||||
|
# if quick:
|
||||||
|
# effects[0].insert(1, "-q")
|
||||||
|
|
||||||
|
# waveform = self._to_2d().cpu()
|
||||||
|
# waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
|
||||||
|
# waveform, self.sample_rate, effects, channels_first=True
|
||||||
|
# )
|
||||||
|
# self.sample_rate = sample_rate
|
||||||
|
# self.audio_data = self._to_3d(waveform)
|
||||||
|
# return self.to(device)
|
||||||
|
|
||||||
|
# def apply_codec(
|
||||||
|
# self,
|
||||||
|
# preset: str = None,
|
||||||
|
# format: str = "wav",
|
||||||
|
# encoding: str = None,
|
||||||
|
# bits_per_sample: int = None,
|
||||||
|
# compression: int = None,
|
||||||
|
# ): # pragma: no cover
|
||||||
|
# """Applies an audio codec to the signal.
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
# ----------
|
||||||
|
# preset : str, optional
|
||||||
|
# One of the keys in ``self.CODEC_PRESETS``, by default None
|
||||||
|
# format : str, optional
|
||||||
|
# Format for audio codec, by default "wav"
|
||||||
|
# encoding : str, optional
|
||||||
|
# Encoding to use, by default None
|
||||||
|
# bits_per_sample : int, optional
|
||||||
|
# How many bits per sample, by default None
|
||||||
|
# compression : int, optional
|
||||||
|
# Compression amount of codec, by default None
|
||||||
|
|
||||||
|
# Returns
|
||||||
|
# -------
|
||||||
|
# AudioSignal
|
||||||
|
# AudioSignal with codec applied.
|
||||||
|
|
||||||
|
# Raises
|
||||||
|
# ------
|
||||||
|
# ValueError
|
||||||
|
# If preset is not in ``self.CODEC_PRESETS``, an error
|
||||||
|
# is thrown.
|
||||||
|
# """
|
||||||
|
# torchaudio_version_070 = "0.7" in torchaudio.__version__
|
||||||
|
# if torchaudio_version_070:
|
||||||
|
# return self
|
||||||
|
|
||||||
|
# kwargs = {
|
||||||
|
# "format": format,
|
||||||
|
# "encoding": encoding,
|
||||||
|
# "bits_per_sample": bits_per_sample,
|
||||||
|
# "compression": compression,
|
||||||
|
# }
|
||||||
|
|
||||||
|
# if preset is not None:
|
||||||
|
# if preset in self.CODEC_PRESETS:
|
||||||
|
# kwargs = self.CODEC_PRESETS[preset]
|
||||||
|
# else:
|
||||||
|
# raise ValueError(
|
||||||
|
# f"Unknown preset: {preset}. "
|
||||||
|
# f"Known presets: {list(self.CODEC_PRESETS.keys())}"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# waveform = self._to_2d()
|
||||||
|
# if kwargs["format"] in ["vorbis", "mp3", "ogg", "amr-nb"]:
|
||||||
|
# # Apply it in a for loop
|
||||||
|
# augmented = torch.cat(
|
||||||
|
# [
|
||||||
|
# torchaudio.functional.apply_codec(
|
||||||
|
# waveform[i][None, :], self.sample_rate, **kwargs
|
||||||
|
# )
|
||||||
|
# for i in range(waveform.shape[0])
|
||||||
|
# ],
|
||||||
|
# dim=0,
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# augmented = torchaudio.functional.apply_codec(
|
||||||
|
# waveform, self.sample_rate, **kwargs
|
||||||
|
# )
|
||||||
|
# augmented = self._to_3d(augmented)
|
||||||
|
|
||||||
|
# self.audio_data = augmented
|
||||||
|
# return self
|
||||||
|
|
||||||
|
def mel_filterbank(self, n_bands: int):
|
||||||
|
"""Breaks signal into mel bands.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
n_bands : int
|
||||||
|
Number of mel bands to use.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
paddle.Tensor
|
||||||
|
Mel-filtered bands, with last axis being the band index.
|
||||||
|
"""
|
||||||
|
filterbank = SplitBands(self.sample_rate, n_bands).float()
|
||||||
|
filtered = filterbank(self.audio_data)
|
||||||
|
return filtered.transpose([1, 2, 3, 0])
|
||||||
|
|
||||||
|
def equalizer(self, db: typing.Union[paddle.Tensor, np.ndarray]):
|
||||||
|
"""Applies a mel-spaced equalizer to the audio signal.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
db : typing.Union[paddle.Tensor, np.ndarray]
|
||||||
|
EQ curve to apply.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
AudioSignal with equalization applied.
|
||||||
|
"""
|
||||||
|
db = util.ensure_tensor(db)
|
||||||
|
n_bands = db.shape[-1]
|
||||||
|
fbank = self.mel_filterbank(n_bands)
|
||||||
|
|
||||||
|
# If there's a batch dimension, make sure it's the same.
|
||||||
|
if db.ndim == 2:
|
||||||
|
if db.shape[0] != 1:
|
||||||
|
assert db.shape[0] == fbank.shape[0]
|
||||||
|
else:
|
||||||
|
db = db.unsqueeze(0)
|
||||||
|
|
||||||
|
weights = (10**db).astype("float32")
|
||||||
|
fbank = fbank * weights[:, None, None, :]
|
||||||
|
eq_audio_data = fbank.sum(-1)
|
||||||
|
self.audio_data = eq_audio_data
|
||||||
|
return self
|
||||||
|
|
||||||
|
def clip_distortion(
|
||||||
|
self,
|
||||||
|
clip_percentile: typing.Union[paddle.Tensor, np.ndarray, float]):
|
||||||
|
"""Clips the signal at a given percentile. The higher it is,
|
||||||
|
the lower the threshold for clipping.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
clip_percentile : typing.Union[paddle.Tensor, np.ndarray, float]
|
||||||
|
Values are between 0.0 to 1.0. Typical values are 0.1 or below.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
Audio signal with clipped audio data.
|
||||||
|
"""
|
||||||
|
clip_percentile = util.ensure_tensor(clip_percentile, ndim=1)
|
||||||
|
clip_percentile = clip_percentile.item()
|
||||||
|
min_thresh = paddle.quantile(
|
||||||
|
self.audio_data, clip_percentile / 2, axis=-1)[None]
|
||||||
|
max_thresh = paddle.quantile(
|
||||||
|
self.audio_data, 1 - (clip_percentile / 2), axis=-1)[None]
|
||||||
|
|
||||||
|
nc = self.audio_data.shape[1]
|
||||||
|
min_thresh = min_thresh[:, :nc, :]
|
||||||
|
max_thresh = max_thresh[:, :nc, :]
|
||||||
|
|
||||||
|
self.audio_data = self.audio_data.clip(min_thresh, max_thresh)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
# def quantization(
|
||||||
|
# self, quantization_channels: typing.Union[paddle.Tensor, np.ndarray, int]
|
||||||
|
# ):
|
||||||
|
# """Applies quantization to the input waveform.
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
# ----------
|
||||||
|
# quantization_channels : typing.Union[paddle.Tensor, np.ndarray, int]
|
||||||
|
# Number of evenly spaced quantization channels to quantize
|
||||||
|
# to.
|
||||||
|
|
||||||
|
# Returns
|
||||||
|
# -------
|
||||||
|
# AudioSignal
|
||||||
|
# Quantized AudioSignal.
|
||||||
|
# """
|
||||||
|
# quantization_channels = util.ensure_tensor(quantization_channels, ndim=3)
|
||||||
|
|
||||||
|
# x = self.audio_data
|
||||||
|
# x = (x + 1) / 2
|
||||||
|
# x = x * quantization_channels
|
||||||
|
# x = x.floor()
|
||||||
|
# x = x / quantization_channels
|
||||||
|
# x = 2 * x - 1
|
||||||
|
|
||||||
|
# residual = (self.audio_data - x).detach()
|
||||||
|
# self.audio_data = self.audio_data - residual
|
||||||
|
# return self
|
||||||
|
|
||||||
|
# def mulaw_quantization(
|
||||||
|
# self, quantization_channels: typing.Union[paddle.Tensor, np.ndarray, int]
|
||||||
|
# ):
|
||||||
|
# """Applies mu-law quantization to the input waveform.
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
# ----------
|
||||||
|
# quantization_channels : typing.Union[paddle.Tensor, np.ndarray, int]
|
||||||
|
# Number of mu-law spaced quantization channels to quantize
|
||||||
|
# to.
|
||||||
|
|
||||||
|
# Returns
|
||||||
|
# -------
|
||||||
|
# AudioSignal
|
||||||
|
# Quantized AudioSignal.
|
||||||
|
# """
|
||||||
|
# mu = quantization_channels - 1.0
|
||||||
|
# mu = util.ensure_tensor(mu, ndim=3)
|
||||||
|
|
||||||
|
# x = self.audio_data
|
||||||
|
|
||||||
|
# # quantize
|
||||||
|
# x = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu)
|
||||||
|
# x = ((x + 1) / 2 * mu + 0.5).to(torch.int64)
|
||||||
|
|
||||||
|
# # unquantize
|
||||||
|
# x = (x / mu) * 2 - 1.0
|
||||||
|
# x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.0) / mu
|
||||||
|
|
||||||
|
# residual = (self.audio_data - x).detach()
|
||||||
|
# self.audio_data = self.audio_data - residual
|
||||||
|
# return self
|
||||||
|
|
||||||
|
# def __matmul__(self, other):
|
||||||
|
# return self.convolve(other)
|
||||||
|
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import typing
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class ImpulseResponseMixin:
|
||||||
|
"""These functions are generally only used with AudioSignals that are derived
|
||||||
|
from impulse responses, not other sources like music or speech. These methods
|
||||||
|
are used to replicate the data augmentation described in [1].
|
||||||
|
|
||||||
|
1. Bryan, Nicholas J. "Impulse response data augmentation and deep
|
||||||
|
neural networks for blind room acoustic parameter estimation."
|
||||||
|
ICASSP 2020-2020 IEEE International Conference on Acoustics,
|
||||||
|
Speech and Signal Processing (ICASSP). IEEE, 2020.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decompose_ir(self):
|
||||||
|
"""Decomposes an impulse response into early and late
|
||||||
|
field responses.
|
||||||
|
"""
|
||||||
|
# Equations 1 and 2
|
||||||
|
# -----------------
|
||||||
|
# Breaking up into early
|
||||||
|
# response + late field response.
|
||||||
|
|
||||||
|
td = paddle.argmax(self.audio_data, axis=-1, keepdim=True)
|
||||||
|
t0 = int(self.sample_rate * 0.0025)
|
||||||
|
|
||||||
|
idx = paddle.arange(self.audio_data.shape[-1])[None, None, :]
|
||||||
|
idx = idx.expand([self.batch_size, -1, -1])
|
||||||
|
early_idx = (idx >= td - t0) * (idx <= td + t0)
|
||||||
|
|
||||||
|
early_response = paddle.zeros_like(self.audio_data)
|
||||||
|
|
||||||
|
# early_response[early_idx] = self.audio_data[early_idx]
|
||||||
|
early_response = paddle.where(early_idx, self.audio_data,
|
||||||
|
early_response)
|
||||||
|
|
||||||
|
late_idx = ~early_idx
|
||||||
|
late_field = paddle.zeros_like(self.audio_data)
|
||||||
|
# late_field[late_idx] = self.audio_data[late_idx]
|
||||||
|
late_field = paddle.where(late_idx, self.audio_data, late_field)
|
||||||
|
|
||||||
|
# Equation 4
|
||||||
|
# ----------
|
||||||
|
# Decompose early response into windowed
|
||||||
|
# direct path and windowed residual.
|
||||||
|
|
||||||
|
window = paddle.zeros_like(self.audio_data)
|
||||||
|
for idx in range(self.batch_size):
|
||||||
|
window_idx = early_idx[idx, 0]
|
||||||
|
|
||||||
|
# ----- Just for this -----
|
||||||
|
# window[idx, ..., window_idx] = self.get_window("hann", window_idx.sum().item())
|
||||||
|
indices = paddle.nonzero(window_idx).reshape(
|
||||||
|
[-1]) # shape: [num_true], dtype: int64
|
||||||
|
temp_window = self.get_window("hann", indices.shape[0])
|
||||||
|
|
||||||
|
window_slice = window[idx, 0]
|
||||||
|
updated_window_slice = paddle.scatter(
|
||||||
|
window_slice, index=indices, updates=temp_window)
|
||||||
|
|
||||||
|
window[idx, 0] = updated_window_slice
|
||||||
|
# ----- Just for that -----
|
||||||
|
|
||||||
|
return early_response, late_field, window
|
||||||
|
|
||||||
|
def measure_drr(self):
|
||||||
|
"""Measures the direct-to-reverberant ratio of the impulse
|
||||||
|
response.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
float
|
||||||
|
Direct-to-reverberant ratio
|
||||||
|
"""
|
||||||
|
early_response, late_field, _ = self.decompose_ir()
|
||||||
|
num = (early_response**2).sum(axis=-1)
|
||||||
|
den = (late_field**2).sum(axis=-1)
|
||||||
|
drr = 10 * paddle.log10(num / den)
|
||||||
|
return drr
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def solve_alpha(early_response, late_field, wd, target_drr):
|
||||||
|
"""Used to solve for the alpha value, which is used
|
||||||
|
to alter the drr.
|
||||||
|
"""
|
||||||
|
# Equation 5
|
||||||
|
# ----------
|
||||||
|
# Apply the good ol' quadratic formula.
|
||||||
|
|
||||||
|
wd_sq = wd**2
|
||||||
|
wd_sq_1 = (1 - wd)**2
|
||||||
|
e_sq = early_response**2
|
||||||
|
l_sq = late_field**2
|
||||||
|
a = (wd_sq * e_sq).sum(axis=-1)
|
||||||
|
b = (2 * (1 - wd) * wd * e_sq).sum(axis=-1)
|
||||||
|
c = (wd_sq_1 * e_sq).sum(axis=-1) - paddle.pow(
|
||||||
|
10 * paddle.ones_like(target_drr), target_drr / 10) * l_sq.sum(
|
||||||
|
axis=-1)
|
||||||
|
|
||||||
|
expr = ((b**2) - 4 * a * c).sqrt()
|
||||||
|
alpha = paddle.maximum(
|
||||||
|
(-b - expr) / (2 * a),
|
||||||
|
(-b + expr) / (2 * a), )
|
||||||
|
return alpha
|
||||||
|
|
||||||
|
def alter_drr(self, drr: typing.Union[paddle.Tensor, np.ndarray, float]):
|
||||||
|
"""Alters the direct-to-reverberant ratio of the impulse response.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
drr : typing.Union[paddle.Tensor, np.ndarray, float]
|
||||||
|
Direct-to-reverberant ratio that impulse response will be
|
||||||
|
altered to, if specified, by default None
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
Altered impulse response.
|
||||||
|
"""
|
||||||
|
drr = util.ensure_tensor(
|
||||||
|
drr, 2, self.batch_size
|
||||||
|
) # Assuming util.ensure_tensor is adapted or equivalent exists
|
||||||
|
|
||||||
|
early_response, late_field, window = self.decompose_ir()
|
||||||
|
alpha = self.solve_alpha(early_response, late_field, window, drr)
|
||||||
|
min_alpha = late_field.abs().max(axis=-1)[0] / early_response.abs().max(
|
||||||
|
axis=-1)[0]
|
||||||
|
alpha = paddle.maximum(alpha, min_alpha)[..., None]
|
||||||
|
|
||||||
|
aug_ir_data = alpha * window * early_response + (
|
||||||
|
(1 - window) * early_response) + late_field
|
||||||
|
self.audio_data = aug_ir_data
|
||||||
|
self.ensure_max_of_audio(
|
||||||
|
) # Assuming ensure_max_of_audio is a method defined elsewhere
|
||||||
|
return self
|
@ -0,0 +1,115 @@
|
|||||||
|
import json
|
||||||
|
import shlex
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import ffmpy
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
|
||||||
|
def r128stats(filepath: str, quiet: bool):
|
||||||
|
"""Takes a path to an audio file, returns a dict with the loudness
|
||||||
|
stats computed by the ffmpeg ebur128 filter.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
filepath : str
|
||||||
|
Path to compute loudness stats on.
|
||||||
|
quiet : bool
|
||||||
|
Whether to show FFMPEG output during computation.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict
|
||||||
|
Dictionary containing loudness stats.
|
||||||
|
"""
|
||||||
|
ffargs = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-nostats",
|
||||||
|
"-i",
|
||||||
|
filepath,
|
||||||
|
"-filter_complex",
|
||||||
|
"ebur128",
|
||||||
|
"-f",
|
||||||
|
"null",
|
||||||
|
"-",
|
||||||
|
]
|
||||||
|
if quiet:
|
||||||
|
ffargs += ["-hide_banner"]
|
||||||
|
proc = subprocess.Popen(
|
||||||
|
ffargs, stderr=subprocess.PIPE, universal_newlines=True)
|
||||||
|
stats = proc.communicate()[1]
|
||||||
|
summary_index = stats.rfind("Summary:")
|
||||||
|
|
||||||
|
summary_list = stats[summary_index:].split()
|
||||||
|
i_lufs = float(summary_list[summary_list.index("I:") + 1])
|
||||||
|
i_thresh = float(summary_list[summary_list.index("I:") + 4])
|
||||||
|
lra = float(summary_list[summary_list.index("LRA:") + 1])
|
||||||
|
lra_thresh = float(summary_list[summary_list.index("LRA:") + 4])
|
||||||
|
lra_low = float(summary_list[summary_list.index("low:") + 1])
|
||||||
|
lra_high = float(summary_list[summary_list.index("high:") + 1])
|
||||||
|
stats_dict = {
|
||||||
|
"I": i_lufs,
|
||||||
|
"I Threshold": i_thresh,
|
||||||
|
"LRA": lra,
|
||||||
|
"LRA Threshold": lra_thresh,
|
||||||
|
"LRA Low": lra_low,
|
||||||
|
"LRA High": lra_high,
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats_dict
|
||||||
|
|
||||||
|
|
||||||
|
def ffprobe_offset_and_codec(path: str) -> Tuple[float, str]:
|
||||||
|
"""Given a path to a file, returns the start time offset and codec of
|
||||||
|
the first audio stream.
|
||||||
|
"""
|
||||||
|
ff = ffmpy.FFprobe(
|
||||||
|
inputs={path: None},
|
||||||
|
global_options="-show_entries format=start_time:stream=duration,start_time,codec_type,codec_name,start_pts,time_base -of json -v quiet",
|
||||||
|
)
|
||||||
|
streams = json.loads(ff.run(stdout=subprocess.PIPE)[0])["streams"]
|
||||||
|
seconds_offset = 0.0
|
||||||
|
codec = None
|
||||||
|
|
||||||
|
# Get the offset and codec of the first audio stream we find
|
||||||
|
# and return its start time, if it has one.
|
||||||
|
for stream in streams:
|
||||||
|
if stream["codec_type"] == "audio":
|
||||||
|
seconds_offset = stream.get("start_time", 0.0)
|
||||||
|
codec = stream.get("codec_name")
|
||||||
|
break
|
||||||
|
return float(seconds_offset), codec
|
||||||
|
|
||||||
|
|
||||||
|
class FFMPEGMixin:
|
||||||
|
_loudness = None
|
||||||
|
|
||||||
|
def ffmpeg_loudness(self, quiet: bool=True):
|
||||||
|
"""Computes loudness of audio file using FFMPEG.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
quiet : bool, optional
|
||||||
|
Whether to show FFMPEG output during computation,
|
||||||
|
by default True
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
paddle.Tensor
|
||||||
|
Loudness of every item in the batch, computed via
|
||||||
|
FFMPEG.
|
||||||
|
"""
|
||||||
|
loudness = []
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav") as f:
|
||||||
|
for i in range(self.batch_size):
|
||||||
|
self[i].write(f.name)
|
||||||
|
loudness_stats = r128stats(f.name, quiet=quiet)
|
||||||
|
loudness.append(loudness_stats["I"])
|
||||||
|
|
||||||
|
self._loudness = paddle.to_tensor(np.array(loudness)).astype("float32")
|
||||||
|
return self.loudness()
|
@ -0,0 +1,338 @@
|
|||||||
|
import copy
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
import scipy
|
||||||
|
|
||||||
|
from . import _julius
|
||||||
|
|
||||||
|
|
||||||
|
class Meter(paddle.nn.Layer):
|
||||||
|
"""Tensorized version of pyloudnorm.Meter. Works with batched audio tensors.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
rate : int
|
||||||
|
Sample rate of audio.
|
||||||
|
filter_class : str, optional
|
||||||
|
Class of weighting filter used.
|
||||||
|
K-weighting' (default), 'Fenton/Lee 1'
|
||||||
|
'Fenton/Lee 2', 'Dash et al.'
|
||||||
|
by default "K-weighting"
|
||||||
|
block_size : float, optional
|
||||||
|
Gating block size in seconds, by default 0.400
|
||||||
|
zeros : int, optional
|
||||||
|
Number of zeros to use in FIR approximation of
|
||||||
|
IIR filters, by default 512
|
||||||
|
use_fir : bool, optional
|
||||||
|
Whether to use FIR approximation or exact IIR formulation.
|
||||||
|
If computing on GPU, ``use_fir=True`` will be used, as its
|
||||||
|
much faster, by default False
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
rate: int,
|
||||||
|
filter_class: str="K-weighting",
|
||||||
|
block_size: float=0.400,
|
||||||
|
zeros: int=512,
|
||||||
|
use_fir: bool=False, ):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.rate = rate
|
||||||
|
self.filter_class = filter_class
|
||||||
|
self.block_size = block_size
|
||||||
|
self.use_fir = use_fir
|
||||||
|
|
||||||
|
G = paddle.to_tensor(
|
||||||
|
np.array([1.0, 1.0, 1.0, 1.41, 1.41]), stop_gradient=True)
|
||||||
|
self.register_buffer("G", G)
|
||||||
|
|
||||||
|
# Compute impulse responses so that filtering is fast via
|
||||||
|
# a convolution at runtime, on GPU, unlike lfilter.
|
||||||
|
impulse = np.zeros((zeros, ))
|
||||||
|
impulse[..., 0] = 1.0
|
||||||
|
|
||||||
|
firs = np.zeros((len(self._filters), 1, zeros))
|
||||||
|
# passband_gain = torch.zeros(len(self._filters))
|
||||||
|
passband_gain = paddle.zeros([len(self._filters)], dtype="float32")
|
||||||
|
|
||||||
|
for i, (_, filter_stage) in enumerate(self._filters.items()):
|
||||||
|
firs[i] = scipy.signal.lfilter(filter_stage.b, filter_stage.a,
|
||||||
|
impulse)
|
||||||
|
passband_gain[i] = filter_stage.passband_gain
|
||||||
|
|
||||||
|
firs = paddle.to_tensor(
|
||||||
|
firs[..., ::-1].copy(), dtype="float32", stop_gradient=True)
|
||||||
|
|
||||||
|
self.register_buffer("firs", firs)
|
||||||
|
self.register_buffer("passband_gain", passband_gain)
|
||||||
|
|
||||||
|
def apply_filter_gpu(self, data: paddle.Tensor):
|
||||||
|
"""Performs FIR approximation of loudness computation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data : paddle.Tensor
|
||||||
|
Audio data of shape (nb, nch, nt).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
paddle.Tensor
|
||||||
|
Filtered audio data.
|
||||||
|
"""
|
||||||
|
# Data is of shape (nb, nch, nt)
|
||||||
|
# Reshape to (nb*nch, 1, nt)
|
||||||
|
nb, nt, nch = data.shape
|
||||||
|
data = data.transpose([0, 2, 1])
|
||||||
|
data = data.reshape([nb * nch, 1, nt])
|
||||||
|
|
||||||
|
# Apply padding
|
||||||
|
pad_length = self.firs.shape[-1]
|
||||||
|
|
||||||
|
# Apply filtering in sequence
|
||||||
|
for i in range(self.firs.shape[0]):
|
||||||
|
data = F.pad(data, (pad_length, pad_length), data_format="NCL")
|
||||||
|
data = _julius.fft_conv1d(data, self.firs[i, None, ...])
|
||||||
|
data = self.passband_gain[i] * data
|
||||||
|
data = data[..., 1:nt + 1]
|
||||||
|
|
||||||
|
data = data.transpose([0, 2, 1])
|
||||||
|
data = data[:, :nt, :]
|
||||||
|
return data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def scipy_lfilter(waveform, a_coeffs, b_coeffs, clamp: bool=True):
|
||||||
|
# 使用 scipy.signal.lfilter 进行滤波(处理三维数据)
|
||||||
|
output = np.zeros_like(waveform)
|
||||||
|
for batch_idx in range(waveform.shape[0]):
|
||||||
|
for channel_idx in range(waveform.shape[2]):
|
||||||
|
output[batch_idx, :, channel_idx] = scipy.signal.lfilter(
|
||||||
|
b_coeffs, a_coeffs, waveform[batch_idx, :, channel_idx])
|
||||||
|
return output
|
||||||
|
|
||||||
|
def apply_filter_cpu(self, data: paddle.Tensor):
|
||||||
|
"""Performs IIR formulation of loudness computation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data : paddle.Tensor
|
||||||
|
Audio data of shape (nb, nch, nt).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
paddle.Tensor
|
||||||
|
Filtered audio data.
|
||||||
|
"""
|
||||||
|
_data = data.cpu().numpy().copy()
|
||||||
|
for _, filter_stage in self._filters.items():
|
||||||
|
passband_gain = filter_stage.passband_gain
|
||||||
|
|
||||||
|
a_coeffs = filter_stage.a
|
||||||
|
b_coeffs = filter_stage.b
|
||||||
|
|
||||||
|
filtered = self.scipy_lfilter(_data, a_coeffs, b_coeffs)
|
||||||
|
_data[:] = passband_gain * filtered
|
||||||
|
data = paddle.to_tensor(_data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def apply_filter(self, data: paddle.Tensor):
|
||||||
|
"""Applies filter on either CPU or GPU, depending
|
||||||
|
on if the audio is on GPU or is on CPU, or if
|
||||||
|
``self.use_fir`` is True.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data : paddle.Tensor
|
||||||
|
Audio data of shape (nb, nch, nt).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
paddle.Tensor
|
||||||
|
Filtered audio data.
|
||||||
|
"""
|
||||||
|
if data.place.is_gpu_place() or self.use_fir:
|
||||||
|
data = self.apply_filter_gpu(data)
|
||||||
|
else:
|
||||||
|
data = self.apply_filter_cpu(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def forward(self, data: paddle.Tensor):
|
||||||
|
"""Computes integrated loudness of data.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data : paddle.Tensor
|
||||||
|
Audio data of shape (nb, nch, nt).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
paddle.Tensor
|
||||||
|
Filtered audio data.
|
||||||
|
"""
|
||||||
|
return self.integrated_loudness(data)
|
||||||
|
|
||||||
|
def _unfold(self, input_data):
|
||||||
|
T_g = self.block_size
|
||||||
|
overlap = 0.75 # overlap of 75% of the block duration
|
||||||
|
step = 1.0 - overlap # step size by percentage
|
||||||
|
|
||||||
|
kernel_size = int(T_g * self.rate)
|
||||||
|
stride = int(T_g * self.rate * step)
|
||||||
|
unfolded = _julius.unfold(
|
||||||
|
input_data.transpose([0, 2, 1]), kernel_size, stride)
|
||||||
|
unfolded = unfolded.transpose([0, 1, 3, 2])
|
||||||
|
|
||||||
|
return unfolded
|
||||||
|
|
||||||
|
def integrated_loudness(self, data: paddle.Tensor):
|
||||||
|
"""Computes integrated loudness of data.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data : paddle.Tensor
|
||||||
|
Audio data of shape (nb, nch, nt).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
paddle.Tensor
|
||||||
|
Filtered audio data.
|
||||||
|
"""
|
||||||
|
if not paddle.is_tensor(data):
|
||||||
|
data = paddle.to_tensor(data, dtype="float32")
|
||||||
|
else:
|
||||||
|
data = data.astype("float32")
|
||||||
|
|
||||||
|
input_data = data.clone()
|
||||||
|
# Data always has a batch and channel dimension.
|
||||||
|
# Is of shape (nb, nt, nch)
|
||||||
|
if input_data.ndim < 2:
|
||||||
|
input_data = input_data.unsqueeze(-1)
|
||||||
|
if input_data.ndim < 3:
|
||||||
|
input_data = input_data.unsqueeze(0)
|
||||||
|
|
||||||
|
nb, nt, nch = input_data.shape
|
||||||
|
|
||||||
|
# Apply frequency weighting filters - account
|
||||||
|
# for the acoustic respose of the head and auditory system
|
||||||
|
input_data = self.apply_filter(input_data)
|
||||||
|
|
||||||
|
G = self.G # channel gains
|
||||||
|
T_g = self.block_size # 400 ms gating block standard
|
||||||
|
Gamma_a = -70.0 # -70 LKFS = absolute loudness threshold
|
||||||
|
|
||||||
|
unfolded = self._unfold(input_data)
|
||||||
|
|
||||||
|
z = (1.0 / (T_g * self.rate)) * unfolded.square().sum(2)
|
||||||
|
l = -0.691 + 10.0 * paddle.log10(
|
||||||
|
(G[None, :nch, None] * z).sum(1, keepdim=True))
|
||||||
|
l = l.expand_as(z)
|
||||||
|
|
||||||
|
# find gating block indices above absolute threshold
|
||||||
|
z_avg_gated = z
|
||||||
|
z_avg_gated[l <= Gamma_a] = 0
|
||||||
|
masked = l > Gamma_a
|
||||||
|
z_avg_gated = z_avg_gated.sum(2) / masked.sum(2)
|
||||||
|
|
||||||
|
# calculate the relative threshold value (see eq. 6)
|
||||||
|
Gamma_r = -0.691 + 10.0 * paddle.log10(
|
||||||
|
(z_avg_gated * G[None, :nch]).sum(-1)) - 10.0
|
||||||
|
Gamma_r = Gamma_r[:, None, None]
|
||||||
|
Gamma_r = Gamma_r.expand([nb, nch, l.shape[-1]])
|
||||||
|
|
||||||
|
# find gating block indices above relative and absolute thresholds (end of eq. 7)
|
||||||
|
z_avg_gated = z
|
||||||
|
z_avg_gated[l <= Gamma_a] = 0
|
||||||
|
z_avg_gated[l <= Gamma_r] = 0
|
||||||
|
masked = (l > Gamma_a) * (l > Gamma_r)
|
||||||
|
z_avg_gated = z_avg_gated.sum(2) / masked.sum(2)
|
||||||
|
|
||||||
|
# # Cannot use nan_to_num (pytorch 1.8 does not come with GCP-supported cuda version)
|
||||||
|
# z_avg_gated = torch.nan_to_num(z_avg_gated)
|
||||||
|
z_avg_gated = paddle.where(
|
||||||
|
paddle.isnan(z_avg_gated),
|
||||||
|
paddle.zeros_like(z_avg_gated), z_avg_gated)
|
||||||
|
z_avg_gated[z_avg_gated == float("inf")] = float(
|
||||||
|
np.finfo(np.float32).max)
|
||||||
|
z_avg_gated[z_avg_gated == -float("inf")] = float(
|
||||||
|
np.finfo(np.float32).min)
|
||||||
|
|
||||||
|
LUFS = -0.691 + 10.0 * paddle.log10(
|
||||||
|
(G[None, :nch] * z_avg_gated).sum(1))
|
||||||
|
return LUFS.astype("float32")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def filter_class(self):
|
||||||
|
return self._filter_class
|
||||||
|
|
||||||
|
@filter_class.setter
|
||||||
|
def filter_class(self, value):
|
||||||
|
from pyloudnorm import Meter
|
||||||
|
|
||||||
|
meter = Meter(self.rate)
|
||||||
|
meter.filter_class = value
|
||||||
|
self._filter_class = value
|
||||||
|
self._filters = meter._filters
|
||||||
|
|
||||||
|
|
||||||
|
class LoudnessMixin:
|
||||||
|
_loudness = None
|
||||||
|
MIN_LOUDNESS = -70
|
||||||
|
"""Minimum loudness possible."""
|
||||||
|
|
||||||
|
def loudness(self,
|
||||||
|
filter_class: str="K-weighting",
|
||||||
|
block_size: float=0.400,
|
||||||
|
**kwargs):
|
||||||
|
"""Calculates loudness using an implementation of ITU-R BS.1770-4.
|
||||||
|
Allows control over gating block size and frequency weighting filters for
|
||||||
|
additional control. Measure the integrated gated loudness of a signal.
|
||||||
|
|
||||||
|
API is derived from PyLoudnorm, but this implementation is ported to PyTorch
|
||||||
|
and is tensorized across batches. When on GPU, an FIR approximation of the IIR
|
||||||
|
filters is used to compute loudness for speed.
|
||||||
|
|
||||||
|
Uses the weighting filters and block size defined by the meter
|
||||||
|
the integrated loudness is measured based upon the gating algorithm
|
||||||
|
defined in the ITU-R BS.1770-4 specification.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
filter_class : str, optional
|
||||||
|
Class of weighting filter used.
|
||||||
|
K-weighting' (default), 'Fenton/Lee 1'
|
||||||
|
'Fenton/Lee 2', 'Dash et al.'
|
||||||
|
by default "K-weighting"
|
||||||
|
block_size : float, optional
|
||||||
|
Gating block size in seconds, by default 0.400
|
||||||
|
kwargs : dict, optional
|
||||||
|
Keyword arguments to :py:func:`audiotools.core.loudness.Meter`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
paddle.Tensor
|
||||||
|
Loudness of audio data.
|
||||||
|
"""
|
||||||
|
if self._loudness is not None:
|
||||||
|
return self._loudness # .to(self.device)
|
||||||
|
original_length = self.signal_length
|
||||||
|
if self.signal_duration < 0.5:
|
||||||
|
pad_len = int((0.5 - self.signal_duration) * self.sample_rate)
|
||||||
|
self.zero_pad(0, pad_len)
|
||||||
|
|
||||||
|
# create BS.1770 meter
|
||||||
|
meter = Meter(
|
||||||
|
self.sample_rate,
|
||||||
|
filter_class=filter_class,
|
||||||
|
block_size=block_size,
|
||||||
|
**kwargs)
|
||||||
|
# meter = meter.to(self.device)
|
||||||
|
# measure loudness
|
||||||
|
loudness = meter.integrated_loudness(
|
||||||
|
self.audio_data.transpose([0, 2, 1]))
|
||||||
|
self.truncate_samples(original_length)
|
||||||
|
min_loudness = paddle.ones_like(loudness) * self.MIN_LOUDNESS
|
||||||
|
self._loudness = paddle.maximum(loudness, min_loudness)
|
||||||
|
|
||||||
|
return self._loudness # .to(self.device)
|
@ -1,5 +0,0 @@
|
|||||||
soundfile
|
|
||||||
librosa
|
|
||||||
scipy
|
|
||||||
rich
|
|
||||||
flatten_dict
|
|
Loading…
Reference in new issue