pull/3900/head
drryanhuang 9 months ago
parent af73cc42b8
commit 47e81105f7

@ -169,7 +169,7 @@ class AudioSignal(
offset: float=0, offset: float=0,
duration: float=None, duration: float=None,
device: str=None, ): device: str=None, ):
# #
audio_path = None audio_path = None
audio_array = None audio_array = None
@ -208,7 +208,7 @@ class AudioSignal(
@property @property
def path_to_input_file( def path_to_input_file(
self, ): self, ):
""" """
Path to input file, if it exists. Path to input file, if it exists.
Alias to ``path_to_file`` for backwards compatibility Alias to ``path_to_file`` for backwards compatibility
""" """
@ -222,7 +222,7 @@ class AudioSignal(
duration: float=None, duration: float=None,
state: typing.Union[np.random.RandomState, int]=None, state: typing.Union[np.random.RandomState, int]=None,
**kwargs, ): **kwargs, ):
"""Randomly draw an excerpt of ``duration`` seconds from an """Randomly draw an excerpt of ``duration`` seconds from an
audio file specified at ``audio_path``, between ``offset`` seconds audio file specified at ``audio_path``, between ``offset`` seconds
and end of file. ``state`` can be used to seed the random draw. and end of file. ``state`` can be used to seed the random draw.
@ -329,7 +329,7 @@ class AudioSignal(
num_channels: int=1, num_channels: int=1,
batch_size: int=1, batch_size: int=1,
**kwargs, ): **kwargs, ):
"""Helper function create an AudioSignal of all zeros. """Helper function create an AudioSignal of all zeros.
Parameters Parameters
---------- ----------
@ -368,7 +368,7 @@ class AudioSignal(
num_channels: int=1, num_channels: int=1,
shape: str="sine", shape: str="sine",
**kwargs, ): **kwargs, ):
""" """
Generate a waveform of a given frequency and shape. Generate a waveform of a given frequency and shape.
Parameters Parameters
@ -420,7 +420,7 @@ class AudioSignal(
truncate_signals: bool=False, truncate_signals: bool=False,
resample: bool=False, resample: bool=False,
dim: int=0, ): dim: int=0, ):
"""Creates a batched AudioSignal from a list of AudioSignals. """Creates a batched AudioSignal from a list of AudioSignals.
Parameters Parameters
---------- ----------
@ -509,7 +509,7 @@ class AudioSignal(
offset: float, offset: float,
duration: float, duration: float,
device: str="cpu", ): device: str="cpu", ):
"""Loads data from file. Used internally when AudioSignal """Loads data from file. Used internally when AudioSignal
is instantiated with a path to a file. is instantiated with a path to a file.
Parameters Parameters
@ -558,7 +558,7 @@ class AudioSignal(
audio_array: typing.Union[paddle.Tensor, np.ndarray], audio_array: typing.Union[paddle.Tensor, np.ndarray],
sample_rate: int, sample_rate: int,
device: str="cpu", ): device: str="cpu", ):
"""Loads data from array, reshaping it to be exactly 3 """Loads data from array, reshaping it to be exactly 3
dimensions. Used internally when AudioSignal is called dimensions. Used internally when AudioSignal is called
with a tensor or an array. with a tensor or an array.
@ -594,7 +594,7 @@ class AudioSignal(
return self return self
def write(self, audio_path: typing.Union[str, Path]): def write(self, audio_path: typing.Union[str, Path]):
"""Writes audio to a file. Only writes the audio """Writes audio to a file. Only writes the audio
that is in the very first item of the batch. To write other items that is in the very first item of the batch. To write other items
in the batch, index the signal along the batch dimension in the batch, index the signal along the batch dimension
before writing. After writing, the signal's ``path_to_file`` before writing. After writing, the signal's ``path_to_file``
@ -636,7 +636,7 @@ class AudioSignal(
return self return self
def deepcopy(self): def deepcopy(self):
"""Copies the signal and all of its attributes. """Copies the signal and all of its attributes.
Returns Returns
------- -------
@ -646,7 +646,7 @@ class AudioSignal(
return copy.deepcopy(self) return copy.deepcopy(self)
def copy(self): def copy(self):
"""Shallow copy of signal. """Shallow copy of signal.
Returns Returns
------- -------
@ -656,7 +656,7 @@ class AudioSignal(
return copy.copy(self) return copy.copy(self)
def clone(self): def clone(self):
"""Clones all tensors contained in the AudioSignal, """Clones all tensors contained in the AudioSignal,
and returns a copy of the signal with everything and returns a copy of the signal with everything
cloned. Useful when using AudioSignal within autograd cloned. Useful when using AudioSignal within autograd
computation graphs. computation graphs.
@ -682,7 +682,7 @@ class AudioSignal(
return clone return clone
def detach(self): def detach(self):
"""Detaches tensors contained in AudioSignal. """Detaches tensors contained in AudioSignal.
Relevant attributes are the stft data, the audio data, Relevant attributes are the stft data, the audio data,
and the loudness of the file. and the loudness of the file.
@ -701,7 +701,7 @@ class AudioSignal(
return self return self
def hash(self): def hash(self):
"""Writes the audio data to a temporary file, and then """Writes the audio data to a temporary file, and then
hashes it using hashlib. Useful for creating a file hashes it using hashlib. Useful for creating a file
name based on the audio content. name based on the audio content.
@ -732,7 +732,7 @@ class AudioSignal(
# Signal operations # Signal operations
def to_mono(self): def to_mono(self):
"""Converts audio data to mono audio, by taking the mean """Converts audio data to mono audio, by taking the mean
along the channels dimension. along the channels dimension.
Returns Returns
@ -744,7 +744,7 @@ class AudioSignal(
return self return self
def resample(self, sample_rate: int): def resample(self, sample_rate: int):
"""Resamples the audio, using sinc interpolation. This works on both """Resamples the audio, using sinc interpolation. This works on both
cpu and gpu, and is much faster on gpu. cpu and gpu, and is much faster on gpu.
Parameters Parameters
@ -779,7 +779,7 @@ class AudioSignal(
# Tensor operations # Tensor operations
def to(self, device: str): def to(self, device: str):
"""Moves all tensors contained in signal to the specified device. """Moves all tensors contained in signal to the specified device.
Parameters Parameters
---------- ----------
@ -801,7 +801,7 @@ class AudioSignal(
return self return self
def float(self): def float(self):
"""Calls ``.float()`` on ``self.audio_data``. """Calls ``.float()`` on ``self.audio_data``.
Returns Returns
------- -------
@ -811,7 +811,7 @@ class AudioSignal(
return self return self
def cpu(self): def cpu(self):
"""Moves AudioSignal to cpu. """Moves AudioSignal to cpu.
Returns Returns
------- -------
@ -820,7 +820,7 @@ class AudioSignal(
return self.to("cpu") return self.to("cpu")
def cuda(self): # pragma: no cover def cuda(self): # pragma: no cover
"""Moves AudioSignal to cuda. """Moves AudioSignal to cuda.
Returns Returns
------- -------
@ -829,7 +829,7 @@ class AudioSignal(
return self.to("gpu") return self.to("gpu")
def numpy(self): def numpy(self):
"""Detaches ``self.audio_data``, moves to cpu, and converts to numpy. """Detaches ``self.audio_data``, moves to cpu, and converts to numpy.
Returns Returns
------- -------
@ -839,7 +839,7 @@ class AudioSignal(
return self.audio_data.detach().cpu().numpy() return self.audio_data.detach().cpu().numpy()
def zero_pad(self, before: int, after: int): def zero_pad(self, before: int, after: int):
"""Zero pads the audio_data tensor before and after. """Zero pads the audio_data tensor before and after.
Parameters Parameters
---------- ----------
@ -858,7 +858,7 @@ class AudioSignal(
return self return self
def zero_pad_to(self, length: int, mode: str="after"): def zero_pad_to(self, length: int, mode: str="after"):
"""Pad with zeros to a specified length, either before or after """Pad with zeros to a specified length, either before or after
the audio data. the audio data.
Parameters Parameters
@ -880,7 +880,7 @@ class AudioSignal(
return self return self
def trim(self, before: int, after: int): def trim(self, before: int, after: int):
"""Trims the audio_data tensor before and after. """Trims the audio_data tensor before and after.
Parameters Parameters
---------- ----------
@ -901,7 +901,7 @@ class AudioSignal(
return self return self
def truncate_samples(self, length_in_samples: int): def truncate_samples(self, length_in_samples: int):
"""Truncate signal to specified length. """Truncate signal to specified length.
Parameters Parameters
---------- ----------
@ -918,7 +918,7 @@ class AudioSignal(
@property @property
def device(self): def device(self):
"""Get device that AudioSignal is on. """Get device that AudioSignal is on.
Returns Returns
------- -------
@ -934,7 +934,7 @@ class AudioSignal(
# Properties # Properties
@property @property
def audio_data(self): def audio_data(self):
"""Returns the audio data tensor in the object. """Returns the audio data tensor in the object.
Audio data is always of the shape Audio data is always of the shape
(batch_size, num_channels, num_samples). If value has less (batch_size, num_channels, num_samples). If value has less
@ -968,7 +968,7 @@ class AudioSignal(
@property @property
def stft_data(self): def stft_data(self):
"""Returns the STFT data inside the signal. Shape is """Returns the STFT data inside the signal. Shape is
(batch, channels, frequencies, time). (batch, channels, frequencies, time).
Returns Returns
@ -989,7 +989,7 @@ class AudioSignal(
@property @property
def batch_size(self): def batch_size(self):
"""Batch size of audio signal. """Batch size of audio signal.
Returns Returns
------- -------
@ -1000,7 +1000,7 @@ class AudioSignal(
@property @property
def signal_length(self): def signal_length(self):
"""Length of audio signal. """Length of audio signal.
Returns Returns
------- -------
@ -1014,7 +1014,7 @@ class AudioSignal(
@property @property
def shape(self): def shape(self):
"""Shape of audio data. """Shape of audio data.
Returns Returns
------- -------
@ -1025,7 +1025,7 @@ class AudioSignal(
@property @property
def signal_duration(self): def signal_duration(self):
"""Length of audio signal in seconds. """Length of audio signal in seconds.
Returns Returns
------- -------
@ -1039,7 +1039,7 @@ class AudioSignal(
@property @property
def num_channels(self): def num_channels(self):
"""Number of audio channels. """Number of audio channels.
Returns Returns
------- -------
@ -1052,7 +1052,7 @@ class AudioSignal(
@staticmethod @staticmethod
@functools.lru_cache(None) @functools.lru_cache(None)
def get_window(window_type: str, window_length: int, device: str=None): def get_window(window_type: str, window_length: int, device: str=None):
"""Wrapper around scipy.signal.get_window so one can also get the """Wrapper around scipy.signal.get_window so one can also get the
popular sqrt-hann window. This function caches for efficiency popular sqrt-hann window. This function caches for efficiency
using functools.lru\_cache. using functools.lru\_cache.
@ -1083,7 +1083,7 @@ class AudioSignal(
@property @property
def stft_params(self): def stft_params(self):
"""Returns STFTParams object, which can be re-used to other """Returns STFTParams object, which can be re-used to other
AudioSignals. AudioSignals.
This property can be set as well. If values are not defined in STFTParams, This property can be set as well. If values are not defined in STFTParams,
@ -1106,7 +1106,7 @@ class AudioSignal(
@stft_params.setter @stft_params.setter
def stft_params(self, value: STFTParams): def stft_params(self, value: STFTParams):
# #
default_win_len = int(2**(np.ceil(np.log2(0.032 * self.sample_rate)))) default_win_len = int(2**(np.ceil(np.log2(0.032 * self.sample_rate))))
default_hop_len = default_win_len // 4 default_hop_len = default_win_len // 4
default_win_type = "hann" default_win_type = "hann"
@ -1133,7 +1133,7 @@ class AudioSignal(
window_length: int, window_length: int,
hop_length: int, hop_length: int,
match_stride: bool): match_stride: bool):
"""Compute how the STFT should be padded, based on match\_stride. """Compute how the STFT should be padded, based on match\_stride.
Parameters Parameters
---------- ----------
@ -1169,7 +1169,7 @@ class AudioSignal(
window_type: str=None, window_type: str=None,
match_stride: bool=None, match_stride: bool=None,
padding_type: str=None, ): padding_type: str=None, ):
"""Computes the short-time Fourier transform of the audio data, """Computes the short-time Fourier transform of the audio data,
with specified STFT parameters. with specified STFT parameters.
Parameters Parameters
@ -1250,7 +1250,7 @@ class AudioSignal(
window_type: str=None, window_type: str=None,
match_stride: bool=None, match_stride: bool=None,
length: int=None, ): length: int=None, ):
"""Computes inverse STFT and sets it to audio\_data. """Computes inverse STFT and sets it to audio\_data.
Parameters Parameters
---------- ----------
@ -1325,7 +1325,7 @@ class AudioSignal(
n_mels: int, n_mels: int,
fmin: float=0.0, fmin: float=0.0,
fmax: float=None): fmax: float=None):
"""Create a Filterbank matrix to combine FFT bins into Mel-frequency bins. """Create a Filterbank matrix to combine FFT bins into Mel-frequency bins.
Parameters Parameters
---------- ----------
@ -1360,7 +1360,7 @@ class AudioSignal(
mel_fmin: float=0.0, mel_fmin: float=0.0,
mel_fmax: float=None, mel_fmax: float=None,
**kwargs, ): **kwargs, ):
"""Computes a Mel spectrogram. """Computes a Mel spectrogram.
Parameters Parameters
---------- ----------
@ -1397,7 +1397,7 @@ class AudioSignal(
@staticmethod @staticmethod
@functools.lru_cache(None) @functools.lru_cache(None)
def get_dct(n_mfcc: int, n_mels: int, norm: str="ortho", device: str=None): def get_dct(n_mfcc: int, n_mels: int, norm: str="ortho", device: str=None):
"""Create a discrete cosine transform (DCT) transformation matrix with shape (``n_mels``, ``n_mfcc``), """Create a discrete cosine transform (DCT) transformation matrix with shape (``n_mels``, ``n_mfcc``),
it can be normalized depending on norm. For more information about dct: it can be normalized depending on norm. For more information about dct:
http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
@ -1426,7 +1426,7 @@ class AudioSignal(
n_mels: int=80, n_mels: int=80,
log_offset: float=1e-6, log_offset: float=1e-6,
**kwargs, ): **kwargs, ):
"""Computes mel-frequency cepstral coefficients (MFCCs). """Computes mel-frequency cepstral coefficients (MFCCs).
Parameters Parameters
---------- ----------
@ -1455,7 +1455,7 @@ class AudioSignal(
@property @property
def magnitude(self): def magnitude(self):
"""Computes and returns the absolute value of the STFT, which """Computes and returns the absolute value of the STFT, which
is the magnitude. This value can also be set to some tensor. is the magnitude. This value can also be set to some tensor.
When set, ``self.stft_data`` is manipulated so that its magnitude When set, ``self.stft_data`` is manipulated so that its magnitude
matches what this is set to, and modulated by the phase. matches what this is set to, and modulated by the phase.
@ -1486,7 +1486,7 @@ class AudioSignal(
ref_value: float=1.0, ref_value: float=1.0,
amin: float=1e-5, amin: float=1e-5,
top_db: float=80.0): top_db: float=80.0):
"""Computes the log-magnitude of the spectrogram. """Computes the log-magnitude of the spectrogram.
Parameters Parameters
---------- ----------
@ -1519,7 +1519,7 @@ class AudioSignal(
@property @property
def phase(self): def phase(self):
"""Computes and returns the phase of the STFT. """Computes and returns the phase of the STFT.
This value can also be set to some tensor. This value can also be set to some tensor.
When set, ``self.stft_data`` is manipulated so that its phase When set, ``self.stft_data`` is manipulated so that its phase
matches what this is set to, we original magnitudeith th. matches what this is set to, we original magnitudeith th.
@ -1543,7 +1543,7 @@ class AudioSignal(
@phase.setter @phase.setter
def phase(self, value): def phase(self, value):
# #
self.stft_data = self.magnitude * paddle.exp(1j * value) self.stft_data = self.magnitude * paddle.exp(1j * value)
return return
@ -1583,7 +1583,7 @@ class AudioSignal(
# Representation # Representation
def _info(self): def _info(self):
# #
dur = f"{self.signal_duration:0.3f}" if self.signal_duration else "[unknown]" dur = f"{self.signal_duration:0.3f}" if self.signal_duration else "[unknown]"
info = { info = {
"duration": "duration":
@ -1607,7 +1607,7 @@ class AudioSignal(
return info return info
def markdown(self): def markdown(self):
"""Produces a markdown representation of AudioSignal, in a markdown table. """Produces a markdown representation of AudioSignal, in a markdown table.
Returns Returns
------- -------

@ -44,7 +44,7 @@ class Info:
def info(audio_path: str): def info(audio_path: str):
""" """
Parameters Parameters
---------- ----------
@ -61,7 +61,7 @@ def ensure_tensor(
x: typing.Union[np.ndarray, paddle.Tensor, float, int], x: typing.Union[np.ndarray, paddle.Tensor, float, int],
ndim: int=None, ndim: int=None,
batch_size: int=None, ): batch_size: int=None, ):
"""Ensures that the input ``x`` is a tensor of specified """Ensures that the input ``x`` is a tensor of specified
dimensions and batch size. dimensions and batch size.
Parameters Parameters
@ -93,7 +93,7 @@ def ensure_tensor(
def _get_value(other): def _get_value(other):
# #
from . import AudioSignal from . import AudioSignal
if isinstance(other, AudioSignal): if isinstance(other, AudioSignal):
@ -102,7 +102,7 @@ def _get_value(other):
def random_state(seed: typing.Union[int, np.random.RandomState]): def random_state(seed: typing.Union[int, np.random.RandomState]):
""" """
Turn seed into a np.random.RandomState instance. Turn seed into a np.random.RandomState instance.
Parameters Parameters
@ -135,7 +135,7 @@ def random_state(seed: typing.Union[int, np.random.RandomState]):
def seed(random_seed, **kwargs): def seed(random_seed, **kwargs):
""" """
Seeds all random states with the same random seed Seeds all random states with the same random seed
for reproducibility. Seeds ``numpy``, ``random`` and ``paddle`` for reproducibility. Seeds ``numpy``, ``random`` and ``paddle``
random generators. random generators.
@ -152,7 +152,7 @@ def seed(random_seed, **kwargs):
@contextmanager @contextmanager
def _close_temp_files(tmpfiles: list): def _close_temp_files(tmpfiles: list):
"""Utility function for creating a context and closing all temporary files """Utility function for creating a context and closing all temporary files
once the context is exited. For correct functionality, all temporary file once the context is exited. For correct functionality, all temporary file
handles created inside the context must be appended to the ```tmpfiles``` handles created inside the context must be appended to the ```tmpfiles```
list. list.
@ -185,7 +185,7 @@ AUDIO_EXTENSIONS = [".wav", ".flac", ".mp3", ".mp4"]
def find_audio(folder: str, ext: List[str]=AUDIO_EXTENSIONS): def find_audio(folder: str, ext: List[str]=AUDIO_EXTENSIONS):
"""Finds all audio files in a directory recursively. """Finds all audio files in a directory recursively.
Returns a list. Returns a list.
Parameters Parameters
@ -218,7 +218,7 @@ def read_sources(
remove_empty: bool=True, remove_empty: bool=True,
relative_path: str="", relative_path: str="",
ext: List[str]=AUDIO_EXTENSIONS, ): ext: List[str]=AUDIO_EXTENSIONS, ):
"""Reads audio sources that can either be folders """Reads audio sources that can either be folders
full of audio files, or CSV files that contain paths full of audio files, or CSV files that contain paths
to audio files. CSV files that adhere to the expected to audio files. CSV files that adhere to the expected
format can be generated by format can be generated by
@ -263,7 +263,7 @@ def read_sources(
def choose_from_list_of_lists(state: np.random.RandomState, def choose_from_list_of_lists(state: np.random.RandomState,
list_of_lists: list, list_of_lists: list,
p: float=None): p: float=None):
"""Choose a single item from a list of lists. """Choose a single item from a list of lists.
Parameters Parameters
---------- ----------
@ -286,7 +286,7 @@ def choose_from_list_of_lists(state: np.random.RandomState,
@contextmanager @contextmanager
def chdir(newdir: typing.Union[Path, str]): def chdir(newdir: typing.Union[Path, str]):
""" """
Context manager for switching directories to run a Context manager for switching directories to run a
function. Useful for when you want to use relative function. Useful for when you want to use relative
paths to different runs. paths to different runs.
@ -306,7 +306,7 @@ def chdir(newdir: typing.Union[Path, str]):
def prepare_batch(batch: typing.Union[dict, list, paddle.Tensor], def prepare_batch(batch: typing.Union[dict, list, paddle.Tensor],
device: str="cpu"): device: str="cpu"):
"""Moves items in a batch (typically generated by a DataLoader as a list """Moves items in a batch (typically generated by a DataLoader as a list
or a dict) to the specified device. This works even if dictionaries or a dict) to the specified device. This works even if dictionaries
are nested. are nested.
@ -344,7 +344,7 @@ def prepare_batch(batch: typing.Union[dict, list, paddle.Tensor],
def sample_from_dist(dist_tuple: tuple, state: np.random.RandomState=None): def sample_from_dist(dist_tuple: tuple, state: np.random.RandomState=None):
"""Samples from a distribution defined by a tuple. The first """Samples from a distribution defined by a tuple. The first
item in the tuple is the distribution type, and the rest of the item in the tuple is the distribution type, and the rest of the
items are arguments to that distribution. The distribution function items are arguments to that distribution. The distribution function
is gotten from the ``np.random.RandomState`` object. is gotten from the ``np.random.RandomState`` object.
@ -397,7 +397,7 @@ def format_figure(
format_axes: bool=True, format_axes: bool=True,
format: bool=True, format: bool=True,
font_color: str="white", ): font_color: str="white", ):
"""Prettifies the spectrogram and waveform plots. A title """Prettifies the spectrogram and waveform plots. A title
can be inset into the top right corner, and the axes can be can be inset into the top right corner, and the axes can be
inset into the figure, allowing the data to take up the entire inset into the figure, allowing the data to take up the entire
image. Used in image. Used in

@ -144,7 +144,7 @@ def align_lists(lists, matcher: Callable=default_matcher):
class AudioDataset: class AudioDataset:
"""Loads audio from multiple loaders (with associated transforms) """Loads audio from multiple loaders (with associated transforms)
for a specified number of samples. Excerpts are drawn randomly for a specified number of samples. Excerpts are drawn randomly
of the specified duration, above a specified loudness threshold of the specified duration, above a specified loudness threshold
and are resampled on the fly to the desired sample rate and are resampled on the fly to the desired sample rate
@ -466,7 +466,7 @@ class AudioDataset:
class ConcatDataset(AudioDataset): class ConcatDataset(AudioDataset):
# #
def __init__(self, datasets: list): def __init__(self, datasets: list):
self.datasets = datasets self.datasets = datasets

@ -16,7 +16,7 @@ from .datasets import AudioLoader
class BaseTransform: class BaseTransform:
"""This is the base class for all transforms that are implemented """This is the base class for all transforms that are implemented
in this library. Transforms have two main operations: ``transform`` in this library. Transforms have two main operations: ``transform``
and ``instantiate``. and ``instantiate``.
@ -272,13 +272,13 @@ class BaseTransform:
class Identity(BaseTransform): class Identity(BaseTransform):
"""This transform just returns the original signal.""" """This transform just returns the original signal."""
pass pass
class SpectralTransform(BaseTransform): class SpectralTransform(BaseTransform):
"""Spectral transforms require STFT data to exist, since manipulations """Spectral transforms require STFT data to exist, since manipulations
of the STFT require the spectrogram. This just calls ``stft`` before of the STFT require the spectrogram. This just calls ``stft`` before
the transform is called, and calls ``istft`` after the transform is the transform is called, and calls ``istft`` after the transform is
called so that the audio data is written to after the spectral called so that the audio data is written to after the spectral
@ -293,7 +293,7 @@ class SpectralTransform(BaseTransform):
class Compose(BaseTransform): class Compose(BaseTransform):
"""Compose applies transforms in sequence, one after the other. The """Compose applies transforms in sequence, one after the other. The
transforms are passed in as positional arguments or as a list like so: transforms are passed in as positional arguments or as a list like so:
>>> transform = tfm.Compose( >>> transform = tfm.Compose(
@ -431,7 +431,7 @@ class Compose(BaseTransform):
class Choose(Compose): class Choose(Compose):
"""Choose logic is the same as :py:func:`audiotools.data.transforms.Compose`, """Choose logic is the same as :py:func:`audiotools.data.transforms.Compose`,
but instead of applying all the transforms in sequence, it applies just a single transform, but instead of applying all the transforms in sequence, it applies just a single transform,
which is chosen for each item in the batch. which is chosen for each item in the batch.
@ -481,7 +481,7 @@ class Choose(Compose):
class Repeat(Compose): class Repeat(Compose):
"""Repeatedly applies a given transform ``n_repeat`` times." """Repeatedly applies a given transform ``n_repeat`` times."
Parameters Parameters
---------- ----------
@ -504,7 +504,7 @@ class Repeat(Compose):
class RepeatUpTo(Choose): class RepeatUpTo(Choose):
"""Repeatedly applies a given transform up to ``max_repeat`` times." """Repeatedly applies a given transform up to ``max_repeat`` times."
Parameters Parameters
---------- ----------
@ -532,7 +532,7 @@ class RepeatUpTo(Choose):
class ClippingDistortion(BaseTransform): class ClippingDistortion(BaseTransform):
"""Adds clipping distortion to signal. Corresponds """Adds clipping distortion to signal. Corresponds
to :py:func:`audiotools.core.effects.EffectMixin.clip_distortion`. to :py:func:`audiotools.core.effects.EffectMixin.clip_distortion`.
Parameters Parameters

@ -29,7 +29,7 @@ def default_list():
class Mean: class Mean:
"""Keeps track of the running mean, along with the latest """Keeps track of the running mean, along with the latest
value. value.
""" """
@ -51,7 +51,7 @@ class Mean:
def when(condition): def when(condition):
"""Runs a function only when the condition is met. The condition is """Runs a function only when the condition is met. The condition is
a function that is run. a function that is run.
Parameters Parameters
@ -89,7 +89,7 @@ def when(condition):
def timer(prefix: str="time"): def timer(prefix: str="time"):
"""Adds execution time to the output dictionary of the decorated """Adds execution time to the output dictionary of the decorated
function. The function decorated by this must output a dictionary. function. The function decorated by this must output a dictionary.
The key added will follow the form "[prefix]/[name_of_function]" The key added will follow the form "[prefix]/[name_of_function]"
@ -116,7 +116,7 @@ def timer(prefix: str="time"):
class Tracker: class Tracker:
""" """
A tracker class that helps to monitor the progress of training and logging the metrics. A tracker class that helps to monitor the progress of training and logging the metrics.
Attributes Attributes

Loading…
Cancel
Save