fix cuda error

pull/3900/head
drryanhuang 9 months ago
parent 0ceaa145f0
commit f0b557648e

@ -0,0 +1,23 @@
# PaddleAudio
安装方式: pip install paddleaudio
目前支持的平台Linux, Mac, Windows
## Environment
## Build wheel
cmd: python setup.py bdist_wheel
Linux test build whl environment:
* os - Ubuntu 16.04.7 LTS
* gcc/g++ - 8.2.0
* cmake - 3.18.0 (need install)
MACtest build whl environment
* os
* gcc/g++ 12.2.0
* cpu Intel Xeon E5 x86_64
Windows
not support paddleaudio C++ extension lib (sox io, kaldi native fbank)

@ -7,5 +7,6 @@ from .core import highpass_filter, highpass_filters
from . import metrics
from . import data
from . import ml
from . import post
from .data import datasets
from .data import transforms

@ -552,49 +552,6 @@ def highpass_filter(_input: paddle.Tensor,
return highpass_filters(_input, [cutoff], stride, pad, zeros, fft)[0]
import paddle
from typing import Optional, Sequence
def hz_to_mel(freqs: paddle.Tensor):
"""
Converts a Tensor of frequencies in hertz to the mel scale.
Uses the simple formula by O'Shaughnessy (1987).
Args:
freqs (paddle.Tensor): frequencies to convert.
"""
return 2595 * paddle.log10(1 + freqs / 700)
def mel_to_hz(mels: paddle.Tensor):
"""
Converts a Tensor of mel scaled frequencies to Hertz.
Uses the simple formula by O'Shaughnessy (1987).
Args:
mels (paddle.Tensor): mel frequencies to convert.
"""
return 700 * (10**(mels / 2595) - 1)
def mel_frequencies(n_mels: int, fmin: float, fmax: float):
"""
Return frequencies that are evenly spaced in mel scale.
Args:
n_mels (int): number of frequencies to return.
fmin (float): start from this frequency (in Hz).
fmax (float): finish at this frequency (in Hz).
"""
low = hz_to_mel(paddle.to_tensor(float(fmin))).item()
high = hz_to_mel(paddle.to_tensor(float(fmax))).item()
mels = paddle.linspace(low, high, n_mels)
return mel_to_hz(mels)
class SplitBands(paddle.nn.Layer):
"""
Decomposes a signal over the given frequency bands in the waveform domain using
@ -657,7 +614,8 @@ class SplitBands(paddle.nn.Layer):
if not n_bands >= 1:
raise ValueError(
f"n_bands must be greater than one (got {n_bands})")
cutoffs = mel_frequencies(n_bands + 1, 0, sample_rate / 2)[1:-1]
cutoffs = paddle.audio.functional.mel_frequencies(
n_bands + 1, 0, sample_rate / 2)[1:-1]
else:
if max(cutoffs) > 0.5 * sample_rate:
raise ValueError(

@ -214,95 +214,103 @@ class DSPMixin:
self.stft_data = None
return self
# def mask_frequencies(
# self,
# fmin_hz: typing.Union[paddle.Tensor, np.ndarray, float],
# fmax_hz: typing.Union[paddle.Tensor, np.ndarray, float],
# val: float = 0.0,
# ):
# """Masks frequencies between ``fmin_hz`` and ``fmax_hz``, and fills them
# with the value specified by ``val``. Useful for implementing SpecAug.
# The min and max can be different for every item in the batch.
# Parameters
# ----------
# fmin_hz : typing.Union[paddle.Tensor, np.ndarray, float]
# Lower end of band to mask out.
# fmax_hz : typing.Union[paddle.Tensor, np.ndarray, float]
# Upper end of band to mask out.
# val : float, optional
# Value to fill in, by default 0.0
def mask_frequencies(
self,
fmin_hz: typing.Union[paddle.Tensor, np.ndarray, float],
fmax_hz: typing.Union[paddle.Tensor, np.ndarray, float],
val: float=0.0, ):
"""Masks frequencies between ``fmin_hz`` and ``fmax_hz``, and fills them
with the value specified by ``val``. Useful for implementing SpecAug.
The min and max can be different for every item in the batch.
# Returns
# -------
# AudioSignal
# Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
# masked audio data.
# """
# # SpecAug
# mag, phase = self.magnitude, self.phase
# fmin_hz = util.ensure_tensor(fmin_hz, ndim=mag.ndim)
# fmax_hz = util.ensure_tensor(fmax_hz, ndim=mag.ndim)
# assert paddle.all(fmin_hz < fmax_hz)
# # build mask
# nbins = mag.shape[-2]
# bins_hz = paddle.linspace(0, self.sample_rate / 2, nbins, device=self.device)
# bins_hz = bins_hz[None, None, :, None].repeat(
# self.batch_size, 1, 1, mag.shape[-1]
# )
# mask = (fmin_hz <= bins_hz) & (bins_hz < fmax_hz)
# mask = mask.to(self.device)
Parameters
----------
fmin_hz : typing.Union[paddle.Tensor, np.ndarray, float]
Lower end of band to mask out.
fmax_hz : typing.Union[paddle.Tensor, np.ndarray, float]
Upper end of band to mask out.
val : float, optional
Value to fill in, by default 0.0
# mag = mag.masked_fill(mask, val)
# phase = phase.masked_fill(mask, val)
# self.stft_data = mag * paddle.exp(1j * phase)
# return self
Returns
-------
AudioSignal
Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
masked audio data.
"""
# SpecAug
mag, phase = self.magnitude, self.phase
fmin_hz = util.ensure_tensor(
fmin_hz,
ndim=mag.ndim, )
fmax_hz = util.ensure_tensor(
fmax_hz,
ndim=mag.ndim, )
assert paddle.all(fmin_hz < fmax_hz)
# build mask
nbins = mag.shape[-2]
bins_hz = paddle.linspace(
0,
self.sample_rate / 2,
nbins, )
bins_hz = bins_hz[None, None, :, None].tile(
[self.batch_size, 1, 1, mag.shape[-1]])
mask = (fmin_hz <= bins_hz) & (bins_hz < fmax_hz)
mag = paddle.where(mask, paddle.full_like(mag, val), mag)
phase = paddle.where(mask, paddle.full_like(phase, val), phase)
self.stft_data = mag * paddle.exp(1j * phase)
return self
# def mask_timesteps(
# self,
# tmin_s: typing.Union[paddle.Tensor, np.ndarray, float],
# tmax_s: typing.Union[paddle.Tensor, np.ndarray, float],
# val: float = 0.0,
# ):
# """Masks timesteps between ``tmin_s`` and ``tmax_s``, and fills them
# with the value specified by ``val``. Useful for implementing SpecAug.
# The min and max can be different for every item in the batch.
def mask_timesteps(
self,
tmin_s: typing.Union[paddle.Tensor, np.ndarray, float],
tmax_s: typing.Union[paddle.Tensor, np.ndarray, float],
val: float=0.0, ):
"""Masks timesteps between ``tmin_s`` and ``tmax_s``, and fills them
with the value specified by ``val``. Useful for implementing SpecAug.
The min and max can be different for every item in the batch.
# Parameters
# ----------
# tmin_s : typing.Union[paddle.Tensor, np.ndarray, float]
# Lower end of timesteps to mask out.
# tmax_s : typing.Union[paddle.Tensor, np.ndarray, float]
# Upper end of timesteps to mask out.
# val : float, optional
# Value to fill in, by default 0.0
Parameters
----------
tmin_s : typing.Union[paddle.Tensor, np.ndarray, float]
Lower end of timesteps to mask out.
tmax_s : typing.Union[paddle.Tensor, np.ndarray, float]
Upper end of timesteps to mask out.
val : float, optional
Value to fill in, by default 0.0
# Returns
# -------
# AudioSignal
# Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
# masked audio data.
# """
# # SpecAug
# mag, phase = self.magnitude, self.phase
# tmin_s = util.ensure_tensor(tmin_s, ndim=mag.ndim)
# tmax_s = util.ensure_tensor(tmax_s, ndim=mag.ndim)
# assert paddle.all(tmin_s < tmax_s)
# # build mask
# nt = mag.shape[-1]
# bins_t = paddle.linspace(0, self.signal_duration, nt, device=self.device)
# bins_t = bins_t[None, None, None, :].repeat(
# self.batch_size, 1, mag.shape[-2], 1
# )
# mask = (tmin_s <= bins_t) & (bins_t < tmax_s)
Returns
-------
AudioSignal
Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
masked audio data.
"""
# SpecAug
mag, phase = self.magnitude, self.phase
tmin_s = util.ensure_tensor(tmin_s, ndim=mag.ndim)
tmax_s = util.ensure_tensor(tmax_s, ndim=mag.ndim)
assert paddle.all(tmin_s < tmax_s)
# build mask
nt = mag.shape[-1]
bins_t = paddle.linspace(
0,
self.signal_duration,
nt, )
bins_t = bins_t[None, None, None, :].tile(
[self.batch_size, 1, mag.shape[-2], 1])
mask = (tmin_s <= bins_t) & (bins_t < tmax_s)
# mag = mag.masked_fill(mask, val)
# phase = phase.masked_fill(mask, val)
# self.stft_data = mag * paddle.exp(1j * phase)
# return self
mag = paddle.where(mask, paddle.full_like(mag, val), mag)
phase = paddle.where(mask, paddle.full_like(phase, val), phase)
self.stft_data = mag * paddle.exp(1j * phase)
return self
# def mask_low_magnitudes(
# self, db_cutoff: typing.Union[paddle.Tensor, np.ndarray, float], val: float = 0.0

@ -234,23 +234,23 @@ class EffectMixin:
self.audio_data = self.audio_data * gain[:, None, None]
return self
# def volume_change(self, db: typing.Union[paddle.Tensor, np.ndarray, float]):
# """Change volume of signal by some amount, in dB.
def volume_change(self, db: typing.Union[paddle.Tensor, np.ndarray, float]):
"""Change volume of signal by some amount, in dB.
# Parameters
# ----------
# db : typing.Union[paddle.Tensor, np.ndarray, float]
# Amount to change volume by.
Parameters
----------
db : typing.Union[paddle.Tensor, np.ndarray, float]
Amount to change volume by.
# Returns
# -------
# AudioSignal
# Signal at new volume.
# """
# db = util.ensure_tensor(db, ndim=1).to(self.device)
# gain = torch.exp(db * self.GAIN_FACTOR)
# self.audio_data = self.audio_data * gain[:, None, None]
# return self
Returns
-------
AudioSignal
Signal at new volume.
"""
db = util.ensure_tensor(db, ndim=1)
gain = paddle.exp(db * self.GAIN_FACTOR)
self.audio_data = self.audio_data * gain[:, None, None]
return self
# def _to_2d(self):
# waveform = self.audio_data.reshape(-1, self.signal_length)
@ -411,7 +411,7 @@ class EffectMixin:
paddle.Tensor
Mel-filtered bands, with last axis being the band index.
"""
filterbank = SplitBands(self.sample_rate, n_bands).float()
filterbank = SplitBands(self.sample_rate, n_bands)
filtered = filterbank(self.audio_data)
return filtered.transpose([1, 2, 3, 0])
@ -462,11 +462,11 @@ class EffectMixin:
Audio signal with clipped audio data.
"""
clip_percentile = util.ensure_tensor(clip_percentile, ndim=1)
clip_percentile = clip_percentile.item()
clip_percentile = clip_percentile.cpu().numpy()
min_thresh = paddle.quantile(
self.audio_data, clip_percentile / 2, axis=-1)[None]
self.audio_data, (clip_percentile / 2).tolist(), axis=-1)[None]
max_thresh = paddle.quantile(
self.audio_data, 1 - (clip_percentile / 2), axis=-1)[None]
self.audio_data, (1 - clip_percentile / 2).tolist(), axis=-1)[None]
nc = self.audio_data.shape[1]
min_thresh = min_thresh[:, :nc, :]

@ -152,9 +152,10 @@ class Meter(paddle.nn.Layer):
paddle.Tensor
Filtered audio data.
"""
if data.place.is_gpu_place() or self.use_fir:
data = self.apply_filter_gpu(data)
else:
# if data.place.is_gpu_place() or self.use_fir:
# data = self.apply_filter_gpu(data)
# else:
# data = self.apply_filter_cpu(data)
data = self.apply_filter_cpu(data)
return data
@ -246,13 +247,13 @@ class Meter(paddle.nn.Layer):
z_avg_gated[l <= Gamma_a] = 0
z_avg_gated[l <= Gamma_r] = 0
masked = (l > Gamma_a) * (l > Gamma_r)
z_avg_gated = z_avg_gated.sum(2) / masked.sum(2)
z_avg_gated = z_avg_gated.sum(2) / (masked.sum(2) + 10e-6)
# # Cannot use nan_to_num (pytorch 1.8 does not come with GCP-supported cuda version)
# z_avg_gated = torch.nan_to_num(z_avg_gated)
z_avg_gated = paddle.where(
paddle.isnan(z_avg_gated),
paddle.zeros_like(z_avg_gated), z_avg_gated)
# TODO Currently, paddle has a segmentation fault bug in this section of the code
# z_avg_gated = paddle.nan_to_num(z_avg_gated)
# z_avg_gated = paddle.where(
# paddle.isnan(z_avg_gated),
# paddle.zeros_like(z_avg_gated), z_avg_gated)
z_avg_gated[z_avg_gated == float("inf")] = float(
np.finfo(np.float32).max)
z_avg_gated[z_avg_gated == -float("inf")] = float(

@ -200,7 +200,7 @@ class AudioDataset:
>>>
>>> loaders = [
>>> AudioLoader(
>>> sources=[f"tests/audio/spk"],
>>> sources=[f"tests/audiotools/audio/spk"],
>>> transform=tfm.Equalizer(),
>>> ext=["wav"],
>>> )

@ -127,7 +127,8 @@ class BaseTransform:
# masked_batch = {k: v[mask] for k, v in flatten(batch).items()}
masked_batch = {}
for k, v in flatten(batch).items():
if 0 == mask.dim() and 0 == v.dim():
# `v` may be `Tensor` or `AudioSignal`
if 0 == len(v.shape) and 0 == mask.dim():
if mask: # 0d 的 True
masked_batch[k] = v[None]
else:
@ -998,64 +999,63 @@ class VolumeNorm(BaseTransform):
return signal.normalize(db)
# class GlobalVolumeNorm(BaseTransform):
# """Similar to :py:func:`audiotools.data.transforms.VolumeNorm`, this
# transform also normalizes the volume of a signal, but it uses
# the volume of the entire audio file the loaded excerpt comes from,
# rather than the volume of just the excerpt. The volume of the
# entire audio file is expected in ``signal.metadata["loudness"]``.
# If loading audio from a CSV generated by :py:func:`audiotools.data.preprocess.create_csv`
# with ``loudness = True``, like the following:
class GlobalVolumeNorm(BaseTransform):
"""Similar to :py:func:`audiotools.data.transforms.VolumeNorm`, this
transform also normalizes the volume of a signal, but it uses
the volume of the entire audio file the loaded excerpt comes from,
rather than the volume of just the excerpt. The volume of the
entire audio file is expected in ``signal.metadata["loudness"]``.
If loading audio from a CSV generated by :py:func:`audiotools.data.preprocess.create_csv`
with ``loudness = True``, like the following:
# .. csv-table::
# :header: path,loudness
.. csv-table::
:header: path,loudness
# daps/produced/f1_script1_produced.wav,-16.299999237060547
# daps/produced/f1_script2_produced.wav,-16.600000381469727
# daps/produced/f1_script3_produced.wav,-17.299999237060547
# daps/produced/f1_script4_produced.wav,-16.100000381469727
# daps/produced/f1_script5_produced.wav,-16.700000762939453
# daps/produced/f3_script1_produced.wav,-16.5
daps/produced/f1_script1_produced.wav,-16.299999237060547
daps/produced/f1_script2_produced.wav,-16.600000381469727
daps/produced/f1_script3_produced.wav,-17.299999237060547
daps/produced/f1_script4_produced.wav,-16.100000381469727
daps/produced/f1_script5_produced.wav,-16.700000762939453
daps/produced/f3_script1_produced.wav,-16.5
# The ``AudioLoader`` will automatically load the loudness column into
# the metadata of the signal.
The ``AudioLoader`` will automatically load the loudness column into
the metadata of the signal.
# Uses :py:func:`audiotools.core.effects.EffectMixin.volume_change`.
Uses :py:func:`audiotools.core.effects.EffectMixin.volume_change`.
# Parameters
# ----------
# db : tuple, optional
# dB to normalize signal to, by default ("const", -24)
# name : str, optional
# Name of this transform, used to identify it in the dictionary
# produced by ``self.instantiate``, by default None
# prob : float, optional
# Probability of applying this transform, by default 1.0
# """
Parameters
----------
db : tuple, optional
dB to normalize signal to, by default ("const", -24)
name : str, optional
Name of this transform, used to identify it in the dictionary
produced by ``self.instantiate``, by default None
prob : float, optional
Probability of applying this transform, by default 1.0
"""
# def __init__(
# self,
# db: tuple = ("const", -24),
# name: str = None,
# prob: float = 1.0,
# ):
# super().__init__(name=name, prob=prob)
def __init__(
self,
db: tuple=("const", -24),
name: str=None,
prob: float=1.0, ):
super().__init__(name=name, prob=prob)
# self.db = db
self.db = db
# def _instantiate(self, state: RandomState, signal: AudioSignal):
# if "loudness" not in signal.metadata:
# db_change = 0.0
# elif float(signal.metadata["loudness"]) == float("-inf"):
# db_change = 0.0
# else:
# db = util.sample_from_dist(self.db, state)
# db_change = db - float(signal.metadata["loudness"])
def _instantiate(self, state: RandomState, signal: AudioSignal):
if "loudness" not in signal.metadata:
db_change = 0.0
elif float(signal.metadata["loudness"]) == float("-inf"):
db_change = 0.0
else:
db = util.sample_from_dist(self.db, state)
db_change = db - float(signal.metadata["loudness"])
# return {"db": db_change}
return {"db": db_change}
# def _transform(self, signal, db):
# return signal.volume_change(db)
def _transform(self, signal, db):
return signal.volume_change(db)
class Silence(BaseTransform):
@ -1266,94 +1266,95 @@ class HighPass(BaseTransform):
# def _transform(self, signal, corruption):
# return signal.shift_phase(shift=corruption)
# class FrequencyMask(SpectralTransform):
# """Masks a band of frequencies at a center frequency
# from the audio.
# Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_frequencies`.
class FrequencyMask(SpectralTransform):
"""Masks a band of frequencies at a center frequency
from the audio.
# Parameters
# ----------
# f_center : tuple, optional
# Center frequency between 0.0 and 1.0 (Nyquist), by default ("uniform", 0.0, 1.0)
# f_width : tuple, optional
# Width of zero'd out band, by default ("const", 0.1)
# name : str, optional
# Name of this transform, used to identify it in the dictionary
# produced by ``self.instantiate``, by default None
# prob : float, optional
# Probability of applying this transform, by default 1.0
# """
Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_frequencies`.
# def __init__(
# self,
# f_center: tuple = ("uniform", 0.0, 1.0),
# f_width: tuple = ("const", 0.1),
# name: str = None,
# prob: float = 1,
# ):
# super().__init__(name=name, prob=prob)
# self.f_center = f_center
# self.f_width = f_width
Parameters
----------
f_center : tuple, optional
Center frequency between 0.0 and 1.0 (Nyquist), by default ("uniform", 0.0, 1.0)
f_width : tuple, optional
Width of zero'd out band, by default ("const", 0.1)
name : str, optional
Name of this transform, used to identify it in the dictionary
produced by ``self.instantiate``, by default None
prob : float, optional
Probability of applying this transform, by default 1.0
"""
# def _instantiate(self, state: RandomState, signal: AudioSignal):
# f_center = util.sample_from_dist(self.f_center, state)
# f_width = util.sample_from_dist(self.f_width, state)
def __init__(
self,
f_center: tuple=("uniform", 0.0, 1.0),
f_width: tuple=("const", 0.1),
name: str=None,
prob: float=1, ):
super().__init__(name=name, prob=prob)
self.f_center = f_center
self.f_width = f_width
# fmin = max(f_center - (f_width / 2), 0.0)
# fmax = min(f_center + (f_width / 2), 1.0)
def _instantiate(self, state: RandomState, signal: AudioSignal):
f_center = util.sample_from_dist(self.f_center, state)
f_width = util.sample_from_dist(self.f_width, state)
# fmin_hz = (signal.sample_rate / 2) * fmin
# fmax_hz = (signal.sample_rate / 2) * fmax
fmin = max(f_center - (f_width / 2), 0.0)
fmax = min(f_center + (f_width / 2), 1.0)
# return {"fmin_hz": fmin_hz, "fmax_hz": fmax_hz}
fmin_hz = (signal.sample_rate / 2) * fmin
fmax_hz = (signal.sample_rate / 2) * fmax
# def _transform(self, signal, fmin_hz: float, fmax_hz: float):
# return signal.mask_frequencies(fmin_hz=fmin_hz, fmax_hz=fmax_hz)
return {"fmin_hz": fmin_hz, "fmax_hz": fmax_hz}
# class TimeMask(SpectralTransform):
# """Masks out contiguous time-steps from signal.
def _transform(self, signal, fmin_hz: float, fmax_hz: float):
return signal.mask_frequencies(fmin_hz=fmin_hz, fmax_hz=fmax_hz)
# Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_timesteps`.
# Parameters
# ----------
# t_center : tuple, optional
# Center time in terms of 0.0 and 1.0 (duration of signal),
# by default ("uniform", 0.0, 1.0)
# t_width : tuple, optional
# Width of dropped out portion, by default ("const", 0.025)
# name : str, optional
# Name of this transform, used to identify it in the dictionary
# produced by ``self.instantiate``, by default None
# prob : float, optional
# Probability of applying this transform, by default 1.0
# """
class TimeMask(SpectralTransform):
"""Masks out contiguous time-steps from signal.
# def __init__(
# self,
# t_center: tuple = ("uniform", 0.0, 1.0),
# t_width: tuple = ("const", 0.025),
# name: str = None,
# prob: float = 1,
# ):
# super().__init__(name=name, prob=prob)
# self.t_center = t_center
# self.t_width = t_width
Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_timesteps`.
# def _instantiate(self, state: RandomState, signal: AudioSignal):
# t_center = util.sample_from_dist(self.t_center, state)
# t_width = util.sample_from_dist(self.t_width, state)
Parameters
----------
t_center : tuple, optional
Center time in terms of 0.0 and 1.0 (duration of signal),
by default ("uniform", 0.0, 1.0)
t_width : tuple, optional
Width of dropped out portion, by default ("const", 0.025)
name : str, optional
Name of this transform, used to identify it in the dictionary
produced by ``self.instantiate``, by default None
prob : float, optional
Probability of applying this transform, by default 1.0
"""
# tmin = max(t_center - (t_width / 2), 0.0)
# tmax = min(t_center + (t_width / 2), 1.0)
def __init__(
self,
t_center: tuple=("uniform", 0.0, 1.0),
t_width: tuple=("const", 0.025),
name: str=None,
prob: float=1, ):
super().__init__(name=name, prob=prob)
self.t_center = t_center
self.t_width = t_width
# tmin_s = signal.signal_duration * tmin
# tmax_s = signal.signal_duration * tmax
# return {"tmin_s": tmin_s, "tmax_s": tmax_s}
def _instantiate(self, state: RandomState, signal: AudioSignal):
t_center = util.sample_from_dist(self.t_center, state)
t_width = util.sample_from_dist(self.t_width, state)
tmin = max(t_center - (t_width / 2), 0.0)
tmax = min(t_center + (t_width / 2), 1.0)
tmin_s = signal.signal_duration * tmin
tmax_s = signal.signal_duration * tmax
return {"tmin_s": tmin_s, "tmax_s": tmax_s}
def _transform(self, signal, tmin_s: float, tmax_s: float):
return signal.mask_timesteps(tmin_s=tmin_s, tmax_s=tmax_s)
# def _transform(self, signal, tmin_s: float, tmax_s: float):
# return signal.mask_timesteps(tmin_s=tmin_s, tmax_s=tmax_s)
# class MaskLowMagnitudes(SpectralTransform):
# """Masks low magnitude regions out of signal.
@ -1387,55 +1388,55 @@ class HighPass(BaseTransform):
# def _transform(self, signal, db_cutoff: float):
# return signal.mask_low_magnitudes(db_cutoff)
# class Smoothing(BaseTransform):
# """Convolves the signal with a smoothing window.
# Uses :py:func:`audiotools.core.effects.EffectMixin.convolve`.
class Smoothing(BaseTransform):
"""Convolves the signal with a smoothing window.
# Parameters
# ----------
# window_type : tuple, optional
# Type of window to use, by default ("const", "average")
# window_length : tuple, optional
# Length of smoothing window, by
# default ("choice", [8, 16, 32, 64, 128, 256, 512])
# name : str, optional
# Name of this transform, used to identify it in the dictionary
# produced by ``self.instantiate``, by default None
# prob : float, optional
# Probability of applying this transform, by default 1.0
# """
Uses :py:func:`audiotools.core.effects.EffectMixin.convolve`.
# def __init__(
# self,
# window_type: tuple = ("const", "average"),
# window_length: tuple = ("choice", [8, 16, 32, 64, 128, 256, 512]),
# name: str = None,
# prob: float = 1,
# ):
# super().__init__(name=name, prob=prob)
# self.window_type = window_type
# self.window_length = window_length
Parameters
----------
window_type : tuple, optional
Type of window to use, by default ("const", "average")
window_length : tuple, optional
Length of smoothing window, by
default ("choice", [8, 16, 32, 64, 128, 256, 512])
name : str, optional
Name of this transform, used to identify it in the dictionary
produced by ``self.instantiate``, by default None
prob : float, optional
Probability of applying this transform, by default 1.0
"""
# def _instantiate(self, state: RandomState, signal: AudioSignal = None):
# window_type = util.sample_from_dist(self.window_type, state)
# window_length = util.sample_from_dist(self.window_length, state)
# window = signal.get_window(
# window_type=window_type, window_length=window_length, device="cpu"
# )
# return {"window": AudioSignal(window, signal.sample_rate)}
def __init__(
self,
window_type: tuple=("const", "average"),
window_length: tuple=("choice", [8, 16, 32, 64, 128, 256, 512]),
name: str=None,
prob: float=1, ):
super().__init__(name=name, prob=prob)
self.window_type = window_type
self.window_length = window_length
def _instantiate(self, state: RandomState, signal: AudioSignal=None):
window_type = util.sample_from_dist(self.window_type, state)
window_length = util.sample_from_dist(self.window_length, state)
window = signal.get_window(
window_type=window_type, window_length=window_length, device="cpu")
return {"window": AudioSignal(window, signal.sample_rate)}
# def _transform(self, signal, window):
# sscale = signal.audio_data.abs().max(dim=-1, keepdim=True).values
# sscale[sscale == 0.0] = 1.0
def _transform(self, signal, window):
sscale = signal.audio_data.abs().max(axis=-1, keepdim=True)
sscale[sscale == 0.0] = 1.0
# out = signal.convolve(window)
out = signal.convolve(window)
# oscale = out.audio_data.abs().max(dim=-1, keepdim=True).values
# oscale[oscale == 0.0] = 1.0
oscale = out.audio_data.abs().max(axis=-1, keepdim=True)
oscale[oscale == 0.0] = 1.0
out = out * (sscale / oscale)
return out
# out = out * (sscale / oscale)
# return out
# class TimeNoise(TimeMask):
# """Similar to :py:func:`audiotools.data.transforms.TimeMask`, but
@ -1478,45 +1479,51 @@ class HighPass(BaseTransform):
# signal.phase = phase
# return signal
# class FrequencyNoise(FrequencyMask):
# """Similar to :py:func:`audiotools.data.transforms.FrequencyMask`, but
# replaces with noise instead of zeros.
# Parameters
# ----------
# f_center : tuple, optional
# Center frequency between 0.0 and 1.0 (Nyquist), by default ("uniform", 0.0, 1.0)
# f_width : tuple, optional
# Width of zero'd out band, by default ("const", 0.1)
# name : str, optional
# Name of this transform, used to identify it in the dictionary
# produced by ``self.instantiate``, by default None
# prob : float, optional
# Probability of applying this transform, by default 1.0
# """
class FrequencyNoise(FrequencyMask):
"""Similar to :py:func:`audiotools.data.transforms.FrequencyMask`, but
replaces with noise instead of zeros.
# def __init__(
# self,
# f_center: tuple = ("uniform", 0.0, 1.0),
# f_width: tuple = ("const", 0.1),
# name: str = None,
# prob: float = 1,
# ):
# super().__init__(f_center=f_center, f_width=f_width, name=name, prob=prob)
Parameters
----------
f_center : tuple, optional
Center frequency between 0.0 and 1.0 (Nyquist), by default ("uniform", 0.0, 1.0)
f_width : tuple, optional
Width of zero'd out band, by default ("const", 0.1)
name : str, optional
Name of this transform, used to identify it in the dictionary
produced by ``self.instantiate``, by default None
prob : float, optional
Probability of applying this transform, by default 1.0
"""
# def _transform(self, signal, fmin_hz: float, fmax_hz: float):
# signal = signal.mask_frequencies(fmin_hz=fmin_hz, fmax_hz=fmax_hz)
# mag, phase = signal.magnitude, signal.phase
def __init__(
self,
f_center: tuple=("uniform", 0.0, 1.0),
f_width: tuple=("const", 0.1),
name: str=None,
prob: float=1, ):
super().__init__(
f_center=f_center, f_width=f_width, name=name, prob=prob)
# mag_r, phase_r = torch.randn_like(mag), torch.randn_like(phase)
# mask = (mag == 0.0) * (phase == 0.0)
def _transform(self, signal, fmin_hz: float, fmax_hz: float):
signal = signal.mask_frequencies(fmin_hz=fmin_hz, fmax_hz=fmax_hz)
mag, phase = signal.magnitude, signal.phase
# mag[mask] = mag_r[mask]
# phase[mask] = phase_r[mask]
mag_r, phase_r = paddle.randn(
shape=mag.shape, dtype=mag.dtype), paddle.randn(
shape=phase.shape, dtype=phase.dtype)
mask = (mag == 0.0) * (phase == 0.0)
# mag[mask] = mag_r[mask]
# phase[mask] = phase_r[mask]
mag = paddle.where(mask, mag_r, mag)
phase = paddle.where(mask, phase_r, phase)
signal.magnitude = mag
signal.phase = phase
return signal
# signal.magnitude = mag
# signal.phase = phase
# return signal
# class SpectralDenoising(Equalizer):
# """Applies denoising algorithm detailed in

@ -0,0 +1,139 @@
import tempfile
import typing
import zipfile
from pathlib import Path
import markdown2 as md
import matplotlib.pyplot as plt
import paddle
from audiotools import AudioSignal
from IPython.display import HTML
def audio_table(
audio_dict: dict,
first_column: str=None,
format_fn: typing.Callable=None,
**kwargs, ): # pragma: no cover
"""Embeds an audio table into HTML, or as the output cell
in a notebook.
Parameters
----------
audio_dict : dict
Dictionary of data to embed.
first_column : str, optional
The label for the first column of the table, by default None
format_fn : typing.Callable, optional
How to format the data, by default None
Returns
-------
str
Table as a string
Examples
--------
>>> audio_dict = {}
>>> for i in range(signal_batch.batch_size):
>>> audio_dict[i] = {
>>> "input": signal_batch[i],
>>> "output": output_batch[i]
>>> }
>>> audiotools.post.audio_zip(audio_dict)
"""
from audiotools import AudioSignal
output = []
columns = None
def _default_format_fn(label, x, **kwargs):
if paddle.is_tensor(x):
x = x.tolist()
if x is None:
return "."
elif isinstance(x, AudioSignal):
return x.embed(display=False, return_html=True, **kwargs)
else:
return str(x)
if format_fn is None:
format_fn = _default_format_fn
if first_column is None:
first_column = "."
for k, v in audio_dict.items():
if not isinstance(v, dict):
v = {"Audio": v}
v_keys = list(v.keys())
if columns is None:
columns = [first_column] + v_keys
output.append(" | ".join(columns))
layout = "|---" + len(v_keys) * "|:-:"
output.append(layout)
formatted_audio = []
for col in columns[1:]:
formatted_audio.append(format_fn(col, v[col], **kwargs))
row = f"| {k} | "
row += " | ".join(formatted_audio)
output.append(row)
output = "\n" + "\n".join(output)
return output
def in_notebook(): # pragma: no cover
"""Determines if code is running in a notebook.
Returns
-------
bool
Whether or not this is running in a notebook.
"""
try:
from IPython import get_ipython
if "IPKernelApp" not in get_ipython().config: # pragma: no cover
return False
except ImportError:
return False
except AttributeError:
return False
return True
def disp(obj, **kwargs): # pragma: no cover
"""Displays an object, depending on if its in a notebook
or not.
Parameters
----------
obj : typing.Any
Any object to display.
"""
IN_NOTEBOOK = in_notebook()
if isinstance(obj, AudioSignal):
audio_elem = obj.embed(display=False, return_html=True)
if IN_NOTEBOOK:
return HTML(audio_elem)
else:
print(audio_elem)
if isinstance(obj, dict):
table = audio_table(obj, **kwargs)
if IN_NOTEBOOK:
return HTML(md.markdown(table, extras=["tables"]))
else:
print(table)
if isinstance(obj, plt.Figure):
plt.show()

@ -0,0 +1,11 @@
flatten_dict
gradio
IPython
librosa
markdown2
pyloudnorm
pytest
pytest-xdist
rich
scipy
soundfile

@ -7,13 +7,13 @@ import numpy as np
import paddle
import pytest
import rich
sys.path.append("/home/work/pdaudoio")
sys.path.append("/home/aistudio/PaddleSpeech/audio")
import audiotools
from audiotools import AudioSignal
def test_io():
audio_path = "tests/audio/spk/f10_script4_produced.wav"
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
signal = AudioSignal(pathlib.Path(audio_path))
with tempfile.NamedTemporaryFile(suffix=".wav") as f:
@ -61,7 +61,7 @@ def test_io():
assert signal.audio_data.ndim == 3
assert paddle.all(signal.samples == signal.audio_data)
audio_path = "tests/audio/spk/f10_script4_produced.wav"
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
assert AudioSignal(audio_path).hash() == AudioSignal(audio_path).hash()
assert AudioSignal(audio_path).hash() != AudioSignal(audio_path).normalize(
-20).hash()
@ -71,7 +71,7 @@ def test_io():
def test_copy_and_clone():
audio_path = "tests/audio/spk/f10_script4_produced.wav"
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path)
signal.stft()
signal.loudness()
@ -369,7 +369,7 @@ def test_trim():
def test_to_from_ops():
audio_path = "tests/audio/spk/f10_script4_produced.wav"
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path)
signal.stft()
signal.loudness()
@ -384,16 +384,12 @@ def test_to_from_ops():
def test_device():
audio_path = "tests/audio/spk/f10_script4_produced.wav"
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path)
signal.to("cpu")
assert str(signal.device) == "Place(cpu)"
signal.stft()
signal.audio_data = None
assert str(signal.device) == "Place(cpu)"
@pytest.mark.parametrize("window_length", [2048, 512])
@pytest.mark.parametrize("hop_length", [512, 128])
@ -401,7 +397,7 @@ def test_device():
def test_stft(window_length, hop_length, window_type):
if hop_length >= window_length:
hop_length = window_length // 2
audio_path = "tests/audio/spk/f10_script4_produced.wav"
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
stft_params = audiotools.STFTParams(
window_length=window_length,
hop_length=hop_length,
@ -460,7 +456,7 @@ def test_stft(window_length, hop_length, window_type):
def test_log_magnitude():
audio_path = "tests/audio/spk/f10_script4_produced.wav"
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
for _ in range(10):
signal = AudioSignal.excerpt(audio_path, duration=5.0)
magnitude = signal.magnitude.numpy()[0, 0]
@ -478,7 +474,7 @@ def test_log_magnitude():
def test_mel_spectrogram(n_mels, window_length, hop_length, window_type):
if hop_length >= window_length:
hop_length = window_length // 2
audio_path = "tests/audio/spk/f10_script4_produced.wav"
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
stft_params = audiotools.STFTParams(
window_length=window_length,
hop_length=hop_length,
@ -496,7 +492,7 @@ def test_mel_spectrogram(n_mels, window_length, hop_length, window_type):
def test_mfcc(n_mfcc, n_mels, window_length, hop_length):
if hop_length >= window_length:
hop_length = window_length // 2
audio_path = "tests/audio/spk/f10_script4_produced.wav"
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
stft_params = audiotools.STFTParams(
window_length=window_length, hop_length=hop_length)
for _stft_params in [None, stft_params]:

@ -5,7 +5,7 @@ import sys
import unittest
import paddle
sys.path.append("/home/work/pdaudoio")
sys.path.append("/home/aistudio/PaddleSpeech/audio")
from audiotools.core import pure_tone, SplitBands, split_bands

@ -6,7 +6,7 @@ import unittest
import paddle
import paddle.nn.functional as F
sys.path.append("/home/work/pdaudoio")
sys.path.append("/home/aistudio/PaddleSpeech/audio")
from audiotools.core import fft_conv1d, FFTConv1d
TOLERANCE = 1e-4 # as relative delta in percentage

@ -6,7 +6,7 @@ import sys
import unittest
import paddle
sys.path.append("/home/work/pdaudoio")
sys.path.append("/home/aistudio/PaddleSpeech/audio")
from audiotools.core import highpass_filter, highpass_filters

@ -3,7 +3,7 @@ import sys
import numpy as np
import pyloudnorm
import soundfile as sf
sys.path.append("/home/work/pdaudoio")
sys.path.append("/home/aistudio/PaddleSpeech/audio")
from audiotools import AudioSignal
from audiotools import datasets
from audiotools import Meter
@ -13,7 +13,7 @@ ATOL = 1e-1
def test_loudness_against_pyln():
audio_path = "tests/audio/spk/f10_script4_produced.wav"
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path, offset=5, duration=10)
signal_loudness = signal.loudness()
@ -24,7 +24,7 @@ def test_loudness_against_pyln():
def test_loudness_short():
audio_path = "tests/audio/spk/f10_script4_produced.wav"
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path, offset=10, duration=0.25)
signal_loudness = signal.loudness()
@ -58,7 +58,7 @@ def test_batch_loudness():
# Tests below are copied from pyloudnorm
def test_integrated_loudness():
data, rate = sf.read("tests/audio/loudness/sine_1000.wav")
data, rate = sf.read("tests/audiotools/audio/loudness/sine_1000.wav")
meter = Meter(rate)
loudness = meter(data)
@ -67,7 +67,8 @@ def test_integrated_loudness():
def test_rel_gate_test():
data, rate = sf.read("tests/audio/loudness/1770-2_Comp_RelGateTest.wav")
data, rate = sf.read(
"tests/audiotools/audio/loudness/1770-2_Comp_RelGateTest.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
@ -76,7 +77,8 @@ def test_rel_gate_test():
def test_abs_gate_test():
data, rate = sf.read("tests/audio/loudness/1770-2_Comp_AbsGateTest.wav")
data, rate = sf.read(
"tests/audiotools/audio/loudness/1770-2_Comp_AbsGateTest.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
@ -85,7 +87,8 @@ def test_abs_gate_test():
def test_24LKFS_25Hz_2ch():
data, rate = sf.read("tests/audio/loudness/1770-2_Comp_24LKFS_25Hz_2ch.wav")
data, rate = sf.read(
"tests/audiotools/audio/loudness/1770-2_Comp_24LKFS_25Hz_2ch.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
@ -95,7 +98,7 @@ def test_24LKFS_25Hz_2ch():
def test_24LKFS_100Hz_2ch():
data, rate = sf.read(
"tests/audio/loudness/1770-2_Comp_24LKFS_100Hz_2ch.wav")
"tests/audiotools/audio/loudness/1770-2_Comp_24LKFS_100Hz_2ch.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
@ -105,7 +108,7 @@ def test_24LKFS_100Hz_2ch():
def test_24LKFS_500Hz_2ch():
data, rate = sf.read(
"tests/audio/loudness/1770-2_Comp_24LKFS_500Hz_2ch.wav")
"tests/audiotools/audio/loudness/1770-2_Comp_24LKFS_500Hz_2ch.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
@ -115,7 +118,7 @@ def test_24LKFS_500Hz_2ch():
def test_24LKFS_1000Hz_2ch():
data, rate = sf.read(
"tests/audio/loudness/1770-2_Comp_24LKFS_1000Hz_2ch.wav")
"tests/audiotools/audio/loudness/1770-2_Comp_24LKFS_1000Hz_2ch.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
@ -125,7 +128,7 @@ def test_24LKFS_1000Hz_2ch():
def test_24LKFS_2000Hz_2ch():
data, rate = sf.read(
"tests/audio/loudness/1770-2_Comp_24LKFS_2000Hz_2ch.wav")
"tests/audiotools/audio/loudness/1770-2_Comp_24LKFS_2000Hz_2ch.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
@ -135,7 +138,7 @@ def test_24LKFS_2000Hz_2ch():
def test_24LKFS_10000Hz_2ch():
data, rate = sf.read(
"tests/audio/loudness/1770-2_Comp_24LKFS_10000Hz_2ch.wav")
"tests/audiotools/audio/loudness/1770-2_Comp_24LKFS_10000Hz_2ch.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
@ -144,7 +147,8 @@ def test_24LKFS_10000Hz_2ch():
def test_23LKFS_25Hz_2ch():
data, rate = sf.read("tests/audio/loudness/1770-2_Comp_23LKFS_25Hz_2ch.wav")
data, rate = sf.read(
"tests/audiotools/audio/loudness/1770-2_Comp_23LKFS_25Hz_2ch.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
@ -154,7 +158,7 @@ def test_23LKFS_25Hz_2ch():
def test_23LKFS_100Hz_2ch():
data, rate = sf.read(
"tests/audio/loudness/1770-2_Comp_23LKFS_100Hz_2ch.wav")
"tests/audiotools/audio/loudness/1770-2_Comp_23LKFS_100Hz_2ch.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
@ -164,7 +168,7 @@ def test_23LKFS_100Hz_2ch():
def test_23LKFS_500Hz_2ch():
data, rate = sf.read(
"tests/audio/loudness/1770-2_Comp_23LKFS_500Hz_2ch.wav")
"tests/audiotools/audio/loudness/1770-2_Comp_23LKFS_500Hz_2ch.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
@ -174,7 +178,7 @@ def test_23LKFS_500Hz_2ch():
def test_23LKFS_1000Hz_2ch():
data, rate = sf.read(
"tests/audio/loudness/1770-2_Comp_23LKFS_1000Hz_2ch.wav")
"tests/audiotools/audio/loudness/1770-2_Comp_23LKFS_1000Hz_2ch.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
@ -184,7 +188,7 @@ def test_23LKFS_1000Hz_2ch():
def test_23LKFS_2000Hz_2ch():
data, rate = sf.read(
"tests/audio/loudness/1770-2_Comp_23LKFS_2000Hz_2ch.wav")
"tests/audiotools/audio/loudness/1770-2_Comp_23LKFS_2000Hz_2ch.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
@ -194,7 +198,7 @@ def test_23LKFS_2000Hz_2ch():
def test_23LKFS_10000Hz_2ch():
data, rate = sf.read(
"tests/audio/loudness/1770-2_Comp_23LKFS_10000Hz_2ch.wav")
"tests/audiotools/audio/loudness/1770-2_Comp_23LKFS_10000Hz_2ch.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
@ -204,7 +208,7 @@ def test_23LKFS_10000Hz_2ch():
def test_18LKFS_frequency_sweep():
data, rate = sf.read(
"tests/audio/loudness/1770-2_Comp_18LKFS_FrequencySweep.wav")
"tests/audiotools/audio/loudness/1770-2_Comp_18LKFS_FrequencySweep.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
@ -214,7 +218,7 @@ def test_18LKFS_frequency_sweep():
def test_conf_stereo_vinL_R_23LKFS():
data, rate = sf.read(
"tests/audio/loudness/1770-2_Conf_Stereo_VinL+R-23LKFS.wav")
"tests/audiotools/audio/loudness/1770-2_Conf_Stereo_VinL+R-23LKFS.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
@ -224,7 +228,8 @@ def test_conf_stereo_vinL_R_23LKFS():
def test_conf_monovoice_music_24LKFS():
data, rate = sf.read(
"tests/audio/loudness/1770-2_Conf_Mono_Voice+Music-24LKFS.wav")
"tests/audiotools/audio/loudness/1770-2_Conf_Mono_Voice+Music-24LKFS.wav"
)
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
@ -234,7 +239,8 @@ def test_conf_monovoice_music_24LKFS():
def conf_monovoice_music_24LKFS():
data, rate = sf.read(
"tests/audio/loudness/1770-2_Conf_Mono_Voice+Music-24LKFS.wav")
"tests/audiotools/audio/loudness/1770-2_Conf_Mono_Voice+Music-24LKFS.wav"
)
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
@ -244,7 +250,8 @@ def conf_monovoice_music_24LKFS():
def test_conf_monovoice_music_23LKFS():
data, rate = sf.read(
"tests/audio/loudness/1770-2_Conf_Mono_Voice+Music-23LKFS.wav")
"tests/audiotools/audio/loudness/1770-2_Conf_Mono_Voice+Music-23LKFS.wav"
)
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
@ -259,7 +266,7 @@ def test_fir_accuracy():
transforms.HighPass(prob=0.5),
transforms.Equalizer(prob=0.5),
prob=0.5, )
loader = datasets.AudioLoader(sources=["tests/audio/spk.csv"])
loader = datasets.AudioLoader(sources=["tests/audiotools/audio/spk.csv"])
dataset = datasets.AudioDataset(
loader,
44100,
@ -278,6 +285,3 @@ def test_fir_accuracy():
fir_db = signal.clone().loudness(use_fir=True)
assert np.allclose(iir_db, fir_db, atol=1e-2)
test_fir_accuracy()

@ -7,7 +7,7 @@ import unittest
import numpy as np
import paddle
sys.path.append("/home/work/pdaudoio")
sys.path.append("/home/aistudio/PaddleSpeech/audio")
from audiotools.core import LowPassFilter, LowPassFilters, lowpass_filter, resample_frac

@ -7,7 +7,7 @@ import numpy as np
import paddle
import pytest
sys.path.append("/home/work/pdaudoio")
sys.path.append("/home/aistudio/PaddleSpeech/audio")
from audiotools import util
from audiotools.core.audio_signal import AudioSignal
@ -66,7 +66,8 @@ def test_find_audio():
assert not audio_files
# Make sure it works with single audio files
audio_files = util.find_audio("tests/audio/spk//f10_script4_produced.wav")
audio_files = util.find_audio(
"tests/audiotools/audio/spk//f10_script4_produced.wav")
# Make sure it works with globs
audio_files = util.find_audio("tests/**/*.wav")

@ -5,7 +5,7 @@ from pathlib import Path
import numpy as np
import pytest
sys.path.append("/home/work/pdaudoio")
sys.path.append("/home/aistudio/PaddleSpeech/audio")
import paddle
import audiotools
from audiotools.data import transforms as tfm
@ -45,7 +45,7 @@ def test_audio_dataset():
tfm.Silence(prob=0.5),
], )
loader = audiotools.data.datasets.AudioLoader(
sources=["tests/audio/spk.csv"],
sources=["tests/audiotools/audio/spk.csv"],
transform=transform, )
dataset = audiotools.data.datasets.AudioDataset(
loader,
@ -161,11 +161,11 @@ def test_loader_out_of_range():
def test_dataset_pipeline():
transform = tfm.Compose([
tfm.RoomImpulseResponse(sources=["tests/audio/irs.csv"]),
tfm.BackgroundNoise(sources=["tests/audio/noises.csv"]),
tfm.RoomImpulseResponse(sources=["tests/audiotools/audio/irs.csv"]),
tfm.BackgroundNoise(sources=["tests/audiotools/audio/noises.csv"]),
])
loader = audiotools.data.datasets.AudioLoader(
sources=["tests/audio/spk.csv"])
sources=["tests/audiotools/audio/spk.csv"])
dataset = audiotools.data.datasets.AudioDataset(
loader,
44100,

@ -3,7 +3,7 @@ import tempfile
from pathlib import Path
import paddle
sys.path.append("/home/work/pdaudoio")
sys.path.append("/home/aistudio/PaddleSpeech/audio")
from audiotools.core.util import find_audio
from audiotools.core.util import read_sources
from audiotools.data import preprocess
@ -12,11 +12,13 @@ from audiotools.data import preprocess
def test_create_csv():
with tempfile.NamedTemporaryFile(suffix=".csv") as f:
preprocess.create_csv(
find_audio("./tests/audio/spk", ext=["wav"]), f.name, loudness=True)
find_audio("./tests/audiotools/audio/spk", ext=["wav"]),
f.name,
loudness=True)
def test_create_csv_with_empty_rows():
audio_files = find_audio("./tests/audio/spk", ext=["wav"])
audio_files = find_audio("./tests/audiotools/audio/spk", ext=["wav"])
audio_files.insert(0, "")
audio_files.insert(2, "")

