diff --git a/audio/audiotools/core/audio_signal.py b/audio/audiotools/core/audio_signal.py index 85fbec1c0..be1f284c4 100644 --- a/audio/audiotools/core/audio_signal.py +++ b/audio/audiotools/core/audio_signal.py @@ -17,14 +17,13 @@ import soundfile from . import util from ._julius import resample_frac +from .display import DisplayMixin from .dsp import DSPMixin from .effects import EffectMixin from .effects import ImpulseResponseMixin from .ffmpeg import FFMPEGMixin from .loudness import LoudnessMixin -# from .display import DisplayMixin - # from .playback import PlayMixin # from .whisper import WhisperMixin @@ -98,7 +97,7 @@ class AudioSignal( # PlayMixin, ImpulseResponseMixin, DSPMixin, - # DisplayMixin, + DisplayMixin, FFMPEGMixin, # WhisperMixin, ): @@ -1498,6 +1497,8 @@ class AudioSignal( amin = amin**2 log_spec = 10.0 * paddle.log10(magnitude.pow(2).clip(min=amin)) + if paddle.is_tensor(ref_value): + ref_value = ref_value.item() log_spec -= 10.0 * np.log10(np.maximum(amin, ref_value)) if top_db is not None: diff --git a/audio/audiotools/core/display.py b/audio/audiotools/core/display.py new file mode 100644 index 000000000..f755183bc --- /dev/null +++ b/audio/audiotools/core/display.py @@ -0,0 +1,191 @@ +import inspect +import typing +from functools import wraps + +from . import util + + +def format_figure(func): + """Decorator for formatting figures produced by the code below. + See :py:func:`audiotools.core.util.format_figure` for more. + + Parameters + ---------- + func : Callable + Plotting function that is decorated by this function. + + """ + + @wraps(func) + def wrapper(*args, **kwargs): + f_keys = inspect.signature(util.format_figure).parameters.keys() + f_kwargs = {} + for k, v in list(kwargs.items()): + if k in f_keys: + kwargs.pop(k) + f_kwargs[k] = v + func(*args, **kwargs) + util.format_figure(**f_kwargs) + + return wrapper + + +class DisplayMixin: + @format_figure + def specshow( + self, + preemphasis: bool=False, + x_axis: str="time", + y_axis: str="linear", + n_mels: int=128, + **kwargs, ): + """Displays a spectrogram, using ``librosa.display.specshow``. + + Parameters + ---------- + preemphasis : bool, optional + Whether or not to apply preemphasis, which makes high + frequency detail easier to see, by default False + x_axis : str, optional + How to label the x axis, by default "time" + y_axis : str, optional + How to label the y axis, by default "linear" + n_mels : int, optional + If displaying a mel spectrogram with ``y_axis = "mel"``, + this controls the number of mels, by default 128. + kwargs : dict, optional + Keyword arguments to :py:func:`audiotools.core.util.format_figure`. + """ + import librosa + import librosa.display + + # Always re-compute the STFT data before showing it, in case + # it changed. + signal = self.clone() + signal.stft_data = None + + if preemphasis: + signal.preemphasis() + + ref = signal.magnitude.max() + log_mag = signal.log_magnitude(ref_value=ref) + + if y_axis == "mel": + log_mag = 20 * signal.mel_spectrogram(n_mels).clip(1e-5).log10() + log_mag -= log_mag.max() + + librosa.display.specshow( + log_mag.numpy()[0].mean(axis=0), + x_axis=x_axis, + y_axis=y_axis, + sr=signal.sample_rate, + **kwargs, ) + + @format_figure + def waveplot(self, x_axis: str="time", **kwargs): + """Displays a waveform plot, using ``librosa.display.waveshow``. + + Parameters + ---------- + x_axis : str, optional + How to label the x axis, by default "time" + kwargs : dict, optional + Keyword arguments to :py:func:`audiotools.core.util.format_figure`. + """ + import librosa + import librosa.display + + audio_data = self.audio_data[0].mean(axis=0) + audio_data = audio_data.cpu().numpy() + + plot_fn = "waveshow" if hasattr(librosa.display, + "waveshow") else "waveplot" + wave_plot_fn = getattr(librosa.display, plot_fn) + wave_plot_fn(audio_data, x_axis=x_axis, sr=self.sample_rate, **kwargs) + + @format_figure + def wavespec(self, x_axis: str="time", **kwargs): + """Displays a waveform plot, using ``librosa.display.waveshow``. + + Parameters + ---------- + x_axis : str, optional + How to label the x axis, by default "time" + kwargs : dict, optional + Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow`. + """ + import matplotlib.pyplot as plt + from matplotlib.gridspec import GridSpec + + gs = GridSpec(6, 1) + plt.subplot(gs[0, :]) + self.waveplot(x_axis=x_axis) + plt.subplot(gs[1:, :]) + self.specshow(x_axis=x_axis, **kwargs) + + def write_audio_to_tb( + self, + tag: str, + writer, + step: int=None, + plot_fn: typing.Union[typing.Callable, str]="specshow", + **kwargs, ): + """Writes a signal and its spectrogram to Tensorboard. Will show up + under the Audio and Images tab in Tensorboard. + + Parameters + ---------- + tag : str + Tag to write signal to (e.g. ``clean/sample_0.wav``). The image will be + written to the corresponding ``.png`` file (e.g. ``clean/sample_0.png``). + writer : SummaryWriter + A SummaryWriter object from PyTorch library. + step : int, optional + The step to write the signal to, by default None + plot_fn : typing.Union[typing.Callable, str], optional + How to create the image. Set to ``None`` to avoid plotting, by default "specshow" + kwargs : dict, optional + Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or + whatever ``plot_fn`` is set to. + """ + import matplotlib.pyplot as plt + + audio_data = self.audio_data[0, 0].detach().cpu().numpy() + sample_rate = self.sample_rate + writer.add_audio(tag, audio_data, step, sample_rate) + + if plot_fn is not None: + if isinstance(plot_fn, str): + plot_fn = getattr(self, plot_fn) + fig = plt.figure() + plt.clf() + plot_fn(**kwargs) + writer.add_figure(tag.replace("wav", "png"), fig, step) + + def save_image( + self, + image_path: str, + plot_fn: typing.Union[typing.Callable, str]="specshow", + **kwargs, ): + """Save AudioSignal spectrogram (or whatever ``plot_fn`` is set to) to + a specified file. + + Parameters + ---------- + image_path : str + Where to save the file to. + plot_fn : typing.Union[typing.Callable, str], optional + How to create the image. Set to ``None`` to avoid plotting, by default "specshow" + kwargs : dict, optional + Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or + whatever ``plot_fn`` is set to. + """ + import matplotlib.pyplot as plt + + if isinstance(plot_fn, str): + plot_fn = getattr(self, plot_fn) + + plt.clf() + plot_fn(**kwargs) + plt.savefig(image_path, bbox_inches="tight", pad_inches=0) + plt.close() diff --git a/audio/audiotools/core/dsp.py b/audio/audiotools/core/dsp.py index aa70f2670..20d7708a9 100644 --- a/audio/audiotools/core/dsp.py +++ b/audio/audiotools/core/dsp.py @@ -7,148 +7,201 @@ from . import _julius from . import util +def _unfold(x, kernel_sizes, strides): + # https://github.com/PaddlePaddle/Paddle/pull/70102 + + if 1 == kernel_sizes[0]: + x_zeros = paddle.zeros_like(x) + x = paddle.concat([x, x_zeros], axis=2) + + kernel_sizes = (2, kernel_sizes[1]) + + unfolded = paddle.nn.functional.unfold( + x, + kernel_sizes=kernel_sizes, + strides=strides, ) + if 2 == kernel_sizes[0]: + unfolded = unfolded[:, :kernel_sizes[1]] + return unfolded + + +def _fold(x, output_sizes, kernel_sizes, strides): + # https://github.com/PaddlePaddle/Paddle/pull/70102 + + if 1 == output_sizes[0] and 1 == kernel_sizes[0]: + x_zeros = paddle.zeros_like(x) + x = paddle.concat([x, x_zeros], axis=1) + + output_sizes = (2, output_sizes[1]) + kernel_sizes = (2, kernel_sizes[1]) + + fold = paddle.nn.functional.fold( + x, + output_sizes=output_sizes, + kernel_sizes=kernel_sizes, + strides=strides, ) + if 2 == kernel_sizes[0]: + fold = fold[:, :, :1] + return fold + + class DSPMixin: _original_batch_size = None _original_num_channels = None _padded_signal_length = None - # def _preprocess_signal_for_windowing(self, window_duration, hop_duration): - # self._original_batch_size = self.batch_size - # self._original_num_channels = self.num_channels - - # window_length = int(window_duration * self.sample_rate) - # hop_length = int(hop_duration * self.sample_rate) - - # if window_length % hop_length != 0: - # factor = window_length // hop_length - # window_length = factor * hop_length - - # self.zero_pad(hop_length, hop_length) - # self._padded_signal_length = self.signal_length - - # return window_length, hop_length - - # def windows( - # self, window_duration: float, hop_duration: float, preprocess: bool = True - # ): - # """Generator which yields windows of specified duration from signal with a specified - # hop length. - - # Parameters - # ---------- - # window_duration : float - # Duration of every window in seconds. - # hop_duration : float - # Hop between windows in seconds. - # preprocess : bool, optional - # Whether to preprocess the signal, so that the first sample is in - # the middle of the first window, by default True - - # Yields - # ------ - # AudioSignal - # Each window is returned as an AudioSignal. - # """ - # if preprocess: - # window_length, hop_length = self._preprocess_signal_for_windowing( - # window_duration, hop_duration - # ) - - # self.audio_data = self.audio_data.reshape(-1, 1, self.signal_length) - - # for b in range(self.batch_size): - # i = 0 - # start_idx = i * hop_length - # while True: - # start_idx = i * hop_length - # i += 1 - # end_idx = start_idx + window_length - # if end_idx > self.signal_length: - # break - # yield self[b, ..., start_idx:end_idx] - - # def collect_windows( - # self, window_duration: float, hop_duration: float, preprocess: bool = True - # ): - # """Reshapes signal into windows of specified duration from signal with a specified - # hop length. Window are placed along the batch dimension. Use with - # :py:func:`audiotools.core.dsp.DSPMixin.overlap_and_add` to reconstruct the - # original signal. - - # Parameters - # ---------- - # window_duration : float - # Duration of every window in seconds. - # hop_duration : float - # Hop between windows in seconds. - # preprocess : bool, optional - # Whether to preprocess the signal, so that the first sample is in - # the middle of the first window, by default True - - # Returns - # ------- - # AudioSignal - # AudioSignal unfolded with shape ``(nb * nch * num_windows, 1, window_length)`` - # """ - # if preprocess: - # window_length, hop_length = self._preprocess_signal_for_windowing( - # window_duration, hop_duration - # ) - - # # self.audio_data: (nb, nch, nt). - # unfolded = paddle.nn.functional.unfold( - # self.audio_data.reshape(-1, 1, 1, self.signal_length), - # kernel_size=(1, window_length), - # stride=(1, hop_length), - # ) - # # unfolded: (nb * nch, window_length, num_windows). - # # -> (nb * nch * num_windows, 1, window_length) - # unfolded = unfolded.permute(0, 2, 1).reshape(-1, 1, window_length) - # self.audio_data = unfolded - # return self - - # def overlap_and_add(self, hop_duration: float): - # """Function which takes a list of windows and overlap adds them into a - # signal the same length as ``audio_signal``. - - # Parameters - # ---------- - # hop_duration : float - # How much to shift for each window - # (overlap is window_duration - hop_duration) in seconds. - - # Returns - # ------- - # AudioSignal - # overlap-and-added signal. - # """ - # hop_length = int(hop_duration * self.sample_rate) - # window_length = self.signal_length - - # nb, nch = self._original_batch_size, self._original_num_channels - - # unfolded = self.audio_data.reshape(nb * nch, -1, window_length).permute(0, 2, 1) - # folded = paddle.nn.functional.fold( - # unfolded, - # output_size=(1, self._padded_signal_length), - # kernel_size=(1, window_length), - # stride=(1, hop_length), - # ) - - # norm = paddle.ones_like(unfolded, device=unfolded.device) - # norm = paddle.nn.functional.fold( - # norm, - # output_size=(1, self._padded_signal_length), - # kernel_size=(1, window_length), - # stride=(1, hop_length), - # ) - - # folded = folded / norm - - # folded = folded.reshape(nb, nch, -1) - # self.audio_data = folded - # self.trim(hop_length, hop_length) - # return self + def _preprocess_signal_for_windowing(self, window_duration, hop_duration): + self._original_batch_size = self.batch_size + self._original_num_channels = self.num_channels + + window_length = int(window_duration * self.sample_rate) + hop_length = int(hop_duration * self.sample_rate) + + if window_length % hop_length != 0: + factor = window_length // hop_length + window_length = factor * hop_length + + self.zero_pad(hop_length, hop_length) + self._padded_signal_length = self.signal_length + + return window_length, hop_length + + def windows(self, + window_duration: float, + hop_duration: float, + preprocess: bool=True): + """Generator which yields windows of specified duration from signal with a specified + hop length. + + Parameters + ---------- + window_duration : float + Duration of every window in seconds. + hop_duration : float + Hop between windows in seconds. + preprocess : bool, optional + Whether to preprocess the signal, so that the first sample is in + the middle of the first window, by default True + + Yields + ------ + AudioSignal + Each window is returned as an AudioSignal. + """ + if preprocess: + window_length, hop_length = self._preprocess_signal_for_windowing( + window_duration, hop_duration) + + self.audio_data = self.audio_data.reshape([-1, 1, self.signal_length]) + + for b in range(self.batch_size): + i = 0 + start_idx = i * hop_length + while True: + start_idx = i * hop_length + i += 1 + end_idx = start_idx + window_length + if end_idx > self.signal_length: + break + yield self[b, ..., start_idx:end_idx] + + def collect_windows(self, + window_duration: float, + hop_duration: float, + preprocess: bool=True): + """Reshapes signal into windows of specified duration from signal with a specified + hop length. Window are placed along the batch dimension. Use with + :py:func:`audiotools.core.dsp.DSPMixin.overlap_and_add` to reconstruct the + original signal. + + Parameters + ---------- + window_duration : float + Duration of every window in seconds. + hop_duration : float + Hop between windows in seconds. + preprocess : bool, optional + Whether to preprocess the signal, so that the first sample is in + the middle of the first window, by default True + + Returns + ------- + AudioSignal + AudioSignal unfolded with shape ``(nb * nch * num_windows, 1, window_length)`` + """ + if preprocess: + window_length, hop_length = self._preprocess_signal_for_windowing( + window_duration, hop_duration) + + # self.audio_data: (nb, nch, nt). + # unfolded = paddle.nn.functional.unfold( + # self.audio_data.reshape([-1, 1, 1, self.signal_length]), + # kernel_sizes=(1, window_length), + # strides=(1, hop_length), + # ) + unfolded = _unfold( + self.audio_data.reshape([-1, 1, 1, self.signal_length]), + kernel_sizes=(1, window_length), + strides=(1, hop_length), ) + # unfolded: (nb * nch, window_length, num_windows). + # -> (nb * nch * num_windows, 1, window_length) + unfolded = unfolded.transpose([0, 2, 1]).reshape([-1, 1, window_length]) + self.audio_data = unfolded + return self + + def overlap_and_add(self, hop_duration: float): + """Function which takes a list of windows and overlap adds them into a + signal the same length as ``audio_signal``. + + Parameters + ---------- + hop_duration : float + How much to shift for each window + (overlap is window_duration - hop_duration) in seconds. + + Returns + ------- + AudioSignal + overlap-and-added signal. + """ + hop_length = int(hop_duration * self.sample_rate) + window_length = self.signal_length + + nb, nch = self._original_batch_size, self._original_num_channels + + unfolded = self.audio_data.reshape( + [nb * nch, -1, window_length]).transpose([0, 2, 1]) + # folded = paddle.nn.functional.fold( + # unfolded, + # output_sizes=(1, self._padded_signal_length), + # kernel_sizes=(1, window_length), + # strides=(1, hop_length), + # ) + folded = _fold( + unfolded, + output_sizes=(1, self._padded_signal_length), + kernel_sizes=(1, window_length), + strides=(1, hop_length), ) + + norm = paddle.ones_like(unfolded) + # norm = paddle.nn.functional.fold( + # norm, + # output_sizes=(1, self._padded_signal_length), + # kernel_sizes=(1, window_length), + # strides=(1, hop_length), + # ) + norm = _fold( + norm, + output_sizes=(1, self._padded_signal_length), + kernel_sizes=(1, window_length), + strides=(1, hop_length), ) + + folded = folded / norm + + folded = folded.reshape([nb, nch, -1]) + self.audio_data = folded + self.trim(hop_length, hop_length) + return self def low_pass(self, cutoffs: typing.Union[paddle.Tensor, np.ndarray, float], @@ -312,87 +365,92 @@ class DSPMixin: self.stft_data = mag * paddle.exp(1j * phase) return self - # def mask_low_magnitudes( - # self, db_cutoff: typing.Union[paddle.Tensor, np.ndarray, float], val: float = 0.0 - # ): - # """Mask away magnitudes below a specified threshold, which - # can be different for every item in the batch. - - # Parameters - # ---------- - # db_cutoff : typing.Union[paddle.Tensor, np.ndarray, float] - # Decibel value for which things below it will be masked away. - # val : float, optional - # Value to fill in for masked portions, by default 0.0 - - # Returns - # ------- - # AudioSignal - # Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the - # masked audio data. - # """ - # mag = self.magnitude - # log_mag = self.log_magnitude() - - # db_cutoff = util.ensure_tensor(db_cutoff, ndim=mag.ndim) - # mask = log_mag < db_cutoff - # mag = mag.masked_fill(mask, val) - - # self.magnitude = mag - # return self - - # def shift_phase(self, shift: typing.Union[paddle.Tensor, np.ndarray, float]): - # """Shifts the phase by a constant value. - - # Parameters - # ---------- - # shift : typing.Union[paddle.Tensor, np.ndarray, float] - # What to shift the phase by. - - # Returns - # ------- - # AudioSignal - # Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the - # masked audio data. - # """ - # shift = util.ensure_tensor(shift, ndim=self.phase.ndim) - # self.phase = self.phase + shift - # return self - - # def corrupt_phase(self, scale: typing.Union[paddle.Tensor, np.ndarray, float]): - # """Corrupts the phase randomly by some scaled value. - - # Parameters - # ---------- - # scale : typing.Union[paddle.Tensor, np.ndarray, float] - # Standard deviation of noise to add to the phase. - - # Returns - # ------- - # AudioSignal - # Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the - # masked audio data. - # """ - # scale = util.ensure_tensor(scale, ndim=self.phase.ndim) - # self.phase = self.phase + scale * paddle.randn_like(self.phase) - # return self - - # def preemphasis(self, coef: float = 0.85): - # """Applies pre-emphasis to audio signal. - - # Parameters - # ---------- - # coef : float, optional - # How much pre-emphasis to apply, lower values do less. 0 does nothing. - # by default 0.85 - - # Returns - # ------- - # AudioSignal - # Pre-emphasized signal. - # """ - # kernel = paddle.to_tensor([1, -coef, 0]).view(1, 1, -1).to(self.device) - # x = self.audio_data.reshape(-1, 1, self.signal_length) - # x = paddle.nn.functional.conv1d(x, kernel, padding=1) - # self.audio_data = x.reshape(*self.audio_data.shape) - # return self + def mask_low_magnitudes( + self, + db_cutoff: typing.Union[paddle.Tensor, np.ndarray, float], + val: float=0.0): + """Mask away magnitudes below a specified threshold, which + can be different for every item in the batch. + + Parameters + ---------- + db_cutoff : typing.Union[paddle.Tensor, np.ndarray, float] + Decibel value for which things below it will be masked away. + val : float, optional + Value to fill in for masked portions, by default 0.0 + + Returns + ------- + AudioSignal + Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the + masked audio data. + """ + mag = self.magnitude + log_mag = self.log_magnitude() + + db_cutoff = util.ensure_tensor(db_cutoff, ndim=mag.ndim) + mask = log_mag < db_cutoff + mag = mag.masked_fill(mask, val) + + self.magnitude = mag + return self + + def shift_phase(self, + shift: typing.Union[paddle.Tensor, np.ndarray, float]): + """Shifts the phase by a constant value. + + Parameters + ---------- + shift : typing.Union[paddle.Tensor, np.ndarray, float] + What to shift the phase by. + + Returns + ------- + AudioSignal + Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the + masked audio data. + """ + shift = util.ensure_tensor(shift, ndim=self.phase.ndim) + self.phase = self.phase + shift + return self + + def corrupt_phase(self, + scale: typing.Union[paddle.Tensor, np.ndarray, float]): + """Corrupts the phase randomly by some scaled value. + + Parameters + ---------- + scale : typing.Union[paddle.Tensor, np.ndarray, float] + Standard deviation of noise to add to the phase. + + Returns + ------- + AudioSignal + Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the + masked audio data. + """ + scale = util.ensure_tensor(scale, ndim=self.phase.ndim) + self.phase = self.phase + scale * paddle.randn( + shape=self.phase.shape, dtype=self.phase.dtype) + return self + + def preemphasis(self, coef: float=0.85): + """Applies pre-emphasis to audio signal. + + Parameters + ---------- + coef : float, optional + How much pre-emphasis to apply, lower values do less. 0 does nothing. + by default 0.85 + + Returns + ------- + AudioSignal + Pre-emphasized signal. + """ + kernel = paddle.to_tensor([1, -coef, 0]).reshape([1, 1, -1]) + x = self.audio_data.reshape([-1, 1, self.signal_length]) + x = paddle.nn.functional.conv1d( + x.astype(kernel.dtype), kernel, padding=1) + self.audio_data = x.reshape(self.audio_data.shape) + return self diff --git a/audio/audiotools/requirements.txt b/audio/audiotools/requirements.txt index def7f22cc..053954496 100644 --- a/audio/audiotools/requirements.txt +++ b/audio/audiotools/requirements.txt @@ -1,8 +1,8 @@ flatten_dict gradio IPython -librosa -markdown2 +librosa==0.8.1markdown2 +numpy==1.23.5 pyloudnorm pytest pytest-xdist diff --git a/audio/tests/audiotools/core/test_audio_signal✅.py b/audio/tests/audiotools/core/test_audio_signal✅.py index a78f2b785..4aa998bb1 100644 --- a/audio/tests/audiotools/core/test_audio_signal✅.py +++ b/audio/tests/audiotools/core/test_audio_signal✅.py @@ -13,7 +13,7 @@ from audiotools import AudioSignal def test_io(): - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" signal = AudioSignal(pathlib.Path(audio_path)) with tempfile.NamedTemporaryFile(suffix=".wav") as f: @@ -61,7 +61,7 @@ def test_io(): assert signal.audio_data.ndim == 3 assert paddle.all(signal.samples == signal.audio_data) - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" assert AudioSignal(audio_path).hash() == AudioSignal(audio_path).hash() assert AudioSignal(audio_path).hash() != AudioSignal(audio_path).normalize( -20).hash() @@ -71,7 +71,7 @@ def test_io(): def test_copy_and_clone(): - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" signal = AudioSignal(audio_path) signal.stft() signal.loudness() @@ -369,7 +369,7 @@ def test_trim(): def test_to_from_ops(): - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" signal = AudioSignal(audio_path) signal.stft() signal.loudness() @@ -384,7 +384,7 @@ def test_to_from_ops(): def test_device(): - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" signal = AudioSignal(audio_path) signal.to("cpu") @@ -397,7 +397,7 @@ def test_device(): def test_stft(window_length, hop_length, window_type): if hop_length >= window_length: hop_length = window_length // 2 - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" stft_params = audiotools.STFTParams( window_length=window_length, hop_length=hop_length, @@ -456,7 +456,7 @@ def test_stft(window_length, hop_length, window_type): def test_log_magnitude(): - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" for _ in range(10): signal = AudioSignal.excerpt(audio_path, duration=5.0) magnitude = signal.magnitude.numpy()[0, 0] @@ -474,7 +474,7 @@ def test_log_magnitude(): def test_mel_spectrogram(n_mels, window_length, hop_length, window_type): if hop_length >= window_length: hop_length = window_length // 2 - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" stft_params = audiotools.STFTParams( window_length=window_length, hop_length=hop_length, @@ -492,7 +492,7 @@ def test_mel_spectrogram(n_mels, window_length, hop_length, window_type): def test_mfcc(n_mfcc, n_mels, window_length, hop_length): if hop_length >= window_length: hop_length = window_length // 2 - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" stft_params = audiotools.STFTParams( window_length=window_length, hop_length=hop_length) for _stft_params in [None, stft_params]: diff --git a/audio/tests/audiotools/core/test_display✅.py b/audio/tests/audiotools/core/test_display✅.py new file mode 100644 index 000000000..a0feb34b1 --- /dev/null +++ b/audio/tests/audiotools/core/test_display✅.py @@ -0,0 +1,48 @@ +import sys +from pathlib import Path + +import numpy as np +sys.path.append("/home/aistudio/PaddleSpeech/audio") + +from audiotools import AudioSignal +from visualdl import LogWriter + + +def test_specshow(): + array = np.zeros((1, 16000)) + AudioSignal(array, sample_rate=16000).specshow() + AudioSignal(array, sample_rate=16000).specshow(preemphasis=True) + AudioSignal( + array, sample_rate=16000).specshow( + title="test", preemphasis=True) + AudioSignal( + array, sample_rate=16000).specshow( + format=False, preemphasis=True) + AudioSignal( + array, sample_rate=16000).specshow( + format=False, preemphasis=False, y_axis="mel") + + +def test_waveplot(): + array = np.zeros((1, 16000)) + AudioSignal(array, sample_rate=16000).waveplot() + + +def test_wavespec(): + array = np.zeros((1, 16000)) + AudioSignal(array, sample_rate=16000).wavespec() + + +def test_write_audio_to_tb(): + signal = AudioSignal("./audio/spk/f10_script4_produced.mp3", duration=5) + + Path("./scratch").mkdir(parents=True, exist_ok=True) + writer = LogWriter("./scratch/") + signal.write_audio_to_tb("tag", writer) + + +def test_save_image(): + signal = AudioSignal( + "./audio/spk/f10_script4_produced.wav", duration=10, offset=10) + Path("./scratch").mkdir(parents=True, exist_ok=True) + signal.save_image("./scratch/image.png") diff --git a/audio/tests/audiotools/core/test_dsp✅.py b/audio/tests/audiotools/core/test_dsp✅.py new file mode 100644 index 000000000..b8219c342 --- /dev/null +++ b/audio/tests/audiotools/core/test_dsp✅.py @@ -0,0 +1,178 @@ +import sys + +import numpy as np +import paddle +import pytest +sys.path.append("/home/aistudio/PaddleSpeech/audio") +from audiotools import AudioSignal +from audiotools.core.util import sample_from_dist + + +@pytest.mark.parametrize("window_duration", [0.1, 0.25, 0.5, 1.0]) +@pytest.mark.parametrize("sample_rate", [8000, 16000, 22050, 44100]) +@pytest.mark.parametrize("duration", [0.5, 1.0, 2.0, 10.0]) +def test_overlap_add(duration, sample_rate, window_duration): + np.random.seed(0) + if duration > window_duration: + spk_signal = AudioSignal.batch([ + AudioSignal.excerpt( + "./audio/spk/f10_script4_produced.wav", duration=duration) + for _ in range(16) + ]) + spk_signal.resample(sample_rate) + + noise = paddle.randn([16, 1, int(duration * sample_rate)]) + nz_signal = AudioSignal(noise, sample_rate=sample_rate) + + def _test(signal): + hop_duration = window_duration / 2 + windowed_signal = signal.deepcopy().collect_windows(window_duration, + hop_duration) + recombined = windowed_signal.overlap_and_add(hop_duration) + + assert recombined == signal + assert np.allclose(recombined.audio_data, signal.audio_data, 1e-3) + + _test(nz_signal) + _test(spk_signal) + + +@pytest.mark.parametrize("window_duration", [0.1, 0.25, 0.5, 1.0]) +@pytest.mark.parametrize("sample_rate", [8000, 16000, 22050, 44100]) +@pytest.mark.parametrize("duration", [0.5, 1.0, 2.0, 10.0]) +def test_inplace_overlap_add(duration, sample_rate, window_duration): + np.random.seed(0) + if duration > window_duration: + spk_signal = AudioSignal.batch([ + AudioSignal.excerpt( + "./audio/spk/f10_script4_produced.wav", duration=duration) + for _ in range(16) + ]) + spk_signal.resample(sample_rate) + + noise = paddle.randn([16, 1, int(duration * sample_rate)]) + nz_signal = AudioSignal(noise, sample_rate=sample_rate) + + def _test(signal): + hop_duration = window_duration / 2 + windowed_signal = signal.deepcopy().collect_windows(window_duration, + hop_duration) + # Compare in-place with unfold results + for i, window in enumerate( + signal.deepcopy().windows(window_duration, hop_duration)): + assert np.allclose(window.audio_data, + windowed_signal.audio_data[i]) + + _test(nz_signal) + _test(spk_signal) + + +def test_low_pass(): + sample_rate = 44100 + f = 440 + t = paddle.arange(0, 1, 1 / sample_rate) + sine_wave = paddle.sin(2 * np.pi * f * t) + window = AudioSignal.get_window("hann", sine_wave.shape[-1]) + sine_wave = sine_wave * window + signal = AudioSignal(sine_wave.unsqueeze(0), sample_rate=sample_rate) + out = signal.deepcopy().low_pass(220) + assert out.audio_data.abs().max() < 1e-4 + + out = signal.deepcopy().low_pass(880) + assert (out - signal).audio_data.abs().max() < 1e-3 + + batch = AudioSignal.batch( + [signal.deepcopy(), signal.deepcopy(), signal.deepcopy()]) + + cutoffs = [220, 880, 220] + out = batch.deepcopy().low_pass(cutoffs) + + assert out.audio_data[0].abs().max() < 1e-4 + assert out.audio_data[2].abs().max() < 1e-4 + assert (out - batch).audio_data[1].abs().max() < 1e-3 + + +def test_high_pass(): + sample_rate = 44100 + f = 440 + t = paddle.arange(0, 1, 1 / sample_rate) + sine_wave = paddle.sin(2 * np.pi * f * t) + window = AudioSignal.get_window("hann", sine_wave.shape[-1]) + sine_wave = sine_wave * window + signal = AudioSignal(sine_wave.unsqueeze(0), sample_rate=sample_rate) + out = signal.deepcopy().high_pass(220) + assert (signal - out).audio_data.abs().max() < 1e-4 + + +def test_mask_frequencies(): + sample_rate = 44100 + fs = paddle.to_tensor([500.0, 2000.0, 8000.0, 32000.0])[None] + t = paddle.arange(0, 1, 1 / sample_rate)[:, None] + sine_wave = paddle.sin(2 * np.pi * t @ fs).sum(axis=-1) + sine_wave = AudioSignal(sine_wave, sample_rate) + masked_sine_wave = sine_wave.mask_frequencies(fmin_hz=1500, fmax_hz=10000) + + fs2 = paddle.to_tensor([500.0, 32000.0])[None] + sine_wave2 = paddle.sin(2 * np.pi * t @ fs).sum(axis=-1) + sine_wave2 = AudioSignal(sine_wave2, sample_rate) + + assert paddle.allclose(masked_sine_wave.audio_data, sine_wave2.audio_data) + + +def test_mask_timesteps(): + sample_rate = 44100 + f = 440 + t = paddle.linspace(0, 1, sample_rate) + sine_wave = paddle.sin(2 * np.pi * f * t) + sine_wave = AudioSignal(sine_wave, sample_rate) + + masked_sine_wave = sine_wave.mask_timesteps(tmin_s=0.25, tmax_s=0.75) + masked_sine_wave.istft() + + mask = ((0.3 < t) & (t < 0.7))[None, None] + assert paddle.allclose( + masked_sine_wave.audio_data[mask], + paddle.zeros_like(masked_sine_wave.audio_data[mask]), ) + + +def test_shift_phase(): + sample_rate = 44100 + f = 440 + t = paddle.linspace(0, 1, sample_rate) + sine_wave = paddle.sin(2 * np.pi * f * t) + sine_wave = AudioSignal(sine_wave, sample_rate) + sine_wave2 = sine_wave.clone() + + shifted_sine_wave = sine_wave.shift_phase(np.pi) + shifted_sine_wave.istft() + + sine_wave2.phase = sine_wave2.phase + np.pi + sine_wave2.istft() + + assert paddle.allclose(shifted_sine_wave.audio_data, sine_wave2.audio_data) + + +def test_corrupt_phase(): + sample_rate = 44100 + f = 440 + t = paddle.linspace(0, 1, sample_rate) + sine_wave = paddle.sin(2 * np.pi * f * t) + sine_wave = AudioSignal(sine_wave, sample_rate) + sine_wave2 = sine_wave.clone() + + shifted_sine_wave = sine_wave.corrupt_phase(scale=np.pi) + shifted_sine_wave.istft() + + assert (sine_wave2.phase - shifted_sine_wave.phase).abs().mean() > 0.0 + assert ((sine_wave2.phase - shifted_sine_wave.phase).std() / np.pi) < 1.0 + + +def test_preemphasis(): + x = AudioSignal.excerpt("./audio/spk/f10_script4_produced.wav", duration=5) + import matplotlib.pyplot as plt + + x.specshow(preemphasis=False) + + x.specshow(preemphasis=True) + + x.preemphasis() diff --git a/audio/tests/audiotools/core/test_effects✅.py b/audio/tests/audiotools/core/test_effects✅.py index e798f06a6..e15622a23 100644 --- a/audio/tests/audiotools/core/test_effects✅.py +++ b/audio/tests/audiotools/core/test_effects✅.py @@ -8,7 +8,7 @@ from audiotools import AudioSignal def test_normalize(): - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" signal = AudioSignal(audio_path, offset=10, duration=10) signal = signal.normalize() assert np.allclose(signal.loudness(), -24, atol=1e-1) @@ -35,7 +35,7 @@ def test_normalize(): def test_volume_change(): - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" signal = AudioSignal(audio_path, offset=10, duration=10) boost = 3 @@ -50,10 +50,10 @@ def test_volume_change(): def test_mix(): - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" spk = AudioSignal(audio_path, offset=10, duration=10) - audio_path = "tests/audiotools/audio/nz/f5_script2_ipad_balcony1_room_tone.wav" + audio_path = "./audio/nz/f5_script2_ipad_balcony1_room_tone.wav" nz = AudioSignal(audio_path, offset=10, duration=10) spk.deepcopy().mix(nz, snr=-10) @@ -61,10 +61,10 @@ def test_mix(): assert np.allclose(snr, -10, atol=1) # Test in batch - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" spk = AudioSignal(audio_path, offset=10, duration=10) - audio_path = "tests/audiotools/audio/nz/f5_script2_ipad_balcony1_room_tone.wav" + audio_path = "./audio/nz/f5_script2_ipad_balcony1_room_tone.wav" nz = AudioSignal(audio_path, offset=10, duration=10) batch_size = 4 @@ -86,7 +86,7 @@ def test_mix(): def test_convolve(): np.random.seed(6) # Found a failing seed - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" spk = AudioSignal(audio_path, offset=10, duration=10) impulse = np.zeros((1, 16000), dtype="float32") @@ -106,7 +106,7 @@ def test_convolve(): assert convolved == spk_batch # Short duration - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" spk = AudioSignal(audio_path, offset=10, duration=0.1) impulse = np.zeros((1, 16000), dtype="float32") @@ -128,14 +128,14 @@ def test_convolve(): def test_pipeline(): # An actual IR, no batching - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" spk = AudioSignal(audio_path, offset=10, duration=5) - audio_path = "tests/audiotools/audio/ir/h179_Bar_1txts.wav" + audio_path = "./audio/ir/h179_Bar_1txts.wav" ir = AudioSignal(audio_path) spk.deepcopy().convolve(ir) - audio_path = "tests/audiotools/audio/nz/f5_script2_ipad_balcony1_room_tone.wav" + audio_path = "./audio/nz/f5_script2_ipad_balcony1_room_tone.wav" nz = AudioSignal(audio_path, offset=10, duration=5) batch_size = 16 @@ -146,7 +146,7 @@ def test_pipeline(): # def test_codec(): -# audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" +# audio_path = "./audio/spk/f10_script4_produced.wav" # spk = AudioSignal(audio_path, offset=10, duration=10) # with pytest.raises(ValueError): @@ -156,7 +156,7 @@ def test_pipeline(): # out = spk.deepcopy().apply_codec("8-bit") # def test_pitch_shift(): -# audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" +# audio_path = "./audio/spk/f10_script4_produced.wav" # spk = AudioSignal(audio_path, offset=10, duration=1) # single = spk.deepcopy().pitch_shift(5) @@ -169,7 +169,7 @@ def test_pipeline(): # assert np.allclose(batched[0].audio_data, single[0].audio_data) # def test_time_stretch(): -# audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" +# audio_path = "./audio/spk/f10_script4_produced.wav" # spk = AudioSignal(audio_path, offset=10, duration=1) # single = spk.deepcopy().time_stretch(0.8) @@ -184,7 +184,7 @@ def test_pipeline(): @pytest.mark.parametrize("n_bands", [1, 2, 4, 8, 12, 16]) def test_mel_filterbank(n_bands): - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" spk = AudioSignal(audio_path, offset=10, duration=1) fbank = spk.deepcopy().mel_filterbank(n_bands) @@ -192,8 +192,7 @@ def test_mel_filterbank(n_bands): # Check if it works in batches. spk_batch = AudioSignal.batch([ - AudioSignal.excerpt( - "tests/audiotools/audio/spk/f10_script4_produced.wav", duration=2) + AudioSignal.excerpt("./audio/spk/f10_script4_produced.wav", duration=2) for _ in range(16) ]) fbank = spk_batch.deepcopy().mel_filterbank(n_bands) @@ -203,7 +202,7 @@ def test_mel_filterbank(n_bands): @pytest.mark.parametrize("n_bands", [1, 2, 4, 8, 12, 16]) def test_equalizer(n_bands): - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" spk = AudioSignal(audio_path, offset=10, duration=10) db = -3 + 1 * paddle.rand([n_bands]) @@ -212,15 +211,14 @@ def test_equalizer(n_bands): db = -3 + 1 * np.random.rand(n_bands) spk.deepcopy().equalizer(db) - audio_path = "tests/audiotools/audio/ir/h179_Bar_1txts.wav" + audio_path = "./audio/ir/h179_Bar_1txts.wav" ir = AudioSignal(audio_path) db = -3 + 1 * paddle.rand([n_bands]) spk.deepcopy().convolve(ir.equalizer(db)) spk_batch = AudioSignal.batch([ - AudioSignal.excerpt( - "tests/audiotools/audio/spk/f10_script4_produced.wav", duration=2) + AudioSignal.excerpt("./audio/spk/f10_script4_produced.wav", duration=2) for _ in range(16) ]) @@ -231,13 +229,12 @@ def test_equalizer(n_bands): def test_clip_distortion(): - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" spk = AudioSignal(audio_path, offset=10, duration=2) clipped = spk.deepcopy().clip_distortion(0.05) spk_batch = AudioSignal.batch([ - AudioSignal.excerpt( - "tests/audiotools/audio/spk/f10_script4_produced.wav", duration=2) + AudioSignal.excerpt("./audio/spk/f10_script4_produced.wav", duration=2) for _ in range(16) ]) percs = paddle.to_tensor(np.random.uniform(size=(16, ))).astype("float32") @@ -249,7 +246,7 @@ def test_clip_distortion(): @pytest.mark.parametrize("quant_ch", [2, 4, 8, 16, 32, 64, 128]) def test_quantization(quant_ch): - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" spk = AudioSignal(audio_path, offset=10, duration=2) quantized = spk.deepcopy().quantization(quant_ch) @@ -260,8 +257,7 @@ def test_quantization(quant_ch): assert found_quant_ch <= quant_ch spk_batch = AudioSignal.batch([ - AudioSignal.excerpt( - "tests/audiotools/audio/spk/f10_script4_produced.wav", duration=2) + AudioSignal.excerpt("./audio/spk/f10_script4_produced.wav", duration=2) for _ in range(16) ]) @@ -277,7 +273,7 @@ def test_quantization(quant_ch): @pytest.mark.parametrize("quant_ch", [2, 4, 8, 16, 32, 64, 128]) def test_mulaw_quantization(quant_ch): - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" spk = AudioSignal(audio_path, offset=10, duration=2) quantized = spk.deepcopy().mulaw_quantization(quant_ch) @@ -288,8 +284,7 @@ def test_mulaw_quantization(quant_ch): assert found_quant_ch <= quant_ch spk_batch = AudioSignal.batch([ - AudioSignal.excerpt( - "tests/audiotools/audio/spk/f10_script4_produced.wav", duration=2) + AudioSignal.excerpt("./audio/spk/f10_script4_produced.wav", duration=2) for _ in range(16) ]) @@ -304,7 +299,7 @@ def test_mulaw_quantization(quant_ch): def test_impulse_response_augmentation(): - audio_path = "tests/audiotools/audio/ir/h179_Bar_1txts.wav" + audio_path = "./audio/ir/h179_Bar_1txts.wav" batch_size = 16 ir = AudioSignal(audio_path) ir_batch = AudioSignal.batch([ir for _ in range(batch_size)]) @@ -330,8 +325,8 @@ def test_impulse_response_augmentation(): def test_apply_ir(): - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" - ir_path = "tests/audiotools/audio/ir/h179_Bar_1txts.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" + ir_path = "./audio/ir/h179_Bar_1txts.wav" spk = AudioSignal(audio_path, offset=10, duration=2) ir = AudioSignal(ir_path) diff --git a/audio/tests/audiotools/core/test_grad✅.py b/audio/tests/audiotools/core/test_grad✅.py index d5ef3f307..57d4df646 100644 --- a/audio/tests/audiotools/core/test_grad✅.py +++ b/audio/tests/audiotools/core/test_grad✅.py @@ -9,8 +9,8 @@ from audiotools import AudioSignal def test_audio_grad(): - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" - ir_path = "tests/audiotools/audio/ir/h179_Bar_1txts.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" + ir_path = "./audio/ir/h179_Bar_1txts.wav" def _test_audio_grad(attr: str, target=True, kwargs: dict={}): signal = AudioSignal(audio_path) @@ -153,7 +153,7 @@ def test_audio_grad(): def test_batch_grad(): - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" signal = AudioSignal(audio_path) signal.audio_data.stop_gradient = False diff --git a/audio/tests/audiotools/core/test_highpass✅.py b/audio/tests/audiotools/core/test_highpass✅.py index 1aa1cf171..acea78330 100644 --- a/audio/tests/audiotools/core/test_highpass✅.py +++ b/audio/tests/audiotools/core/test_highpass✅.py @@ -97,75 +97,5 @@ class TestHighPassFilters(_BaseTest): self.assertSimilar(y, y2, x) -# class TestBandPassFilters(_BaseTest): - -# def setUp(self): -# paddle.seed(1234) -# random.seed(1234) - -# def test_keep_or_kill(self): -# for _ in range(10): -# freq = random.uniform(0.01, 0.4) -# sr = 1024 -# tone = pure_tone(freq * sr, sr=sr, dur=10) - -# # For this test we accept 5% tolerance in amplitude, or -26dB in power. -# tol = 5 -# zeros = 16 - -# y_pass = filters.bandpass_filter(tone, 0.9 * freq, 1.1 * freq, zeros=zeros) -# self.assertSimilar(y_pass, tone, tone, f"freq={freq}, pass", tol=tol) - -# y_killed = filters.bandpass_filter(tone, 1.1 * freq, 1.2 * freq, zeros=zeros) -# self.assertSimilar(y_killed, 0 * tone, tone, f"freq={freq}, kill", tol=tol) - -# y_killed = filters.bandpass_filter(tone, 0.8 * freq, 0.9 * freq, zeros=zeros) -# self.assertSimilar(y_killed, 0 * tone, tone, f"freq={freq}, kill", tol=tol) - -# def test_fft_nofft(self): -# for _ in range(10): -# x = paddle.randn([1024]) -# freq = random.uniform(0.01, 0.5) -# freq2 = random.uniform(freq, 0.5) -# y_fft = filters.bandpass_filter(x, freq, freq2, fft=True) -# y_ref = filters.bandpass_filter(x, freq, freq2, fft=False) -# self.assertSimilar(y_fft, y_ref, x, f"freq={freq}", tol=0.01) - -# def test_constant(self): -# x = paddle.ones([2048]) -# for zeros in [4, 10]: -# for freq in [0.01, 0.1]: -# y = filters.bandpass_filter(x, freq, 1.2 * freq, zeros=zeros) -# self.assertLessEqual(y.abs().mean(), 1e-6, (zeros, freq)) - -# def test_stride(self): -# x = paddle.randn([1024]) - -# y = filters.bandpass_filter(x, 0.1, 0.2, stride=1)[::3] -# y2 = filters.bandpass_filter(x, 0.1, 0.2, stride=3) - -# self.assertEqual(y.shape, y2.shape) -# self.assertSimilar(y, y2, x) - -# y = filters.bandpass_filter(x, 0.1, 0.2, stride=1, pad=False)[::3] -# y2 = filters.bandpass_filter(x, 0.1, 0.2, stride=3, pad=False) - -# self.assertEqual(y.shape, y2.shape) -# self.assertSimilar(y, y2, x) - -# def test_same_as_highpass(self): -# x = paddle.randn([1024]) - -# y_ref = highpass_filter(x, 0.2) -# y = filters.bandpass_filter(x, 0.2, 0.5) -# self.assertSimilar(y, y_ref, x) - -# def test_same_as_lowpass(self): -# x = paddle.randn([1024]) - -# y_ref = filters.lowpass_filter(x, 0.2) -# y = filters.bandpass_filter(x, 0.0, 0.2) -# self.assertSimilar(y, y_ref, x) - if __name__ == "__main__": unittest.main() diff --git a/audio/tests/audiotools/core/test_loudness✅.py b/audio/tests/audiotools/core/test_loudness✅.py index f9cbc77ac..308406652 100644 --- a/audio/tests/audiotools/core/test_loudness✅.py +++ b/audio/tests/audiotools/core/test_loudness✅.py @@ -13,7 +13,7 @@ ATOL = 1e-1 def test_loudness_against_pyln(): - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" signal = AudioSignal(audio_path, offset=5, duration=10) signal_loudness = signal.loudness() @@ -24,7 +24,7 @@ def test_loudness_against_pyln(): def test_loudness_short(): - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" signal = AudioSignal(audio_path, offset=10, duration=0.25) signal_loudness = signal.loudness() @@ -58,7 +58,7 @@ def test_batch_loudness(): # Tests below are copied from pyloudnorm def test_integrated_loudness(): - data, rate = sf.read("tests/audiotools/audio/loudness/sine_1000.wav") + data, rate = sf.read("./audio/loudness/sine_1000.wav") meter = Meter(rate) loudness = meter(data) @@ -67,8 +67,7 @@ def test_integrated_loudness(): def test_rel_gate_test(): - data, rate = sf.read( - "tests/audiotools/audio/loudness/1770-2_Comp_RelGateTest.wav") + data, rate = sf.read("./audio/loudness/1770-2_Comp_RelGateTest.wav") meter = Meter(rate) loudness = meter.integrated_loudness(data) @@ -77,8 +76,7 @@ def test_rel_gate_test(): def test_abs_gate_test(): - data, rate = sf.read( - "tests/audiotools/audio/loudness/1770-2_Comp_AbsGateTest.wav") + data, rate = sf.read("./audio/loudness/1770-2_Comp_AbsGateTest.wav") meter = Meter(rate) loudness = meter.integrated_loudness(data) @@ -87,8 +85,7 @@ def test_abs_gate_test(): def test_24LKFS_25Hz_2ch(): - data, rate = sf.read( - "tests/audiotools/audio/loudness/1770-2_Comp_24LKFS_25Hz_2ch.wav") + data, rate = sf.read("./audio/loudness/1770-2_Comp_24LKFS_25Hz_2ch.wav") meter = Meter(rate) loudness = meter.integrated_loudness(data) @@ -97,8 +94,7 @@ def test_24LKFS_25Hz_2ch(): def test_24LKFS_100Hz_2ch(): - data, rate = sf.read( - "tests/audiotools/audio/loudness/1770-2_Comp_24LKFS_100Hz_2ch.wav") + data, rate = sf.read("./audio/loudness/1770-2_Comp_24LKFS_100Hz_2ch.wav") meter = Meter(rate) loudness = meter.integrated_loudness(data) @@ -107,8 +103,7 @@ def test_24LKFS_100Hz_2ch(): def test_24LKFS_500Hz_2ch(): - data, rate = sf.read( - "tests/audiotools/audio/loudness/1770-2_Comp_24LKFS_500Hz_2ch.wav") + data, rate = sf.read("./audio/loudness/1770-2_Comp_24LKFS_500Hz_2ch.wav") meter = Meter(rate) loudness = meter.integrated_loudness(data) @@ -117,8 +112,7 @@ def test_24LKFS_500Hz_2ch(): def test_24LKFS_1000Hz_2ch(): - data, rate = sf.read( - "tests/audiotools/audio/loudness/1770-2_Comp_24LKFS_1000Hz_2ch.wav") + data, rate = sf.read("./audio/loudness/1770-2_Comp_24LKFS_1000Hz_2ch.wav") meter = Meter(rate) loudness = meter.integrated_loudness(data) @@ -127,8 +121,7 @@ def test_24LKFS_1000Hz_2ch(): def test_24LKFS_2000Hz_2ch(): - data, rate = sf.read( - "tests/audiotools/audio/loudness/1770-2_Comp_24LKFS_2000Hz_2ch.wav") + data, rate = sf.read("./audio/loudness/1770-2_Comp_24LKFS_2000Hz_2ch.wav") meter = Meter(rate) loudness = meter.integrated_loudness(data) @@ -137,8 +130,7 @@ def test_24LKFS_2000Hz_2ch(): def test_24LKFS_10000Hz_2ch(): - data, rate = sf.read( - "tests/audiotools/audio/loudness/1770-2_Comp_24LKFS_10000Hz_2ch.wav") + data, rate = sf.read("./audio/loudness/1770-2_Comp_24LKFS_10000Hz_2ch.wav") meter = Meter(rate) loudness = meter.integrated_loudness(data) @@ -147,8 +139,7 @@ def test_24LKFS_10000Hz_2ch(): def test_23LKFS_25Hz_2ch(): - data, rate = sf.read( - "tests/audiotools/audio/loudness/1770-2_Comp_23LKFS_25Hz_2ch.wav") + data, rate = sf.read("./audio/loudness/1770-2_Comp_23LKFS_25Hz_2ch.wav") meter = Meter(rate) loudness = meter.integrated_loudness(data) @@ -157,8 +148,7 @@ def test_23LKFS_25Hz_2ch(): def test_23LKFS_100Hz_2ch(): - data, rate = sf.read( - "tests/audiotools/audio/loudness/1770-2_Comp_23LKFS_100Hz_2ch.wav") + data, rate = sf.read("./audio/loudness/1770-2_Comp_23LKFS_100Hz_2ch.wav") meter = Meter(rate) loudness = meter.integrated_loudness(data) @@ -167,8 +157,7 @@ def test_23LKFS_100Hz_2ch(): def test_23LKFS_500Hz_2ch(): - data, rate = sf.read( - "tests/audiotools/audio/loudness/1770-2_Comp_23LKFS_500Hz_2ch.wav") + data, rate = sf.read("./audio/loudness/1770-2_Comp_23LKFS_500Hz_2ch.wav") meter = Meter(rate) loudness = meter.integrated_loudness(data) @@ -177,8 +166,7 @@ def test_23LKFS_500Hz_2ch(): def test_23LKFS_1000Hz_2ch(): - data, rate = sf.read( - "tests/audiotools/audio/loudness/1770-2_Comp_23LKFS_1000Hz_2ch.wav") + data, rate = sf.read("./audio/loudness/1770-2_Comp_23LKFS_1000Hz_2ch.wav") meter = Meter(rate) loudness = meter.integrated_loudness(data) @@ -187,8 +175,7 @@ def test_23LKFS_1000Hz_2ch(): def test_23LKFS_2000Hz_2ch(): - data, rate = sf.read( - "tests/audiotools/audio/loudness/1770-2_Comp_23LKFS_2000Hz_2ch.wav") + data, rate = sf.read("./audio/loudness/1770-2_Comp_23LKFS_2000Hz_2ch.wav") meter = Meter(rate) loudness = meter.integrated_loudness(data) @@ -197,8 +184,7 @@ def test_23LKFS_2000Hz_2ch(): def test_23LKFS_10000Hz_2ch(): - data, rate = sf.read( - "tests/audiotools/audio/loudness/1770-2_Comp_23LKFS_10000Hz_2ch.wav") + data, rate = sf.read("./audio/loudness/1770-2_Comp_23LKFS_10000Hz_2ch.wav") meter = Meter(rate) loudness = meter.integrated_loudness(data) @@ -208,7 +194,7 @@ def test_23LKFS_10000Hz_2ch(): def test_18LKFS_frequency_sweep(): data, rate = sf.read( - "tests/audiotools/audio/loudness/1770-2_Comp_18LKFS_FrequencySweep.wav") + "./audio/loudness/1770-2_Comp_18LKFS_FrequencySweep.wav") meter = Meter(rate) loudness = meter.integrated_loudness(data) @@ -218,7 +204,7 @@ def test_18LKFS_frequency_sweep(): def test_conf_stereo_vinL_R_23LKFS(): data, rate = sf.read( - "tests/audiotools/audio/loudness/1770-2_Conf_Stereo_VinL+R-23LKFS.wav") + "./audio/loudness/1770-2_Conf_Stereo_VinL+R-23LKFS.wav") meter = Meter(rate) loudness = meter.integrated_loudness(data) @@ -228,8 +214,7 @@ def test_conf_stereo_vinL_R_23LKFS(): def test_conf_monovoice_music_24LKFS(): data, rate = sf.read( - "tests/audiotools/audio/loudness/1770-2_Conf_Mono_Voice+Music-24LKFS.wav" - ) + "./audio/loudness/1770-2_Conf_Mono_Voice+Music-24LKFS.wav") meter = Meter(rate) loudness = meter.integrated_loudness(data) @@ -239,8 +224,7 @@ def test_conf_monovoice_music_24LKFS(): def conf_monovoice_music_24LKFS(): data, rate = sf.read( - "tests/audiotools/audio/loudness/1770-2_Conf_Mono_Voice+Music-24LKFS.wav" - ) + "./audio/loudness/1770-2_Conf_Mono_Voice+Music-24LKFS.wav") meter = Meter(rate) loudness = meter.integrated_loudness(data) @@ -250,8 +234,7 @@ def conf_monovoice_music_24LKFS(): def test_conf_monovoice_music_23LKFS(): data, rate = sf.read( - "tests/audiotools/audio/loudness/1770-2_Conf_Mono_Voice+Music-23LKFS.wav" - ) + "./audio/loudness/1770-2_Conf_Mono_Voice+Music-23LKFS.wav") meter = Meter(rate) loudness = meter.integrated_loudness(data) @@ -266,7 +249,7 @@ def test_fir_accuracy(): transforms.HighPass(prob=0.5), transforms.Equalizer(prob=0.5), prob=0.5, ) - loader = datasets.AudioLoader(sources=["tests/audiotools/audio/spk.csv"]) + loader = datasets.AudioLoader(sources=["./audio/spk.csv"]) dataset = datasets.AudioDataset( loader, 44100, diff --git a/audio/tests/audiotools/core/test_util✅.py b/audio/tests/audiotools/core/test_util✅.py index 42feeb100..e442a16c0 100644 --- a/audio/tests/audiotools/core/test_util✅.py +++ b/audio/tests/audiotools/core/test_util✅.py @@ -66,8 +66,7 @@ def test_find_audio(): assert not audio_files # Make sure it works with single audio files - audio_files = util.find_audio( - "tests/audiotools/audio/spk//f10_script4_produced.wav") + audio_files = util.find_audio("./audio/spk//f10_script4_produced.wav") # Make sure it works with globs audio_files = util.find_audio("tests/**/*.wav") diff --git a/audio/tests/audiotools/data/test_datasets✅.py b/audio/tests/audiotools/data/test_datasets✅.py index 6412b04f6..61ca94170 100644 --- a/audio/tests/audiotools/data/test_datasets✅.py +++ b/audio/tests/audiotools/data/test_datasets✅.py @@ -45,7 +45,7 @@ def test_audio_dataset(): tfm.Silence(prob=0.5), ], ) loader = audiotools.data.datasets.AudioLoader( - sources=["tests/audiotools/audio/spk.csv"], + sources=["./audio/spk.csv"], transform=transform, ) dataset = audiotools.data.datasets.AudioDataset( loader, @@ -161,11 +161,10 @@ def test_loader_out_of_range(): def test_dataset_pipeline(): transform = tfm.Compose([ - tfm.RoomImpulseResponse(sources=["tests/audiotools/audio/irs.csv"]), - tfm.BackgroundNoise(sources=["tests/audiotools/audio/noises.csv"]), + tfm.RoomImpulseResponse(sources=["./audio/irs.csv"]), + tfm.BackgroundNoise(sources=["./audio/noises.csv"]), ]) - loader = audiotools.data.datasets.AudioLoader( - sources=["tests/audiotools/audio/spk.csv"]) + loader = audiotools.data.datasets.AudioLoader(sources=["./audio/spk.csv"]) dataset = audiotools.data.datasets.AudioDataset( loader, 44100, diff --git a/audio/tests/audiotools/data/test_preprocess✅.py b/audio/tests/audiotools/data/test_preprocess✅.py index db038a593..d283f3712 100644 --- a/audio/tests/audiotools/data/test_preprocess✅.py +++ b/audio/tests/audiotools/data/test_preprocess✅.py @@ -12,13 +12,11 @@ from audiotools.data import preprocess def test_create_csv(): with tempfile.NamedTemporaryFile(suffix=".csv") as f: preprocess.create_csv( - find_audio("./tests/audiotools/audio/spk", ext=["wav"]), - f.name, - loudness=True) + find_audio("././audio/spk", ext=["wav"]), f.name, loudness=True) def test_create_csv_with_empty_rows(): - audio_files = find_audio("./tests/audiotools/audio/spk", ext=["wav"]) + audio_files = find_audio("././audio/spk", ext=["wav"]) audio_files.insert(0, "") audio_files.insert(2, "") diff --git a/audio/tests/audiotools/data/test_transforms✅.py b/audio/tests/audiotools/data/test_transforms✅.py index 0ec07f8f6..add6a80c8 100644 --- a/audio/tests/audiotools/data/test_transforms✅.py +++ b/audio/tests/audiotools/data/test_transforms✅.py @@ -49,13 +49,13 @@ def test_transform(transform_name): kwargs = {} if transform_name == "BackgroundNoise": - kwargs["sources"] = ["tests/audiotools/audio/noises.csv"] + kwargs["sources"] = ["./audio/noises.csv"] if transform_name == "RoomImpulseResponse": - kwargs["sources"] = ["tests/audiotools/audio/irs.csv"] + kwargs["sources"] = ["./audio/irs.csv"] if transform_name == "CrossTalk": - kwargs["sources"] = ["tests/audiotools/audio/spk.csv"] + kwargs["sources"] = ["./audio/spk.csv"] - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" signal = AudioSignal(audio_path, offset=10, duration=2) signal.metadata["loudness"] = AudioSignal( audio_path).ffmpeg_loudness().item() @@ -102,12 +102,12 @@ def test_transform(transform_name): def test_compose_basic(): seed = 0 - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" signal = AudioSignal(audio_path, offset=10, duration=2) transform = tfm.Compose( [ - tfm.RoomImpulseResponse(sources=["tests/audiotools/audio/irs.csv"]), - tfm.BackgroundNoise(sources=["tests/audiotools/audio/noises.csv"]), + tfm.RoomImpulseResponse(sources=["./audio/irs.csv"]), + tfm.BackgroundNoise(sources=["./audio/noises.csv"]), ], ) kwargs = transform.instantiate(seed, signal) @@ -143,7 +143,7 @@ def test_compose_with_duplicate_transforms(): full_mul = np.prod(muls) kwargs = transform.instantiate(0) - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" signal = AudioSignal(audio_path, offset=10, duration=2) output = transform(signal.clone(), **kwargs) @@ -162,7 +162,7 @@ def test_nested_compose(): full_mul = np.prod(muls) kwargs = transform.instantiate(0) - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" signal = AudioSignal(audio_path, offset=10, duration=2) output = transform(signal.clone(), **kwargs) @@ -176,7 +176,7 @@ def test_compose_filtering(): transform = tfm.Compose([MulTransform(x, name=str(x)) for x in muls]) kwargs = transform.instantiate(0) - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" signal = AudioSignal(audio_path, offset=10, duration=2) for s in range(len(muls)): @@ -199,7 +199,7 @@ def test_sequential_compose(): full_mul = np.prod(muls) kwargs = transform.instantiate(0) - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" signal = AudioSignal(audio_path, offset=10, duration=2) output = transform(signal.clone(), **kwargs) @@ -210,11 +210,11 @@ def test_sequential_compose(): def test_choose_basic(): seed = 0 - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" signal = AudioSignal(audio_path, offset=10, duration=2) transform = tfm.Choose([ - tfm.RoomImpulseResponse(sources=["tests/audiotools/audio/irs.csv"]), - tfm.BackgroundNoise(sources=["tests/audiotools/audio/noises.csv"]), + tfm.RoomImpulseResponse(sources=["./audio/irs.csv"]), + tfm.BackgroundNoise(sources=["./audio/noises.csv"]), ]) kwargs = transform.instantiate(seed, signal) @@ -251,7 +251,7 @@ def test_choose_basic(): def test_choose_weighted(): seed = 0 - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" transform = tfm.Choose( [ MulTransform(0.0), @@ -277,7 +277,7 @@ def test_choose_weighted(): def test_choose_with_compose(): - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" signal = AudioSignal(audio_path, offset=10, duration=2) transform = tfm.Choose([ @@ -296,7 +296,7 @@ def test_choose_with_compose(): def test_repeat(): seed = 0 - audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + audio_path = "./audio/spk/f10_script4_produced.wav" signal = AudioSignal(audio_path, offset=10, duration=2) kwargs = {} @@ -356,7 +356,7 @@ class DummyData(paddle.io.Dataset): def test_masking(): - dataset = DummyData("tests/audiotools/audio/spk/f10_script4_produced.wav") + dataset = DummyData("./audio/spk/f10_script4_produced.wav") dataloader = paddle.io.DataLoader( dataset, batch_size=16, @@ -385,8 +385,7 @@ def test_nested_masking(): ], prob=0.9, ) - loader = audiotools.data.datasets.AudioLoader( - sources=["tests/audiotools/audio/spk.csv"]) + loader = audiotools.data.datasets.AudioLoader(sources=["./audio/spk.csv"]) dataset = audiotools.data.datasets.AudioDataset( loader, 44100, diff --git a/audio/tests/audiotools/test_audiotools.sh b/audio/tests/audiotools/test_audiotools.sh new file mode 100644 index 000000000..32653174e --- /dev/null +++ b/audio/tests/audiotools/test_audiotools.sh @@ -0,0 +1,4 @@ +python -m pip install -r ../audiotools/requirements.txt +# wget -P ./test_data https://paddlespeech.bj.bcebos.com/datasets/unit_test/asr/static_ds2online_inputs.pickle +# wget +find . -name "*✅.py" | xargs python -m pytest \ No newline at end of file diff --git a/audio/tests/audiotools/test_post✅.py b/audio/tests/audiotools/test_post✅.py index 6bf1cb4bd..f02f18e83 100644 --- a/audio/tests/audiotools/test_post✅.py +++ b/audio/tests/audiotools/test_post✅.py @@ -13,8 +13,7 @@ def test_audio_table(): audio_dict = {} audio_dict["inputs"] = [ - AudioSignal.excerpt( - "tests/audiotools/audio/spk/f10_script4_produced.wav", duration=5) + AudioSignal.excerpt("./audio/spk/f10_script4_produced.wav", duration=5) for _ in range(3) ] audio_dict["outputs"] = [] diff --git a/tests/unit/ci.sh b/tests/unit/ci.sh index 72b4678d6..a2189e626 100644 --- a/tests/unit/ci.sh +++ b/tests/unit/ci.sh @@ -31,6 +31,13 @@ function main(){ cd ${speech_ci_path}/server/offline bash test_server_client.sh echo "End server" + + echo "Start testing audiotools" + cd ${speech_ci_path}/../../audio/tests/audiotools + bash test_audiotools.sh + echo "End testing audiotools" + + } main diff --git a/tests/unit/cli/test_cli.sh b/tests/unit/cli/test_cli.sh index 3903e6597..be287ed6c 100755 --- a/tests/unit/cli/test_cli.sh +++ b/tests/unit/cli/test_cli.sh @@ -115,3 +115,4 @@ paddlespeech whisper --task translate --input ./zh.wav paddlespeech whisper --lang en --size base --task transcribe --input ./en.wav echo -e "\033[32mTest success !!!\033[0m" +