fix codestyle

pull/3900/head
drryanhuang 10 months ago
parent 9e7dca2bc5
commit 1726e2fdfc

@ -8,13 +8,12 @@ import typing
import warnings import warnings
from collections import namedtuple from collections import namedtuple
from pathlib import Path from pathlib import Path
from typing import Optional
import librosa
import numpy as np import numpy as np
import soundfile
import paddle import paddle
import librosa import soundfile
from typing import Optional
import util import util
from resample import resample_frac from resample import resample_frac
@ -48,9 +47,8 @@ def create_dct(n_mfcc: int, n_mels: int, norm: Optional[str]) -> paddle.Tensor:
# http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II # http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
n = paddle.arange(float(n_mels)) n = paddle.arange(float(n_mels))
k = paddle.arange(float(n_mfcc)).unsqueeze([1]) k = paddle.arange(float(n_mfcc)).unsqueeze([1])
dct = paddle.cos( dct = paddle.cos(math.pi / float(n_mels) * (n + 0.5) *
math.pi / float(n_mels) * (n + 0.5) * k k) # size (n_mfcc, n_mels)
) # size (n_mfcc, n_mels)
if norm is None: if norm is None:
dct *= 2.0 dct *= 2.0
@ -68,8 +66,7 @@ STFTParams = namedtuple(
"window_type", "window_type",
"match_stride", "match_stride",
"padding_type", "padding_type",
], ], )
)
""" """
STFTParams object is a container that holds STFT parameters - window_length, STFTParams object is a container that holds STFT parameters - window_length,
hop_length, and window_type. Not all parameters need to be specified. Ones that hop_length, and window_type. Not all parameters need to be specified. Ones that
@ -92,14 +89,14 @@ STFTParams.__new__.__defaults__ = (None, None, None, None, None)
class AudioSignal( class AudioSignal(
# EffectMixin, # EffectMixin,
# LoudnessMixin, # LoudnessMixin,
# PlayMixin, # PlayMixin,
# ImpulseResponseMixin, # ImpulseResponseMixin,
# DSPMixin, # DSPMixin,
# DisplayMixin, # DisplayMixin,
# FFMPEGMixin, # FFMPEGMixin,
# WhisperMixin, # WhisperMixin,
): ):
"""This is the core object of this library. Audio is always """This is the core object of this library. Audio is always
loaded into an AudioSignal, which then enables all the features loaded into an AudioSignal, which then enables all the features
@ -161,14 +158,14 @@ class AudioSignal(
""" """
def __init__( def __init__(
self, self,
audio_path_or_array: typing.Union[paddle.Tensor, str, Path, np.ndarray], audio_path_or_array: typing.Union[paddle.Tensor, str, Path,
sample_rate: int = None, np.ndarray],
stft_params: STFTParams = None, sample_rate: int=None,
offset: float = 0, stft_params: STFTParams=None,
duration: float = None, offset: float=0,
device: str = None, duration: float=None,
): device: str=None, ):
# ✅ # ✅
audio_path = None audio_path = None
audio_array = None audio_array = None
@ -182,10 +179,8 @@ class AudioSignal(
elif paddle.is_tensor(audio_path_or_array): elif paddle.is_tensor(audio_path_or_array):
audio_array = audio_path_or_array audio_array = audio_path_or_array
else: else:
raise ValueError( raise ValueError("audio_path_or_array must be either a Path, "
"audio_path_or_array must be either a Path, " "string, numpy array, or paddle Tensor!")
"string, numpy array, or paddle Tensor!"
)
self.path_to_file = None self.path_to_file = None
@ -194,8 +189,7 @@ class AudioSignal(
self.stft_data = None self.stft_data = None
if audio_path is not None: if audio_path is not None:
self.load_from_file( self.load_from_file(
audio_path, offset=offset, duration=duration, device=device audio_path, offset=offset, duration=duration, device=device)
)
elif audio_array is not None: elif audio_array is not None:
assert sample_rate is not None, "Must set sample rate!" assert sample_rate is not None, "Must set sample rate!"
self.load_from_array(audio_array, sample_rate, device=device) self.load_from_array(audio_array, sample_rate, device=device)
@ -210,8 +204,7 @@ class AudioSignal(
@property @property
def path_to_input_file( def path_to_input_file(
self, self, ):
):
""" """
Path to input file, if it exists. Path to input file, if it exists.
Alias to ``path_to_file`` for backwards compatibility Alias to ``path_to_file`` for backwards compatibility
@ -220,13 +213,12 @@ class AudioSignal(
@classmethod @classmethod
def excerpt( def excerpt(
cls, cls,
audio_path: typing.Union[str, Path], audio_path: typing.Union[str, Path],
offset: float = None, offset: float=None,
duration: float = None, duration: float=None,
state: typing.Union[np.random.RandomState, int] = None, state: typing.Union[np.random.RandomState, int]=None,
**kwargs, **kwargs, ):
):
"""✅Randomly draw an excerpt of ``duration`` seconds from an """✅Randomly draw an excerpt of ``duration`` seconds from an
audio file specified at ``audio_path``, between ``offset`` seconds audio file specified at ``audio_path``, between ``offset`` seconds
and end of file. ``state`` can be used to seed the random draw. and end of file. ``state`` can be used to seed the random draw.
@ -268,13 +260,12 @@ class AudioSignal(
@classmethod @classmethod
def salient_excerpt( def salient_excerpt(
cls, cls,
audio_path: typing.Union[str, Path], audio_path: typing.Union[str, Path],
loudness_cutoff: float = None, loudness_cutoff: float=None,
num_tries: int = 8, num_tries: int=8,
state: typing.Union[np.random.RandomState, int] = None, state: typing.Union[np.random.RandomState, int]=None,
**kwargs, **kwargs, ):
):
"""❌Similar to AudioSignal.excerpt, except it extracts excerpts only """❌Similar to AudioSignal.excerpt, except it extracts excerpts only
if they are above a specified loudness threshold, which is computed via if they are above a specified loudness threshold, which is computed via
a fast LUFS routine. a fast LUFS routine.
@ -329,13 +320,12 @@ class AudioSignal(
@classmethod @classmethod
def zeros( def zeros(
cls, cls,
duration: float, duration: float,
sample_rate: int, sample_rate: int,
num_channels: int = 1, num_channels: int=1,
batch_size: int = 1, batch_size: int=1,
**kwargs, **kwargs, ):
):
"""✅Helper function create an AudioSignal of all zeros. """✅Helper function create an AudioSignal of all zeros.
Parameters Parameters
@ -364,19 +354,17 @@ class AudioSignal(
return cls( return cls(
paddle.zeros([batch_size, num_channels, n_samples]), paddle.zeros([batch_size, num_channels, n_samples]),
sample_rate, sample_rate,
**kwargs, **kwargs, )
)
@classmethod @classmethod
def wave( def wave(
cls, cls,
frequency: float, frequency: float,
duration: float, duration: float,
sample_rate: int, sample_rate: int,
num_channels: int = 1, num_channels: int=1,
shape: str = "sine", shape: str="sine",
**kwargs, **kwargs, ):
):
""" """
Generate a waveform of a given frequency and shape. Generate a waveform of a given frequency and shape.
@ -423,13 +411,12 @@ class AudioSignal(
@classmethod @classmethod
def batch( def batch(
cls, cls,
audio_signals: list, audio_signals: list,
pad_signals: bool = False, pad_signals: bool=False,
truncate_signals: bool = False, truncate_signals: bool=False,
resample: bool = False, resample: bool=False,
dim: int = 0, dim: int=0, ):
):
"""✅Creates a batched AudioSignal from a list of AudioSignals. """✅Creates a batched AudioSignal from a list of AudioSignals.
Parameters Parameters
@ -500,29 +487,25 @@ class AudioSignal(
raise RuntimeError( raise RuntimeError(
f"Not all signals had the same length! Got {signal_lengths}. " f"Not all signals had the same length! Got {signal_lengths}. "
f"All signals must be the same length, or pad_signals/truncate_signals " f"All signals must be the same length, or pad_signals/truncate_signals "
f"must be True. " f"must be True. ")
)
# Concatenate along the specified dimension (default 0) # Concatenate along the specified dimension (default 0)
audio_data = paddle.concat( audio_data = paddle.concat(
[x.audio_data for x in audio_signals], axis=dim [x.audio_data for x in audio_signals], axis=dim)
)
audio_paths = [x.path_to_file for x in audio_signals] audio_paths = [x.path_to_file for x in audio_signals]
batched_signal = cls( batched_signal = cls(
audio_data, audio_data,
sample_rate=audio_signals[0].sample_rate, sample_rate=audio_signals[0].sample_rate, )
)
batched_signal.path_to_file = audio_paths batched_signal.path_to_file = audio_paths
return batched_signal return batched_signal
# I/O # I/O
def load_from_file( def load_from_file(
self, self,
audio_path: typing.Union[str, Path], audio_path: typing.Union[str, Path],
offset: float, offset: float,
duration: float, duration: float,
device: str = "cpu", device: str="cpu", ):
):
"""✅Loads data from file. Used internally when AudioSignal """✅Loads data from file. Used internally when AudioSignal
is instantiated with a path to a file. is instantiated with a path to a file.
@ -548,8 +531,7 @@ class AudioSignal(
offset=offset, offset=offset,
duration=duration, duration=duration,
sr=None, sr=None,
mono=False, mono=False, )
)
data = util.ensure_tensor(data) data = util.ensure_tensor(data)
if data.shape[-1] == 0: if data.shape[-1] == 0:
raise RuntimeError( raise RuntimeError(
@ -569,11 +551,10 @@ class AudioSignal(
return self.to(device) return self.to(device)
def load_from_array( def load_from_array(
self, self,
audio_array: typing.Union[paddle.Tensor, np.ndarray], audio_array: typing.Union[paddle.Tensor, np.ndarray],
sample_rate: int, sample_rate: int,
device: str = "cpu", device: str="cpu", ):
):
"""✅Loads data from array, reshaping it to be exactly 3 """✅Loads data from array, reshaping it to be exactly 3
dimensions. Used internally when AudioSignal is called dimensions. Used internally when AudioSignal is called
with a tensor or an array. with a tensor or an array.
@ -646,8 +627,7 @@ class AudioSignal(
if self.audio_data[0].abs().max() > 1: if self.audio_data[0].abs().max() > 1:
warnings.warn("Audio amplitude > 1 clipped when saving") warnings.warn("Audio amplitude > 1 clipped when saving")
soundfile.write( soundfile.write(
str(audio_path), self.audio_data[0].numpy().T, self.sample_rate str(audio_path), self.audio_data[0].numpy().T, self.sample_rate)
)
self.path_to_file = audio_path self.path_to_file = audio_path
return self return self
@ -689,8 +669,7 @@ class AudioSignal(
clone = type(self)( clone = type(self)(
self.audio_data.clone(), self.audio_data.clone(),
self.sample_rate, self.sample_rate,
stft_params=self.stft_params, stft_params=self.stft_params, )
)
if self.stft_data is not None: if self.stft_data is not None:
clone.stft_data = self.stft_data.clone() clone.stft_data = self.stft_data.clone()
if self._loudness is not None: if self._loudness is not None:
@ -777,9 +756,8 @@ class AudioSignal(
""" """
if sample_rate == self.sample_rate: if sample_rate == self.sample_rate:
return self return self
self.audio_data = resample_frac( self.audio_data = resample_frac(self.audio_data, self.sample_rate,
self.audio_data, self.sample_rate, sample_rate sample_rate)
)
self.sample_rate = sample_rate self.sample_rate = sample_rate
return self return self
@ -861,11 +839,10 @@ class AudioSignal(
AudioSignal with padding applied. AudioSignal with padding applied.
""" """
self.audio_data = paddle.nn.functional.pad( self.audio_data = paddle.nn.functional.pad(
self.audio_data, (before, after), data_format="NCL" self.audio_data, (before, after), data_format="NCL")
)
return self return self
def zero_pad_to(self, length: int, mode: str = "after"): def zero_pad_to(self, length: int, mode: str="after"):
"""✅Pad with zeros to a specified length, either before or after """✅Pad with zeros to a specified length, either before or after
the audio data. the audio data.
@ -990,10 +967,8 @@ class AudioSignal(
def stft_data(self, data: typing.Union[paddle.Tensor, np.ndarray]): def stft_data(self, data: typing.Union[paddle.Tensor, np.ndarray]):
if data is not None: if data is not None:
assert paddle.is_tensor(data) and paddle.is_complex(data) assert paddle.is_tensor(data) and paddle.is_complex(data)
if ( if (self.stft_data is not None and
self.stft_data is not None self.stft_data.shape != data.shape):
and self.stft_data.shape != data.shape
):
warnings.warn("stft_data changed shape") warnings.warn("stft_data changed shape")
self._stft_data = data self._stft_data = data
return return
@ -1062,7 +1037,7 @@ class AudioSignal(
# STFT # STFT
@staticmethod @staticmethod
@functools.lru_cache(None) @functools.lru_cache(None)
def get_window(window_type: str, window_length: int, device: str = None): def get_window(window_type: str, window_length: int, device: str=None):
"""✅Wrapper around scipy.signal.get_window so one can also get the """✅Wrapper around scipy.signal.get_window so one can also get the
popular sqrt-hann window. This function caches for efficiency popular sqrt-hann window. This function caches for efficiency
using functools.lru\_cache. using functools.lru\_cache.
@ -1118,7 +1093,7 @@ class AudioSignal(
@stft_params.setter @stft_params.setter
def stft_params(self, value: STFTParams): def stft_params(self, value: STFTParams):
# ✅ # ✅
default_win_len = int(2 ** (np.ceil(np.log2(0.032 * self.sample_rate)))) default_win_len = int(2**(np.ceil(np.log2(0.032 * self.sample_rate))))
default_hop_len = default_win_len // 4 default_hop_len = default_win_len // 4
default_win_type = "hann" default_win_type = "hann"
default_match_stride = False default_match_stride = False
@ -1129,8 +1104,7 @@ class AudioSignal(
hop_length=default_hop_len, hop_length=default_hop_len,
window_type=default_win_type, window_type=default_win_type,
match_stride=default_match_stride, match_stride=default_match_stride,
padding_type=default_padding_type, padding_type=default_padding_type, )._asdict()
)._asdict()
value = value._asdict() if value else default_stft_params value = value._asdict() if value else default_stft_params
@ -1141,9 +1115,10 @@ class AudioSignal(
self._stft_params = STFTParams(**value) self._stft_params = STFTParams(**value)
self.stft_data = None self.stft_data = None
def compute_stft_padding( def compute_stft_padding(self,
self, window_length: int, hop_length: int, match_stride: bool window_length: int,
): hop_length: int,
match_stride: bool):
"""✅Compute how the STFT should be padded, based on match\_stride. """✅Compute how the STFT should be padded, based on match\_stride.
Parameters Parameters
@ -1164,9 +1139,8 @@ class AudioSignal(
length = self.signal_length length = self.signal_length
if match_stride: if match_stride:
assert ( assert (hop_length == window_length //
hop_length == window_length // 4 4), "For match_stride, hop must equal n_fft // 4"
), "For match_stride, hop must equal n_fft // 4"
right_pad = math.ceil(length / hop_length) * hop_length - length right_pad = math.ceil(length / hop_length) * hop_length - length
pad = (window_length - hop_length) // 2 pad = (window_length - hop_length) // 2
else: else:
@ -1176,13 +1150,12 @@ class AudioSignal(
return right_pad, pad return right_pad, pad
def stft( def stft(
self, self,
window_length: int = None, window_length: int=None,
hop_length: int = None, hop_length: int=None,
window_type: str = None, window_type: str=None,
match_stride: bool = None, match_stride: bool=None,
padding_type: str = None, padding_type: str=None, ):
):
"""✅Computes the short-time Fourier transform of the audio data, """✅Computes the short-time Fourier transform of the audio data,
with specified STFT parameters. with specified STFT parameters.
@ -1219,55 +1192,38 @@ class AudioSignal(
>>> signal.stft() >>> signal.stft()
""" """
window_length = ( window_length = (self.stft_params.window_length
self.stft_params.window_length if window_length is None else int(window_length))
if window_length is None hop_length = (self.stft_params.hop_length
else int(window_length) if hop_length is None else int(hop_length))
) window_type = (self.stft_params.window_type
hop_length = ( if window_type is None else window_type)
self.stft_params.hop_length match_stride = (self.stft_params.match_stride
if hop_length is None if match_stride is None else match_stride)
else int(hop_length) padding_type = (self.stft_params.padding_type
) if padding_type is None else padding_type)
window_type = (
self.stft_params.window_type if window_type is None else window_type
)
match_stride = (
self.stft_params.match_stride
if match_stride is None
else match_stride
)
padding_type = (
self.stft_params.padding_type
if padding_type is None
else padding_type
)
window = self.get_window(window_type, window_length) window = self.get_window(window_type, window_length)
# window = window.to(self.audio_data.device) # window = window.to(self.audio_data.device)
audio_data = self.audio_data audio_data = self.audio_data
right_pad, pad = self.compute_stft_padding( right_pad, pad = self.compute_stft_padding(window_length, hop_length,
window_length, hop_length, match_stride match_stride)
)
audio_data = paddle.nn.functional.pad( audio_data = paddle.nn.functional.pad(
x=audio_data, x=audio_data,
pad=[pad, pad + right_pad], pad=[pad, pad + right_pad],
mode="reflect", mode="reflect",
data_format="NCL", data_format="NCL", )
)
stft_data = paddle.signal.stft( stft_data = paddle.signal.stft(
audio_data.reshape([-1, audio_data.shape[-1]]), audio_data.reshape([-1, audio_data.shape[-1]]),
n_fft=window_length, n_fft=window_length,
hop_length=hop_length, hop_length=hop_length,
window=window, window=window,
# return_complex=True, # return_complex=True,
center=True, center=True, )
)
_, nf, nt = stft_data.shape _, nf, nt = stft_data.shape
stft_data = stft_data.reshape( stft_data = stft_data.reshape(
[self.batch_size, self.num_channels, nf, nt] [self.batch_size, self.num_channels, nf, nt])
)
if match_stride: if match_stride:
# Drop first two and last two frames, which are added # Drop first two and last two frames, which are added
@ -1278,13 +1234,12 @@ class AudioSignal(
return stft_data return stft_data
def istft( def istft(
self, self,
window_length: int = None, window_length: int=None,
hop_length: int = None, hop_length: int=None,
window_type: str = None, window_type: str=None,
match_stride: bool = None, match_stride: bool=None,
length: int = None, length: int=None, ):
):
"""✅Computes inverse STFT and sets it to audio\_data. """✅Computes inverse STFT and sets it to audio\_data.
Parameters Parameters
@ -1314,34 +1269,22 @@ class AudioSignal(
if self.stft_data is None: if self.stft_data is None:
raise RuntimeError("Cannot do inverse STFT without self.stft_data!") raise RuntimeError("Cannot do inverse STFT without self.stft_data!")
window_length = ( window_length = (self.stft_params.window_length
self.stft_params.window_length if window_length is None else int(window_length))
if window_length is None hop_length = (self.stft_params.hop_length
else int(window_length) if hop_length is None else int(hop_length))
) window_type = (self.stft_params.window_type
hop_length = ( if window_type is None else window_type)
self.stft_params.hop_length match_stride = (self.stft_params.match_stride
if hop_length is None if match_stride is None else match_stride)
else int(hop_length)
) window = self.get_window(window_type, window_length,
window_type = ( self.stft_data.place)
self.stft_params.window_type if window_type is None else window_type
)
match_stride = (
self.stft_params.match_stride
if match_stride is None
else match_stride
)
window = self.get_window(
window_type, window_length, self.stft_data.place
)
nb, nch, nf, nt = self.stft_data.shape nb, nch, nf, nt = self.stft_data.shape
stft_data = self.stft_data.reshape([nb * nch, nf, nt]) stft_data = self.stft_data.reshape([nb * nch, nf, nt])
right_pad, pad = self.compute_stft_padding( right_pad, pad = self.compute_stft_padding(window_length, hop_length,
window_length, hop_length, match_stride match_stride)
)
if length is None: if length is None:
length = self.original_signal_length length = self.original_signal_length
@ -1351,8 +1294,7 @@ class AudioSignal(
# Zero-pad the STFT on either side, putting back the frames that were # Zero-pad the STFT on either side, putting back the frames that were
# dropped in stft(). # dropped in stft().
stft_data = paddle.nn.functional.pad( stft_data = paddle.nn.functional.pad(
stft_data, pad=(2, 2), data_format="NCL" stft_data, pad=(2, 2), data_format="NCL")
)
audio_data = paddle.signal.istft( audio_data = paddle.signal.istft(
stft_data, stft_data,
@ -1360,20 +1302,21 @@ class AudioSignal(
hop_length=hop_length, hop_length=hop_length,
window=window, window=window,
length=length, length=length,
center=True, center=True, )
)
audio_data = audio_data.reshape([nb, nch, -1]) audio_data = audio_data.reshape([nb, nch, -1])
if match_stride: if match_stride:
audio_data = audio_data[..., pad : -(pad + right_pad)] audio_data = audio_data[..., pad:-(pad + right_pad)]
self.audio_data = audio_data self.audio_data = audio_data
return self return self
@staticmethod @staticmethod
@functools.lru_cache(None) @functools.lru_cache(None)
def get_mel_filters( def get_mel_filters(sr: int,
sr: int, n_fft: int, n_mels: int, fmin: float = 0.0, fmax: float = None n_fft: int,
): n_mels: int,
fmin: float=0.0,
fmax: float=None):
"""✅Create a Filterbank matrix to combine FFT bins into Mel-frequency bins. """✅Create a Filterbank matrix to combine FFT bins into Mel-frequency bins.
Parameters Parameters
@ -1401,16 +1344,14 @@ class AudioSignal(
n_fft=n_fft, n_fft=n_fft,
n_mels=n_mels, n_mels=n_mels,
fmin=fmin, fmin=fmin,
fmax=fmax, fmax=fmax, )
)
def mel_spectrogram( def mel_spectrogram(
self, self,
n_mels: int = 80, n_mels: int=80,
mel_fmin: float = 0.0, mel_fmin: float=0.0,
mel_fmax: float = None, mel_fmax: float=None,
**kwargs, **kwargs, ):
):
"""✅Computes a Mel spectrogram. """✅Computes a Mel spectrogram.
Parameters Parameters
@ -1438,8 +1379,7 @@ class AudioSignal(
n_fft=2 * (nf - 1), n_fft=2 * (nf - 1),
n_mels=n_mels, n_mels=n_mels,
fmin=mel_fmin, fmin=mel_fmin,
fmax=mel_fmax, fmax=mel_fmax, )
)
mel_basis = paddle.to_tensor(mel_basis) mel_basis = paddle.to_tensor(mel_basis)
mel_spectrogram = magnitude.transpose([0, 1, 3, 2]) @ mel_basis.T mel_spectrogram = magnitude.transpose([0, 1, 3, 2]) @ mel_basis.T
@ -1448,9 +1388,7 @@ class AudioSignal(
@staticmethod @staticmethod
@functools.lru_cache(None) @functools.lru_cache(None)
def get_dct( def get_dct(n_mfcc: int, n_mels: int, norm: str="ortho", device: str=None):
n_mfcc: int, n_mels: int, norm: str = "ortho", device: str = None
):
"""✅Create a discrete cosine transform (DCT) transformation matrix with shape (``n_mels``, ``n_mfcc``), """✅Create a discrete cosine transform (DCT) transformation matrix with shape (``n_mels``, ``n_mfcc``),
it can be normalized depending on norm. For more information about dct: it can be normalized depending on norm. For more information about dct:
http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
@ -1476,12 +1414,11 @@ class AudioSignal(
return create_dct(n_mfcc, n_mels, norm) return create_dct(n_mfcc, n_mels, norm)
def mfcc( def mfcc(
self, self,
n_mfcc: int = 40, n_mfcc: int=40,
n_mels: int = 80, n_mels: int=80,
log_offset: float = 1e-6, log_offset: float=1e-6,
**kwargs, **kwargs, ):
):
"""✅Computes mel-frequency cepstral coefficients (MFCCs). """✅Computes mel-frequency cepstral coefficients (MFCCs).
Parameters Parameters
@ -1538,9 +1475,10 @@ class AudioSignal(
self.stft_data = value * paddle.exp(1j * self.phase) self.stft_data = value * paddle.exp(1j * self.phase)
return return
def log_magnitude( def log_magnitude(self,
self, ref_value: float = 1.0, amin: float = 1e-5, top_db: float = 80.0 ref_value: float=1.0,
): amin: float=1e-5,
top_db: float=80.0):
"""✅Computes the log-magnitude of the spectrogram. """✅Computes the log-magnitude of the spectrogram.
Parameters Parameters
@ -1637,22 +1575,25 @@ class AudioSignal(
# Representation # Representation
def _info(self): def _info(self):
# ✅ # ✅
dur = ( dur = (f"{self.signal_duration:0.3f}"
f"{self.signal_duration:0.3f}" if self.signal_duration else "[unknown]")
if self.signal_duration
else "[unknown]"
)
info = { info = {
"duration": f"{dur} seconds", "duration":
"batch_size": self.batch_size, f"{dur} seconds",
"path": self.path_to_file if self.path_to_file else "path unknown", "batch_size":
"sample_rate": self.sample_rate, self.batch_size,
"num_channels": ( "path":
self.num_channels if self.num_channels else "[unknown]" self.path_to_file if self.path_to_file else "path unknown",
), "sample_rate":
"audio_data.shape": self.audio_data.shape, self.sample_rate,
"stft_params": self.stft_params, "num_channels": (self.num_channels
"device": self.device, if self.num_channels else "[unknown]"),
"audio_data.shape":
self.audio_data.shape,
"stft_params":
self.stft_params,
"device":
self.device,
} }
return info return info
@ -1728,25 +1669,21 @@ class AudioSignal(
stft_data = self.stft_data stft_data = self.stft_data
elif isinstance(key, (bool, int, list, slice, tuple)) or ( elif isinstance(key, (bool, int, list, slice, tuple)) or (
paddle.is_tensor(key) and key.ndim <= 1 paddle.is_tensor(key) and key.ndim <= 1):
):
# Indexing only on the batch dimension. # Indexing only on the batch dimension.
# Then let's copy over relevant stuff. # Then let's copy over relevant stuff.
# Future work: make this work for time-indexing # Future work: make this work for time-indexing
# as well, using the hop length. # as well, using the hop length.
audio_data = self.audio_data[key] audio_data = self.audio_data[key]
_loudness = ( _loudness = (self._loudness[key]
self._loudness[key] if self._loudness is not None else None if self._loudness is not None else None)
) stft_data = (self.stft_data[key]
stft_data = ( if self.stft_data is not None else None)
self.stft_data[key] if self.stft_data is not None else None
)
sources = None sources = None
copy = type(self)( copy = type(self)(
audio_data, self.sample_rate, stft_params=self.stft_params audio_data, self.sample_rate, stft_params=self.stft_params)
)
copy._loudness = _loudness copy._loudness = _loudness
copy._stft_data = stft_data copy._stft_data = stft_data
copy.sources = sources copy.sources = sources
@ -1766,8 +1703,7 @@ class AudioSignal(
return return
elif isinstance(key, (bool, int, list, slice, tuple)) or ( elif isinstance(key, (bool, int, list, slice, tuple)) or (
paddle.is_tensor(key) and key.ndim <= 1 paddle.is_tensor(key) and key.ndim <= 1):
):
if self.audio_data is not None and value.audio_data is not None: if self.audio_data is not None and value.audio_data is not None:
self.audio_data[key] = value.audio_data self.audio_data[key] = value.audio_data
if self._loudness is not None and value._loudness is not None: if self._loudness is not None and value._loudness is not None:

@ -1,13 +1,13 @@
import inspect import inspect
from typing import Optional, Sequence import math
from typing import Optional
from typing import Sequence
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
import math
def simple_repr( def simple_repr(obj, attrs: Optional[Sequence[str]]=None, overrides: dict={}):
obj, attrs: Optional[Sequence[str]] = None, overrides: dict = {}
):
""" """
Return a simple representation string for `obj`. Return a simple representation string for `obj`.
If `attrs` is not None, it should be a list of attributes to include. If `attrs` is not None, it should be a list of attributes to include.
@ -45,8 +45,7 @@ def sinc(x: paddle.Tensor):
return paddle.where( return paddle.where(
x == 0, x == 0,
paddle.to_tensor(1.0, dtype=x.dtype, place=x.place), paddle.to_tensor(1.0, dtype=x.dtype, place=x.place),
paddle.sin(x) / x, paddle.sin(x) / x, )
)
class ResampleFrac(paddle.nn.Layer): class ResampleFrac(paddle.nn.Layer):
@ -54,9 +53,11 @@ class ResampleFrac(paddle.nn.Layer):
Resampling from the sample rate `old_sr` to `new_sr`. Resampling from the sample rate `old_sr` to `new_sr`.
""" """
def __init__( def __init__(self,
self, old_sr: int, new_sr: int, zeros: int = 24, rolloff: float = 0.945 old_sr: int,
): new_sr: int,
zeros: int=24,
rolloff: float=0.945):
""" """
Args: Args:
old_sr (int): sample rate of the input signal x. old_sr (int): sample rate of the input signal x.
@ -129,13 +130,12 @@ class ResampleFrac(paddle.nn.Layer):
# There is probably a way to evaluate those filters more efficiently, but this is kept for # There is probably a way to evaluate those filters more efficiently, but this is kept for
# future work. # future work.
idx = paddle.arange( idx = paddle.arange(
-self._width, self._width + self.old_sr, dtype="float32" -self._width, self._width + self.old_sr, dtype="float32")
)
for i in range(self.new_sr): for i in range(self.new_sr):
t = (-i / self.new_sr + idx / self.old_sr) * sr t = (-i / self.new_sr + idx / self.old_sr) * sr
t = paddle.clip(t, -self.zeros, self.zeros) t = paddle.clip(t, -self.zeros, self.zeros)
t *= math.pi t *= math.pi
window = paddle.cos(t / self.zeros / 2) ** 2 window = paddle.cos(t / self.zeros / 2)**2
kernel = sinc(t) * window kernel = sinc(t) * window
# Renormalize kernel to ensure a constant signal is preserved. # Renormalize kernel to ensure a constant signal is preserved.
kernel = kernel / kernel.sum() kernel = kernel / kernel.sum()
@ -144,16 +144,14 @@ class ResampleFrac(paddle.nn.Layer):
_kernel = paddle.stack(kernels).reshape([self.new_sr, 1, -1]) _kernel = paddle.stack(kernels).reshape([self.new_sr, 1, -1])
self.kernel = self.create_parameter( self.kernel = self.create_parameter(
shape=_kernel.shape, shape=_kernel.shape,
dtype=_kernel.dtype, dtype=_kernel.dtype, )
)
self.kernel.set_value(_kernel) self.kernel.set_value(_kernel)
def forward( def forward(
self, self,
x: paddle.Tensor, x: paddle.Tensor,
output_length: Optional[int] = None, output_length: Optional[int]=None,
full: bool = False, full: bool=False, ):
):
""" """
Resample x. Resample x.
Args: Args:
@ -176,35 +174,29 @@ class ResampleFrac(paddle.nn.Layer):
x.unsqueeze(1), x.unsqueeze(1),
[self._width, self._width + self.old_sr], [self._width, self._width + self.old_sr],
mode="replicate", mode="replicate",
data_format="NCL", data_format="NCL", )
)
ys = F.conv1d(x, self.kernel, stride=self.old_sr, data_format="NCL") ys = F.conv1d(x, self.kernel, stride=self.old_sr, data_format="NCL")
y = ys.transpose([0, 2, 1]).reshape(list(shape[:-1]) + [-1]) y = ys.transpose([0, 2, 1]).reshape(list(shape[:-1]) + [-1])
float_output_length = paddle.to_tensor( float_output_length = paddle.to_tensor(
self.new_sr * length / self.old_sr, dtype="float32" self.new_sr * length / self.old_sr, dtype="float32")
)
max_output_length = paddle.ceil(float_output_length).astype("int64") max_output_length = paddle.ceil(float_output_length).astype("int64")
default_output_length = paddle.floor(float_output_length).astype( default_output_length = paddle.floor(float_output_length).astype(
"int64" "int64")
)
if output_length is None: if output_length is None:
applied_output_length = ( applied_output_length = (max_output_length
max_output_length if full else default_output_length if full else default_output_length)
)
elif output_length < 0 or output_length > max_output_length: elif output_length < 0 or output_length > max_output_length:
raise ValueError( raise ValueError(
f"output_length must be between 0 and {max_output_length.numpy()}" f"output_length must be between 0 and {max_output_length.numpy()}"
) )
else: else:
applied_output_length = paddle.to_tensor( applied_output_length = paddle.to_tensor(
output_length, dtype="int64" output_length, dtype="int64")
)
if full: if full:
raise ValueError( raise ValueError(
"You cannot pass both full=True and output_length" "You cannot pass both full=True and output_length")
)
return y[..., :applied_output_length] return y[..., :applied_output_length]
def __repr__(self): def __repr__(self):
@ -212,14 +204,13 @@ class ResampleFrac(paddle.nn.Layer):
def resample_frac( def resample_frac(
x: paddle.Tensor, x: paddle.Tensor,
old_sr: int, old_sr: int,
new_sr: int, new_sr: int,
zeros: int = 24, zeros: int=24,
rolloff: float = 0.945, rolloff: float=0.945,
output_length: Optional[int] = None, output_length: Optional[int]=None,
full: bool = False, full: bool=False, ):
):
""" """
Functional version of `ResampleFrac`, refer to its documentation for more information. Functional version of `ResampleFrac`, refer to its documentation for more information.
@ -228,9 +219,7 @@ def resample_frac(
resampling kernel will be recomputed everytime. For best performance, you should use resampling kernel will be recomputed everytime. For best performance, you should use
and cache an instance of `ResampleFrac`. and cache an instance of `ResampleFrac`.
""" """
return ResampleFrac(old_sr, new_sr, zeros, rolloff)( return ResampleFrac(old_sr, new_sr, zeros, rolloff)(x, output_length, full)
x, output_length, full
)
if __name__ == "__main__": if __name__ == "__main__":

@ -5,14 +5,16 @@ import numbers
import os import os
import random import random
import typing import typing
import soundfile
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, List from typing import Dict
from typing import List
from typing import Optional
import numpy as np import numpy as np
import paddle import paddle
import soundfile
from flatten_dict import flatten from flatten_dict import flatten
from flatten_dict import unflatten from flatten_dict import unflatten
@ -43,10 +45,9 @@ def info(audio_path: str):
def ensure_tensor( def ensure_tensor(
x: typing.Union[np.ndarray, paddle.Tensor, float, int], x: typing.Union[np.ndarray, paddle.Tensor, float, int],
ndim: int = None, ndim: int=None,
batch_size: int = None, batch_size: int=None, ):
):
"""✅Ensures that the input ``x`` is a tensor of specified """✅Ensures that the input ``x`` is a tensor of specified
dimensions and batch size. dimensions and batch size.
@ -146,10 +147,8 @@ def random_state(seed: typing.Union[int, np.random.RandomState]):
elif isinstance(seed, np.random.RandomState): elif isinstance(seed, np.random.RandomState):
return seed return seed
else: else:
raise ValueError( raise ValueError("%r cannot be used to seed a numpy.random.RandomState"
"%r cannot be used to seed a numpy.random.RandomState" " instance" % seed)
" instance" % seed
)
def seed(random_seed, set_cudnn=False): def seed(random_seed, set_cudnn=False):
@ -214,7 +213,7 @@ def _close_temp_files(tmpfiles: list):
AUDIO_EXTENSIONS = [".wav", ".flac", ".mp3", ".mp4"] AUDIO_EXTENSIONS = [".wav", ".flac", ".mp3", ".mp4"]
def find_audio(folder: str, ext: List[str] = AUDIO_EXTENSIONS): def find_audio(folder: str, ext: List[str]=AUDIO_EXTENSIONS):
"""Finds all audio files in a directory recursively. """Finds all audio files in a directory recursively.
Returns a list. Returns a list.
@ -244,11 +243,10 @@ def find_audio(folder: str, ext: List[str] = AUDIO_EXTENSIONS):
def read_sources( def read_sources(
sources: List[str], sources: List[str],
remove_empty: bool = True, remove_empty: bool=True,
relative_path: str = "", relative_path: str="",
ext: List[str] = AUDIO_EXTENSIONS, ext: List[str]=AUDIO_EXTENSIONS, ):
):
"""Reads audio sources that can either be folders """Reads audio sources that can either be folders
full of audio files, or CSV files that contain paths full of audio files, or CSV files that contain paths
to audio files. CSV files that adhere to the expected to audio files. CSV files that adhere to the expected
@ -291,9 +289,9 @@ def read_sources(
return files return files
def choose_from_list_of_lists( def choose_from_list_of_lists(state: np.random.RandomState,
state: np.random.RandomState, list_of_lists: list, p: float = None list_of_lists: list,
): p: float=None):
"""Choose a single item from a list of lists. """Choose a single item from a list of lists.
Parameters Parameters
@ -335,9 +333,8 @@ def chdir(newdir: typing.Union[Path, str]):
os.chdir(curdir) os.chdir(curdir)
def prepare_batch( def prepare_batch(batch: typing.Union[dict, list, paddle.Tensor],
batch: typing.Union[dict, list, paddle.Tensor], device: str = "cpu" device: str="cpu"):
):
"""Moves items in a batch (typically generated by a DataLoader as a list """Moves items in a batch (typically generated by a DataLoader as a list
or a dict) to the specified device. This works even if dictionaries or a dict) to the specified device. This works even if dictionaries
are nested. are nested.
@ -374,7 +371,7 @@ def prepare_batch(
return batch return batch
def sample_from_dist(dist_tuple: tuple, state: np.random.RandomState = None): def sample_from_dist(dist_tuple: tuple, state: np.random.RandomState=None):
"""Samples from a distribution defined by a tuple. The first """Samples from a distribution defined by a tuple. The first
item in the tuple is the distribution type, and the rest of the item in the tuple is the distribution type, and the rest of the
items are arguments to that distribution. The distribution function items are arguments to that distribution. The distribution function
@ -417,7 +414,7 @@ def sample_from_dist(dist_tuple: tuple, state: np.random.RandomState = None):
return dist_fn(*dist_tuple[1:]) return dist_fn(*dist_tuple[1:])
def collate(list_of_dicts: list, n_splits: int = None): def collate(list_of_dicts: list, n_splits: int=None):
"""Collates a list of dictionaries (e.g. as returned by a """Collates a list of dictionaries (e.g. as returned by a
dataloader) into a dictionary with batched values. This routine dataloader) into a dictionary with batched values. This routine
uses the default paddle collate function for everything uses the default paddle collate function for everything
@ -454,9 +451,10 @@ def collate(list_of_dicts: list, n_splits: int = None):
for i in range(0, list_len, n_items): for i in range(0, list_len, n_items):
# Flatten the dictionaries to avoid recursion. # Flatten the dictionaries to avoid recursion.
list_of_dicts_ = [flatten(d) for d in list_of_dicts[i : i + n_items]] list_of_dicts_ = [flatten(d) for d in list_of_dicts[i:i + n_items]]
dict_of_lists = { dict_of_lists = {
k: [dic[k] for dic in list_of_dicts_] for k in list_of_dicts_[0] k: [dic[k] for dic in list_of_dicts_]
for k in list_of_dicts_[0]
} }
batch = {} batch = {}
@ -467,8 +465,7 @@ def collate(list_of_dicts: list, n_splits: int = None):
else: else:
# Borrow the default collate fn from paddle. # Borrow the default collate fn from paddle.
batch[k] = paddle.utils.data._utils.collate.default_collate( batch[k] = paddle.utils.data._utils.collate.default_collate(
v v)
)
batches.append(unflatten(batch)) batches.append(unflatten(batch))
batches = batches[0] if not return_list else batches batches = batches[0] if not return_list else batches
@ -480,13 +477,12 @@ DEFAULT_FIG_SIZE = (9, 3)
def format_figure( def format_figure(
fig_size: tuple = None, fig_size: tuple=None,
title: str = None, title: str=None,
fig=None, fig=None,
format_axes: bool = True, format_axes: bool=True,
format: bool = True, format: bool=True,
font_color: str = "white", font_color: str="white", ):
):
"""Prettifies the spectrogram and waveform plots. A title """Prettifies the spectrogram and waveform plots. A title
can be inset into the top right corner, and the axes can be can be inset into the top right corner, and the axes can be
inset into the figure, allowing the data to take up the entire inset into the figure, allowing the data to take up the entire
@ -546,8 +542,7 @@ def format_figure(
va="top", va="top",
color=font_color, color=font_color,
fontsize=12 * font_scale, fontsize=12 * font_scale,
alpha=0.75, alpha=0.75, )
)
ticks = ax.get_xticks()[2:] ticks = ax.get_xticks()[2:]
for t in ticks[:-1]: for t in ticks[:-1]:
@ -561,8 +556,7 @@ def format_figure(
va="bottom", va="bottom",
color=font_color, color=font_color,
fontsize=12 * font_scale, fontsize=12 * font_scale,
alpha=0.75, alpha=0.75, )
)
ax.margins(0, 0) ax.margins(0, 0)
ax.set_axis_off() ax.set_axis_off()
@ -570,8 +564,7 @@ def format_figure(
ax.yaxis.set_major_locator(plt.NullLocator()) ax.yaxis.set_major_locator(plt.NullLocator())
plt.subplots_adjust( plt.subplots_adjust(
top=1, bottom=0, right=1, left=0, hspace=0, wspace=0 top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
)
if title is not None: if title is not None:
t = axs[0].annotate( t = axs[0].annotate(
@ -583,20 +576,18 @@ def format_figure(
textcoords="offset points", textcoords="offset points",
ha="right", ha="right",
va="top", va="top",
color="white", color="white", )
)
t.set_bbox(dict(facecolor="black", alpha=0.5, edgecolor="black")) t.set_bbox(dict(facecolor="black", alpha=0.5, edgecolor="black"))
def generate_chord_dataset( def generate_chord_dataset(
max_voices: int = 8, max_voices: int=8,
sample_rate: int = 44100, sample_rate: int=44100,
num_items: int = 5, num_items: int=5,
duration: float = 1.0, duration: float=1.0,
min_note: str = "C2", min_note: str="C2",
max_note: str = "C6", max_note: str="C6",
output_dir: Path = "chords", output_dir: Path="chords", ):
):
""" """
Generates a toy multitrack dataset of chords, synthesized from sine waves. Generates a toy multitrack dataset of chords, synthesized from sine waves.
@ -640,8 +631,7 @@ def generate_chord_dataset(
frequency=librosa.midi_to_hz(midinote), frequency=librosa.midi_to_hz(midinote),
duration=dur, duration=dur,
sample_rate=sample_rate, sample_rate=sample_rate,
shape="sine", shape="sine", )
)
track[f"voice_{voice_idx}"] = sig track[f"voice_{voice_idx}"] = sig
tracks.append(track) tracks.append(track)

Loading…
Cancel
Save