parent
643f1c6071
commit
080bd7f5db
@ -0,0 +1,191 @@
|
||||
import inspect
|
||||
import typing
|
||||
from functools import wraps
|
||||
|
||||
from . import util
|
||||
|
||||
|
||||
def format_figure(func):
|
||||
"""Decorator for formatting figures produced by the code below.
|
||||
See :py:func:`audiotools.core.util.format_figure` for more.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : Callable
|
||||
Plotting function that is decorated by this function.
|
||||
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
f_keys = inspect.signature(util.format_figure).parameters.keys()
|
||||
f_kwargs = {}
|
||||
for k, v in list(kwargs.items()):
|
||||
if k in f_keys:
|
||||
kwargs.pop(k)
|
||||
f_kwargs[k] = v
|
||||
func(*args, **kwargs)
|
||||
util.format_figure(**f_kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class DisplayMixin:
|
||||
@format_figure
|
||||
def specshow(
|
||||
self,
|
||||
preemphasis: bool=False,
|
||||
x_axis: str="time",
|
||||
y_axis: str="linear",
|
||||
n_mels: int=128,
|
||||
**kwargs, ):
|
||||
"""Displays a spectrogram, using ``librosa.display.specshow``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
preemphasis : bool, optional
|
||||
Whether or not to apply preemphasis, which makes high
|
||||
frequency detail easier to see, by default False
|
||||
x_axis : str, optional
|
||||
How to label the x axis, by default "time"
|
||||
y_axis : str, optional
|
||||
How to label the y axis, by default "linear"
|
||||
n_mels : int, optional
|
||||
If displaying a mel spectrogram with ``y_axis = "mel"``,
|
||||
this controls the number of mels, by default 128.
|
||||
kwargs : dict, optional
|
||||
Keyword arguments to :py:func:`audiotools.core.util.format_figure`.
|
||||
"""
|
||||
import librosa
|
||||
import librosa.display
|
||||
|
||||
# Always re-compute the STFT data before showing it, in case
|
||||
# it changed.
|
||||
signal = self.clone()
|
||||
signal.stft_data = None
|
||||
|
||||
if preemphasis:
|
||||
signal.preemphasis()
|
||||
|
||||
ref = signal.magnitude.max()
|
||||
log_mag = signal.log_magnitude(ref_value=ref)
|
||||
|
||||
if y_axis == "mel":
|
||||
log_mag = 20 * signal.mel_spectrogram(n_mels).clip(1e-5).log10()
|
||||
log_mag -= log_mag.max()
|
||||
|
||||
librosa.display.specshow(
|
||||
log_mag.numpy()[0].mean(axis=0),
|
||||
x_axis=x_axis,
|
||||
y_axis=y_axis,
|
||||
sr=signal.sample_rate,
|
||||
**kwargs, )
|
||||
|
||||
@format_figure
|
||||
def waveplot(self, x_axis: str="time", **kwargs):
|
||||
"""Displays a waveform plot, using ``librosa.display.waveshow``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x_axis : str, optional
|
||||
How to label the x axis, by default "time"
|
||||
kwargs : dict, optional
|
||||
Keyword arguments to :py:func:`audiotools.core.util.format_figure`.
|
||||
"""
|
||||
import librosa
|
||||
import librosa.display
|
||||
|
||||
audio_data = self.audio_data[0].mean(axis=0)
|
||||
audio_data = audio_data.cpu().numpy()
|
||||
|
||||
plot_fn = "waveshow" if hasattr(librosa.display,
|
||||
"waveshow") else "waveplot"
|
||||
wave_plot_fn = getattr(librosa.display, plot_fn)
|
||||
wave_plot_fn(audio_data, x_axis=x_axis, sr=self.sample_rate, **kwargs)
|
||||
|
||||
@format_figure
|
||||
def wavespec(self, x_axis: str="time", **kwargs):
|
||||
"""Displays a waveform plot, using ``librosa.display.waveshow``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x_axis : str, optional
|
||||
How to label the x axis, by default "time"
|
||||
kwargs : dict, optional
|
||||
Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow`.
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.gridspec import GridSpec
|
||||
|
||||
gs = GridSpec(6, 1)
|
||||
plt.subplot(gs[0, :])
|
||||
self.waveplot(x_axis=x_axis)
|
||||
plt.subplot(gs[1:, :])
|
||||
self.specshow(x_axis=x_axis, **kwargs)
|
||||
|
||||
def write_audio_to_tb(
|
||||
self,
|
||||
tag: str,
|
||||
writer,
|
||||
step: int=None,
|
||||
plot_fn: typing.Union[typing.Callable, str]="specshow",
|
||||
**kwargs, ):
|
||||
"""Writes a signal and its spectrogram to Tensorboard. Will show up
|
||||
under the Audio and Images tab in Tensorboard.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tag : str
|
||||
Tag to write signal to (e.g. ``clean/sample_0.wav``). The image will be
|
||||
written to the corresponding ``.png`` file (e.g. ``clean/sample_0.png``).
|
||||
writer : SummaryWriter
|
||||
A SummaryWriter object from PyTorch library.
|
||||
step : int, optional
|
||||
The step to write the signal to, by default None
|
||||
plot_fn : typing.Union[typing.Callable, str], optional
|
||||
How to create the image. Set to ``None`` to avoid plotting, by default "specshow"
|
||||
kwargs : dict, optional
|
||||
Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or
|
||||
whatever ``plot_fn`` is set to.
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
audio_data = self.audio_data[0, 0].detach().cpu().numpy()
|
||||
sample_rate = self.sample_rate
|
||||
writer.add_audio(tag, audio_data, step, sample_rate)
|
||||
|
||||
if plot_fn is not None:
|
||||
if isinstance(plot_fn, str):
|
||||
plot_fn = getattr(self, plot_fn)
|
||||
fig = plt.figure()
|
||||
plt.clf()
|
||||
plot_fn(**kwargs)
|
||||
writer.add_figure(tag.replace("wav", "png"), fig, step)
|
||||
|
||||
def save_image(
|
||||
self,
|
||||
image_path: str,
|
||||
plot_fn: typing.Union[typing.Callable, str]="specshow",
|
||||
**kwargs, ):
|
||||
"""Save AudioSignal spectrogram (or whatever ``plot_fn`` is set to) to
|
||||
a specified file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image_path : str
|
||||
Where to save the file to.
|
||||
plot_fn : typing.Union[typing.Callable, str], optional
|
||||
How to create the image. Set to ``None`` to avoid plotting, by default "specshow"
|
||||
kwargs : dict, optional
|
||||
Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or
|
||||
whatever ``plot_fn`` is set to.
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
if isinstance(plot_fn, str):
|
||||
plot_fn = getattr(self, plot_fn)
|
||||
|
||||
plt.clf()
|
||||
plot_fn(**kwargs)
|
||||
plt.savefig(image_path, bbox_inches="tight", pad_inches=0)
|
||||
plt.close()
|
@ -0,0 +1,48 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
sys.path.append("/home/aistudio/PaddleSpeech/audio")
|
||||
|
||||
from audiotools import AudioSignal
|
||||
from visualdl import LogWriter
|
||||
|
||||
|
||||
def test_specshow():
|
||||
array = np.zeros((1, 16000))
|
||||
AudioSignal(array, sample_rate=16000).specshow()
|
||||
AudioSignal(array, sample_rate=16000).specshow(preemphasis=True)
|
||||
AudioSignal(
|
||||
array, sample_rate=16000).specshow(
|
||||
title="test", preemphasis=True)
|
||||
AudioSignal(
|
||||
array, sample_rate=16000).specshow(
|
||||
format=False, preemphasis=True)
|
||||
AudioSignal(
|
||||
array, sample_rate=16000).specshow(
|
||||
format=False, preemphasis=False, y_axis="mel")
|
||||
|
||||
|
||||
def test_waveplot():
|
||||
array = np.zeros((1, 16000))
|
||||
AudioSignal(array, sample_rate=16000).waveplot()
|
||||
|
||||
|
||||
def test_wavespec():
|
||||
array = np.zeros((1, 16000))
|
||||
AudioSignal(array, sample_rate=16000).wavespec()
|
||||
|
||||
|
||||
def test_write_audio_to_tb():
|
||||
signal = AudioSignal("./audio/spk/f10_script4_produced.mp3", duration=5)
|
||||
|
||||
Path("./scratch").mkdir(parents=True, exist_ok=True)
|
||||
writer = LogWriter("./scratch/")
|
||||
signal.write_audio_to_tb("tag", writer)
|
||||
|
||||
|
||||
def test_save_image():
|
||||
signal = AudioSignal(
|
||||
"./audio/spk/f10_script4_produced.wav", duration=10, offset=10)
|
||||
Path("./scratch").mkdir(parents=True, exist_ok=True)
|
||||
signal.save_image("./scratch/image.png")
|
@ -0,0 +1,178 @@
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import pytest
|
||||
sys.path.append("/home/aistudio/PaddleSpeech/audio")
|
||||
from audiotools import AudioSignal
|
||||
from audiotools.core.util import sample_from_dist
|
||||
|
||||
|
||||
@pytest.mark.parametrize("window_duration", [0.1, 0.25, 0.5, 1.0])
|
||||
@pytest.mark.parametrize("sample_rate", [8000, 16000, 22050, 44100])
|
||||
@pytest.mark.parametrize("duration", [0.5, 1.0, 2.0, 10.0])
|
||||
def test_overlap_add(duration, sample_rate, window_duration):
|
||||
np.random.seed(0)
|
||||
if duration > window_duration:
|
||||
spk_signal = AudioSignal.batch([
|
||||
AudioSignal.excerpt(
|
||||
"./audio/spk/f10_script4_produced.wav", duration=duration)
|
||||
for _ in range(16)
|
||||
])
|
||||
spk_signal.resample(sample_rate)
|
||||
|
||||
noise = paddle.randn([16, 1, int(duration * sample_rate)])
|
||||
nz_signal = AudioSignal(noise, sample_rate=sample_rate)
|
||||
|
||||
def _test(signal):
|
||||
hop_duration = window_duration / 2
|
||||
windowed_signal = signal.deepcopy().collect_windows(window_duration,
|
||||
hop_duration)
|
||||
recombined = windowed_signal.overlap_and_add(hop_duration)
|
||||
|
||||
assert recombined == signal
|
||||
assert np.allclose(recombined.audio_data, signal.audio_data, 1e-3)
|
||||
|
||||
_test(nz_signal)
|
||||
_test(spk_signal)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("window_duration", [0.1, 0.25, 0.5, 1.0])
|
||||
@pytest.mark.parametrize("sample_rate", [8000, 16000, 22050, 44100])
|
||||
@pytest.mark.parametrize("duration", [0.5, 1.0, 2.0, 10.0])
|
||||
def test_inplace_overlap_add(duration, sample_rate, window_duration):
|
||||
np.random.seed(0)
|
||||
if duration > window_duration:
|
||||
spk_signal = AudioSignal.batch([
|
||||
AudioSignal.excerpt(
|
||||
"./audio/spk/f10_script4_produced.wav", duration=duration)
|
||||
for _ in range(16)
|
||||
])
|
||||
spk_signal.resample(sample_rate)
|
||||
|
||||
noise = paddle.randn([16, 1, int(duration * sample_rate)])
|
||||
nz_signal = AudioSignal(noise, sample_rate=sample_rate)
|
||||
|
||||
def _test(signal):
|
||||
hop_duration = window_duration / 2
|
||||
windowed_signal = signal.deepcopy().collect_windows(window_duration,
|
||||
hop_duration)
|
||||
# Compare in-place with unfold results
|
||||
for i, window in enumerate(
|
||||
signal.deepcopy().windows(window_duration, hop_duration)):
|
||||
assert np.allclose(window.audio_data,
|
||||
windowed_signal.audio_data[i])
|
||||
|
||||
_test(nz_signal)
|
||||
_test(spk_signal)
|
||||
|
||||
|
||||
def test_low_pass():
|
||||
sample_rate = 44100
|
||||
f = 440
|
||||
t = paddle.arange(0, 1, 1 / sample_rate)
|
||||
sine_wave = paddle.sin(2 * np.pi * f * t)
|
||||
window = AudioSignal.get_window("hann", sine_wave.shape[-1])
|
||||
sine_wave = sine_wave * window
|
||||
signal = AudioSignal(sine_wave.unsqueeze(0), sample_rate=sample_rate)
|
||||
out = signal.deepcopy().low_pass(220)
|
||||
assert out.audio_data.abs().max() < 1e-4
|
||||
|
||||
out = signal.deepcopy().low_pass(880)
|
||||
assert (out - signal).audio_data.abs().max() < 1e-3
|
||||
|
||||
batch = AudioSignal.batch(
|
||||
[signal.deepcopy(), signal.deepcopy(), signal.deepcopy()])
|
||||
|
||||
cutoffs = [220, 880, 220]
|
||||
out = batch.deepcopy().low_pass(cutoffs)
|
||||
|
||||
assert out.audio_data[0].abs().max() < 1e-4
|
||||
assert out.audio_data[2].abs().max() < 1e-4
|
||||
assert (out - batch).audio_data[1].abs().max() < 1e-3
|
||||
|
||||
|
||||
def test_high_pass():
|
||||
sample_rate = 44100
|
||||
f = 440
|
||||
t = paddle.arange(0, 1, 1 / sample_rate)
|
||||
sine_wave = paddle.sin(2 * np.pi * f * t)
|
||||
window = AudioSignal.get_window("hann", sine_wave.shape[-1])
|
||||
sine_wave = sine_wave * window
|
||||
signal = AudioSignal(sine_wave.unsqueeze(0), sample_rate=sample_rate)
|
||||
out = signal.deepcopy().high_pass(220)
|
||||
assert (signal - out).audio_data.abs().max() < 1e-4
|
||||
|
||||
|
||||
def test_mask_frequencies():
|
||||
sample_rate = 44100
|
||||
fs = paddle.to_tensor([500.0, 2000.0, 8000.0, 32000.0])[None]
|
||||
t = paddle.arange(0, 1, 1 / sample_rate)[:, None]
|
||||
sine_wave = paddle.sin(2 * np.pi * t @ fs).sum(axis=-1)
|
||||
sine_wave = AudioSignal(sine_wave, sample_rate)
|
||||
masked_sine_wave = sine_wave.mask_frequencies(fmin_hz=1500, fmax_hz=10000)
|
||||
|
||||
fs2 = paddle.to_tensor([500.0, 32000.0])[None]
|
||||
sine_wave2 = paddle.sin(2 * np.pi * t @ fs).sum(axis=-1)
|
||||
sine_wave2 = AudioSignal(sine_wave2, sample_rate)
|
||||
|
||||
assert paddle.allclose(masked_sine_wave.audio_data, sine_wave2.audio_data)
|
||||
|
||||
|
||||
def test_mask_timesteps():
|
||||
sample_rate = 44100
|
||||
f = 440
|
||||
t = paddle.linspace(0, 1, sample_rate)
|
||||
sine_wave = paddle.sin(2 * np.pi * f * t)
|
||||
sine_wave = AudioSignal(sine_wave, sample_rate)
|
||||
|
||||
masked_sine_wave = sine_wave.mask_timesteps(tmin_s=0.25, tmax_s=0.75)
|
||||
masked_sine_wave.istft()
|
||||
|
||||
mask = ((0.3 < t) & (t < 0.7))[None, None]
|
||||
assert paddle.allclose(
|
||||
masked_sine_wave.audio_data[mask],
|
||||
paddle.zeros_like(masked_sine_wave.audio_data[mask]), )
|
||||
|
||||
|
||||
def test_shift_phase():
|
||||
sample_rate = 44100
|
||||
f = 440
|
||||
t = paddle.linspace(0, 1, sample_rate)
|
||||
sine_wave = paddle.sin(2 * np.pi * f * t)
|
||||
sine_wave = AudioSignal(sine_wave, sample_rate)
|
||||
sine_wave2 = sine_wave.clone()
|
||||
|
||||
shifted_sine_wave = sine_wave.shift_phase(np.pi)
|
||||
shifted_sine_wave.istft()
|
||||
|
||||
sine_wave2.phase = sine_wave2.phase + np.pi
|
||||
sine_wave2.istft()
|
||||
|
||||
assert paddle.allclose(shifted_sine_wave.audio_data, sine_wave2.audio_data)
|
||||
|
||||
|
||||
def test_corrupt_phase():
|
||||
sample_rate = 44100
|
||||
f = 440
|
||||
t = paddle.linspace(0, 1, sample_rate)
|
||||
sine_wave = paddle.sin(2 * np.pi * f * t)
|
||||
sine_wave = AudioSignal(sine_wave, sample_rate)
|
||||
sine_wave2 = sine_wave.clone()
|
||||
|
||||
shifted_sine_wave = sine_wave.corrupt_phase(scale=np.pi)
|
||||
shifted_sine_wave.istft()
|
||||
|
||||
assert (sine_wave2.phase - shifted_sine_wave.phase).abs().mean() > 0.0
|
||||
assert ((sine_wave2.phase - shifted_sine_wave.phase).std() / np.pi) < 1.0
|
||||
|
||||
|
||||
def test_preemphasis():
|
||||
x = AudioSignal.excerpt("./audio/spk/f10_script4_produced.wav", duration=5)
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
x.specshow(preemphasis=False)
|
||||
|
||||
x.specshow(preemphasis=True)
|
||||
|
||||
x.preemphasis()
|
@ -0,0 +1,4 @@
|
||||
python -m pip install -r ../audiotools/requirements.txt
|
||||
# wget -P ./test_data https://paddlespeech.bj.bcebos.com/datasets/unit_test/asr/static_ds2online_inputs.pickle
|
||||
# wget
|
||||
find . -name "*✅.py" | xargs python -m pytest
|
Loading…
Reference in new issue