@ -7,7 +7,7 @@ import numpy as np
import paddle
import pytest
sys.path.append("/home/work/pdaudoio")
sys.path.append("/home/aistudio/PaddleSpeech/audio")
import audiotools
from audiotools import AudioSignal
from audiotools import util
@ -49,13 +49,13 @@ def test_transform(transform_name):
kwargs = {}
if transform_name == "BackgroundNoise":
kwargs["sources"] = ["tests/audio/noises.csv"]
kwargs["sources"] = ["tests/audiotools/audio/noises.csv"]
if transform_name == "RoomImpulseResponse":
kwargs["sources"] = ["tests/audio/irs.csv"]
kwargs["sources"] = ["tests/audiotools/audio/irs.csv"]
if transform_name == "CrossTalk":
kwargs["sources"] = ["tests/audio/spk.csv"]
kwargs["sources"] = ["tests/audiotools/audio/spk.csv"]
audio_path = "tests/audio/spk/f10_script4_produced.wav"
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path, offset=10, duration=2)
signal.metadata["loudness"] = AudioSignal(
audio_path).ffmpeg_loudness().item()
@ -99,18 +99,15 @@ def test_transform(transform_name):
assert output_a == output_b
# test_transform("FrequencyNoise")
def test_compose_basic():
seed = 0
audio_path = "tests/audio/spk/f10_script4_produced.wav"
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path, offset=10, duration=2)
transform = tfm.Compose(
[
tfm.RoomImpulseResponse(sources=["tests/audio/irs.csv"]),
tfm.BackgroundNoise(sources=["tests/audio/noises.csv"]),
tfm.RoomImpulseResponse(sources=["tests/audiotools/audio/irs.csv"]),
tfm.BackgroundNoise(sources=["tests/audiotools/audio/noises.csv"]),
], )
kwargs = transform.instantiate(seed, signal)
@ -146,7 +143,7 @@ def test_compose_with_duplicate_transforms():
full_mul = np.prod(muls)
kwargs = transform.instantiate(0)
audio_path = "tests/audio/spk/f10_script4_produced.wav"
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path, offset=10, duration=2)
output = transform(signal.clone(), **kwargs)
@ -165,7 +162,7 @@ def test_nested_compose():
full_mul = np.prod(muls)
kwargs = transform.instantiate(0)
audio_path = "tests/audio/spk/f10_script4_produced.wav"
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path, offset=10, duration=2)
output = transform(signal.clone(), **kwargs)
@ -179,7 +176,7 @@ def test_compose_filtering():
transform = tfm.Compose([MulTransform(x, name=str(x)) for x in muls])
kwargs = transform.instantiate(0)
audio_path = "tests/audio/spk/f10_script4_produced.wav"
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path, offset=10, duration=2)
for s in range(len(muls)):
@ -202,7 +199,7 @@ def test_sequential_compose():
full_mul = np.prod(muls)
kwargs = transform.instantiate(0)
audio_path = "tests/audio/spk/f10_script4_produced.wav"
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path, offset=10, duration=2)
output = transform(signal.clone(), **kwargs)
@ -213,11 +210,11 @@ def test_sequential_compose():
def test_choose_basic():
seed = 0
audio_path = "tests/audio/spk/f10_script4_produced.wav"
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path, offset=10, duration=2)
transform = tfm.Choose([
tfm.RoomImpulseResponse(sources=["tests/audio/irs.csv"]),
tfm.BackgroundNoise(sources=["tests/audio/noises.csv"]),
tfm.RoomImpulseResponse(sources=["tests/audiotools/audio/irs.csv"]),
tfm.BackgroundNoise(sources=["tests/audiotools/audio/noises.csv"]),
])
kwargs = transform.instantiate(seed, signal)
@ -254,7 +251,7 @@ def test_choose_basic():
def test_choose_weighted():
seed = 0
audio_path = "tests/audio/spk/f10_script4_produced.wav"
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
transform = tfm.Choose(
[
MulTransform(0.0),
@ -280,7 +277,7 @@ def test_choose_weighted():
def test_choose_with_compose():
audio_path = "tests/audio/spk/f10_script4_produced.wav"
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path, offset=10, duration=2)
transform = tfm.Choose([
@ -299,7 +296,7 @@ def test_choose_with_compose():
def test_repeat():
seed = 0
audio_path = "tests/audio/spk/f10_script4_produced.wav"
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path, offset=10, duration=2)
kwargs = {}
@ -359,7 +356,7 @@ class DummyData(paddle.io.Dataset):
def test_masking():
dataset = DummyData("tests/audio/spk/f10_script4_produced.wav")
dataset = DummyData("tests/audiotools/audio/spk/f10_script4_produced.wav")
dataloader = paddle.io.DataLoader(
dataset,
batch_size=16,
@ -389,7 +386,7 @@ def test_nested_masking():
prob=0.9, )
loader = audiotools.data.datasets.AudioLoader(
sources=["tests/audio/spk.csv"])
sources=["tests/audiotools/audio/spk.csv"])
dataset = audiotools.data.datasets.AudioDataset(
loader,
44100,

@ -1,6 +1,6 @@
import sys
import time
sys.path.append("/home/work/pdaudoio")
sys.path.append("/home/aistudio/PaddleSpeech/audio")
import paddle
from visualdl import LogWriter

@ -3,7 +3,7 @@ import tempfile
import paddle
from paddle import nn
sys.path.append("/home/work/pdaudoio")
sys.path.append("/home/aistudio/PaddleSpeech/audio")
from audiotools import ml
from audiotools import util
@ -41,7 +41,7 @@ def test_base_model():
x = paddle.randn([10, 1])
model1 = Model()
assert str(model1.device) == 'Place(cpu)'
# assert str(model1.device) == 'Place(cpu)'
out1 = seed_and_run(model1, x)

@ -1,7 +1,7 @@
import sys
from pathlib import Path
sys.path.append("/home/work/pdaudoio")
sys.path.append("/home/aistudio/PaddleSpeech/audio")
from audiotools import AudioSignal
from audiotools import post
from audiotools import transforms
@ -14,7 +14,7 @@ def test_audio_table():
audio_dict["inputs"] = [
AudioSignal.excerpt(
"tests/audio/spk/f10_script4_produced.wav", duration=5)
"tests/audiotools/audio/spk/f10_script4_produced.wav", duration=5)
for _ in range(3)
]
audio_dict["outputs"] = []
Loading…
Cancel
Save