From 9e7dca2bc54aac604ee6aa9bce3ae6b71e53824c Mon Sep 17 00:00:00 2001 From: drryanhuang Date: Wed, 20 Nov 2024 08:27:05 +0000 Subject: [PATCH] add AudioSignal && util --- audio/audiotools/audio_signal.py | 1780 +++++++++++++++++++++++++++++ audio/audiotools/requirements.txt | 5 + audio/audiotools/resample.py | 240 ++++ audio/audiotools/util.py | 669 +++++++++++ 4 files changed, 2694 insertions(+) create mode 100644 audio/audiotools/audio_signal.py create mode 100644 audio/audiotools/requirements.txt create mode 100644 audio/audiotools/resample.py create mode 100644 audio/audiotools/util.py diff --git a/audio/audiotools/audio_signal.py b/audio/audiotools/audio_signal.py new file mode 100644 index 000000000..8ab98eabe --- /dev/null +++ b/audio/audiotools/audio_signal.py @@ -0,0 +1,1780 @@ +import copy +import functools +import hashlib +import math +import pathlib +import tempfile +import typing +import warnings +from collections import namedtuple +from pathlib import Path + +import numpy as np +import soundfile +import paddle +import librosa +from typing import Optional + +import util +from resample import resample_frac + +# from .display import DisplayMixin +# from .dsp import DSPMixin +# from .effects import EffectMixin +# from .effects import ImpulseResponseMixin +# from .ffmpeg import FFMPEGMixinx +# from loudness import LoudnessMixin +# from .playback import PlayMixin +# from .whisper import WhisperMixin + + +def create_dct(n_mfcc: int, n_mels: int, norm: Optional[str]) -> paddle.Tensor: + r"""Create a DCT transformation matrix with shape (``n_mels``, ``n_mfcc``), + normalized depending on norm. + + Args: + n_mfcc (int): Number of mfc coefficients to retain + n_mels (int): Number of mel filterbanks + norm (str or None): Norm to use (either "ortho" or None) + + Returns: + paddle.Tensor: The transformation matrix, to be right-multiplied to + row-wise data of size (``n_mels``, ``n_mfcc``). + """ + + if norm is not None and norm != "ortho": + raise ValueError('norm must be either "ortho" or None') + + # http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II + n = paddle.arange(float(n_mels)) + k = paddle.arange(float(n_mfcc)).unsqueeze([1]) + dct = paddle.cos( + math.pi / float(n_mels) * (n + 0.5) * k + ) # size (n_mfcc, n_mels) + + if norm is None: + dct *= 2.0 + else: + dct[0] *= 1.0 / math.sqrt(2.0) + dct *= math.sqrt(2.0 / float(n_mels)) + return dct.transpose([1, 0]) + + +STFTParams = namedtuple( + "STFTParams", + [ + "window_length", + "hop_length", + "window_type", + "match_stride", + "padding_type", + ], +) +""" +STFTParams object is a container that holds STFT parameters - window_length, +hop_length, and window_type. Not all parameters need to be specified. Ones that +are not specified will be inferred by the AudioSignal parameters. + +Parameters +---------- +window_length : int, optional + Window length of STFT, by default ``0.032 * self.sample_rate``. +hop_length : int, optional + Hop length of STFT, by default ``window_length // 4``. +window_type : str, optional + Type of window to use, by default ``sqrt\_hann``. +match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False +padding_type : str, optional + Type of padding to use, by default 'reflect' +""" +STFTParams.__new__.__defaults__ = (None, None, None, None, None) + + +class AudioSignal( + # EffectMixin, + # LoudnessMixin, + # PlayMixin, + # ImpulseResponseMixin, + # DSPMixin, + # DisplayMixin, + # FFMPEGMixin, + # WhisperMixin, +): + """This is the core object of this library. Audio is always + loaded into an AudioSignal, which then enables all the features + of this library, including audio augmentations, I/O, playback, + and more. + + The structure of this object is that the base functionality + is defined in ``core/audio_signal.py``, while extensions to + that functionality are defined in the other ``core/*.py`` + files. For example, all the display-based functionality + (e.g. plot spectrograms, waveforms, write to tensorboard) + are in ``core/display.py``. + + Parameters + ---------- + audio_path_or_array : typing.Union[paddle.Tensor, str, Path, np.ndarray] + Object to create AudioSignal from. Can be a tensor, numpy array, + or a path to a file. The file is always reshaped to + sample_rate : int, optional + Sample rate of the audio. If different from underlying file, resampling is + performed. If passing in an array or tensor, this must be defined, + by default None + stft_params : STFTParams, optional + Parameters of STFT to use. , by default None + offset : float, optional + Offset in seconds to read from file, by default 0 + duration : float, optional + Duration in seconds to read from file, by default None + device : str, optional + Device to load audio onto, by default None + + Examples + -------- + Loading an AudioSignal from an array, at a sample rate of + 44100. + + >>> signal = AudioSignal(paddle.randn([5*44100]), 44100) + + Note, the signal is reshaped to have a batch size, and one + audio channel: + + >>> print(signal.shape) + (1, 1, 44100) + + You can treat AudioSignals like tensors, and many of the same + functions you might use on tensors are defined for AudioSignals + as well: + + >>> signal.to("cuda") + >>> signal.cuda() + >>> signal.clone() + >>> signal.detach() + + Indexing AudioSignals returns an AudioSignal: + + >>> signal[..., 3*44100:4*44100] + + The above signal is 1 second long, and is also an AudioSignal. + """ + + def __init__( + self, + audio_path_or_array: typing.Union[paddle.Tensor, str, Path, np.ndarray], + sample_rate: int = None, + stft_params: STFTParams = None, + offset: float = 0, + duration: float = None, + device: str = None, + ): + # ✅ + audio_path = None + audio_array = None + + if isinstance(audio_path_or_array, str): + audio_path = audio_path_or_array + elif isinstance(audio_path_or_array, pathlib.Path): + audio_path = audio_path_or_array + elif isinstance(audio_path_or_array, np.ndarray): + audio_array = audio_path_or_array + elif paddle.is_tensor(audio_path_or_array): + audio_array = audio_path_or_array + else: + raise ValueError( + "audio_path_or_array must be either a Path, " + "string, numpy array, or paddle Tensor!" + ) + + self.path_to_file = None + + self.audio_data = None + self.sources = None # List of AudioSignal objects. + self.stft_data = None + if audio_path is not None: + self.load_from_file( + audio_path, offset=offset, duration=duration, device=device + ) + elif audio_array is not None: + assert sample_rate is not None, "Must set sample rate!" + self.load_from_array(audio_array, sample_rate, device=device) + + self.window = None + self.stft_params = stft_params + + self.metadata = { + "offset": offset, + "duration": duration, + } + + @property + def path_to_input_file( + self, + ): + """✅ + Path to input file, if it exists. + Alias to ``path_to_file`` for backwards compatibility + """ + return self.path_to_file + + @classmethod + def excerpt( + cls, + audio_path: typing.Union[str, Path], + offset: float = None, + duration: float = None, + state: typing.Union[np.random.RandomState, int] = None, + **kwargs, + ): + """✅Randomly draw an excerpt of ``duration`` seconds from an + audio file specified at ``audio_path``, between ``offset`` seconds + and end of file. ``state`` can be used to seed the random draw. + + Parameters + ---------- + audio_path : typing.Union[str, Path] + Path to audio file to grab excerpt from. + offset : float, optional + Lower bound for the start time, in seconds drawn from + the file, by default None. + duration : float, optional + Duration of excerpt, in seconds, by default None + state : typing.Union[np.random.RandomState, int], optional + RandomState or seed of random state, by default None + + Returns + ------- + AudioSignal + AudioSignal containing excerpt. + + Examples + -------- + >>> signal = AudioSignal.excerpt("path/to/audio", duration=5) + """ + info = util.info(audio_path) + total_duration = info.duration + + state = util.random_state(state) + lower_bound = 0 if offset is None else offset + upper_bound = max(total_duration - duration, 0) + offset = state.uniform(lower_bound, upper_bound) + + signal = cls(audio_path, offset=offset, duration=duration, **kwargs) + signal.metadata["offset"] = offset + signal.metadata["duration"] = duration + + return signal + + @classmethod + def salient_excerpt( + cls, + audio_path: typing.Union[str, Path], + loudness_cutoff: float = None, + num_tries: int = 8, + state: typing.Union[np.random.RandomState, int] = None, + **kwargs, + ): + """❌Similar to AudioSignal.excerpt, except it extracts excerpts only + if they are above a specified loudness threshold, which is computed via + a fast LUFS routine. + + Parameters + ---------- + audio_path : typing.Union[str, Path] + Path to audio file to grab excerpt from. + loudness_cutoff : float, optional + Loudness threshold in dB. Typical values are ``-40, -60``, + etc, by default None + num_tries : int, optional + Number of tries to grab an excerpt above the threshold + before giving up, by default 8. + state : typing.Union[np.random.RandomState, int], optional + RandomState or seed of random state, by default None + kwargs : dict + Keyword arguments to AudioSignal.excerpt + + Returns + ------- + AudioSignal + AudioSignal containing excerpt. + + + .. warning:: + if ``num_tries`` is set to None, ``salient_excerpt`` may try forever, which can + result in an infinite loop if ``audio_path`` does not have + any loud enough excerpts. + + Examples + -------- + >>> signal = AudioSignal.salient_excerpt( + "path/to/audio", + loudness_cutoff=-40, + duration=5 + ) + """ + state = util.random_state(state) + if loudness_cutoff is None: + excerpt = cls.excerpt(audio_path, state=state, **kwargs) + else: + loudness = -np.inf + num_try = 0 + while loudness <= loudness_cutoff: + excerpt = cls.excerpt(audio_path, state=state, **kwargs) + loudness = excerpt.loudness() # <---- NOT IMPLEMENTED YET + num_try += 1 + if num_tries is not None and num_try >= num_tries: + break + return excerpt + + @classmethod + def zeros( + cls, + duration: float, + sample_rate: int, + num_channels: int = 1, + batch_size: int = 1, + **kwargs, + ): + """✅Helper function create an AudioSignal of all zeros. + + Parameters + ---------- + duration : float + Duration of AudioSignal + sample_rate : int + Sample rate of AudioSignal + num_channels : int, optional + Number of channels, by default 1 + batch_size : int, optional + Batch size, by default 1 + + Returns + ------- + AudioSignal + AudioSignal containing all zeros. + + Examples + -------- + Generate 5 seconds of all zeros at a sample rate of 44100. + + >>> signal = AudioSignal.zeros(5.0, 44100) + """ + n_samples = int(duration * sample_rate) + return cls( + paddle.zeros([batch_size, num_channels, n_samples]), + sample_rate, + **kwargs, + ) + + @classmethod + def wave( + cls, + frequency: float, + duration: float, + sample_rate: int, + num_channels: int = 1, + shape: str = "sine", + **kwargs, + ): + """✅ + Generate a waveform of a given frequency and shape. + + Parameters + ---------- + frequency : float + Frequency of the waveform + duration : float + Duration of the waveform + sample_rate : int + Sample rate of the waveform + num_channels : int, optional + Number of channels, by default 1 + shape : str, optional + Shape of the waveform, by default "saw" + One of "sawtooth", "square", "sine", "triangle" + kwargs : dict + Keyword arguments to AudioSignal + """ + n_samples = int(duration * sample_rate) + t = np.linspace(0, duration, n_samples) + if shape == "sawtooth": + from scipy.signal import sawtooth + + wave_data = sawtooth(2 * np.pi * frequency * t, 0.5) + elif shape == "square": + from scipy.signal import square + + wave_data = square(2 * np.pi * frequency * t) + elif shape == "sine": + wave_data = np.sin(2 * np.pi * frequency * t) + elif shape == "triangle": + from scipy.signal import sawtooth + + # frequency is doubled by the abs call, so omit the 2 in 2pi + wave_data = sawtooth(np.pi * frequency * t, 0.5) + wave_data = -np.abs(wave_data) * 2 + 1 + else: + raise ValueError(f"Invalid shape {shape}") + + wave_data = paddle.to_tensor(wave_data, dtype=paddle.float32) + wave_data = wave_data[None, None].expand([1, num_channels, -1]) + return cls(wave_data, sample_rate, **kwargs) + + @classmethod + def batch( + cls, + audio_signals: list, + pad_signals: bool = False, + truncate_signals: bool = False, + resample: bool = False, + dim: int = 0, + ): + """✅Creates a batched AudioSignal from a list of AudioSignals. + + Parameters + ---------- + audio_signals : list[AudioSignal] + List of AudioSignal objects + pad_signals : bool, optional + Whether to pad signals to length of the maximum length + AudioSignal in the list, by default False + truncate_signals : bool, optional + Whether to truncate signals to length of shortest length + AudioSignal in the list, by default False + resample : bool, optional + Whether to resample AudioSignal to the sample rate of + the first AudioSignal in the list, by default False + dim : int, optional + Dimension along which to batch the signals. + + Returns + ------- + AudioSignal + Batched AudioSignal. + + Raises + ------ + RuntimeError + If not all AudioSignals are the same sample rate, and + ``resample=False``, an error is raised. + RuntimeError + If not all AudioSignals are the same the length, and + both ``pad_signals=False`` and ``truncate_signals=False``, + an error is raised. + + Examples + -------- + Batching a bunch of random signals: + + >>> signal_list = [AudioSignal(paddle.randn([44100]), 44100) for _ in range(10)] + >>> signal = AudioSignal.batch(signal_list) + >>> print(signal.shape) + (10, 1, 44100) + + """ + signal_lengths = [x.signal_length for x in audio_signals] + sample_rates = [x.sample_rate for x in audio_signals] + + if len(set(sample_rates)) != 1: + if resample: + for x in audio_signals: + x.resample(sample_rates[0]) + else: + raise RuntimeError( + f"Not all signals had the same sample rate! Got {sample_rates}. " + f"All signals must have the same sample rate, or resample must be True. " + ) + + if len(set(signal_lengths)) != 1: + if pad_signals: + max_length = max(signal_lengths) + for x in audio_signals: + pad_len = max_length - x.signal_length + x.zero_pad(0, pad_len) + elif truncate_signals: + min_length = min(signal_lengths) + for x in audio_signals: + x.truncate_samples(min_length) + else: + raise RuntimeError( + f"Not all signals had the same length! Got {signal_lengths}. " + f"All signals must be the same length, or pad_signals/truncate_signals " + f"must be True. " + ) + # Concatenate along the specified dimension (default 0) + audio_data = paddle.concat( + [x.audio_data for x in audio_signals], axis=dim + ) + audio_paths = [x.path_to_file for x in audio_signals] + + batched_signal = cls( + audio_data, + sample_rate=audio_signals[0].sample_rate, + ) + batched_signal.path_to_file = audio_paths + return batched_signal + + # I/O + def load_from_file( + self, + audio_path: typing.Union[str, Path], + offset: float, + duration: float, + device: str = "cpu", + ): + """✅Loads data from file. Used internally when AudioSignal + is instantiated with a path to a file. + + Parameters + ---------- + audio_path : typing.Union[str, Path] + Path to file + offset : float + Offset in seconds + duration : float + Duration in seconds + device : str, optional + Device to put AudioSignal on, by default "cpu" + + Returns + ------- + AudioSignal + AudioSignal loaded from file + """ + + data, sample_rate = librosa.load( + audio_path, + offset=offset, + duration=duration, + sr=None, + mono=False, + ) + data = util.ensure_tensor(data) + if data.shape[-1] == 0: + raise RuntimeError( + f"Audio file {audio_path} with offset {offset} and duration {duration} is empty!" + ) + + if data.ndim < 2: + data = data.unsqueeze(0) + if data.ndim < 3: + data = data.unsqueeze(0) + self.audio_data = data + + self.original_signal_length = self.signal_length + + self.sample_rate = sample_rate + self.path_to_file = audio_path + return self.to(device) + + def load_from_array( + self, + audio_array: typing.Union[paddle.Tensor, np.ndarray], + sample_rate: int, + device: str = "cpu", + ): + """✅Loads data from array, reshaping it to be exactly 3 + dimensions. Used internally when AudioSignal is called + with a tensor or an array. + + Parameters + ---------- + audio_array : typing.Union[paddle.Tensor, np.ndarray] + Array/tensor of audio of samples. + sample_rate : int + Sample rate of audio + device : str, optional + Device to move audio onto, by default "cpu" + + Returns + ------- + AudioSignal + AudioSignal loaded from array + """ + audio_data = util.ensure_tensor(audio_array) + + if str(audio_data.dtype) == paddle.float64: + audio_data = audio_data.astype("float32") + + if audio_data.ndim < 2: + audio_data = audio_data.unsqueeze(0) + if audio_data.ndim < 3: + audio_data = audio_data.unsqueeze(0) + self.audio_data = audio_data + + self.original_signal_length = self.signal_length + + self.sample_rate = sample_rate + # return self.to(device) + return self + + def write(self, audio_path: typing.Union[str, Path]): + """✅Writes audio to a file. Only writes the audio + that is in the very first item of the batch. To write other items + in the batch, index the signal along the batch dimension + before writing. After writing, the signal's ``path_to_file`` + attribute is updated to the new path. + + Parameters + ---------- + audio_path : typing.Union[str, Path] + Path to write audio to. + + Returns + ------- + AudioSignal + Returns original AudioSignal, so you can use this in a fluent + interface. + + Examples + -------- + Creating and writing a signal to disk: + + >>> signal = AudioSignal(paddle.randn([10, 1, 44100]), 44100) + >>> signal.write("/tmp/out.wav") + + Writing a different element of the batch: + + >>> signal[5].write("/tmp/out.wav") + + Using this in a fluent interface: + + >>> signal.write("/tmp/original.wav").low_pass(4000).write("/tmp/lowpass.wav") + + """ + if self.audio_data[0].abs().max() > 1: + warnings.warn("Audio amplitude > 1 clipped when saving") + soundfile.write( + str(audio_path), self.audio_data[0].numpy().T, self.sample_rate + ) + + self.path_to_file = audio_path + return self + + def deepcopy(self): + """✅Copies the signal and all of its attributes. + + Returns + ------- + AudioSignal + Deep copy of the audio signal. + """ + return copy.deepcopy(self) + + def copy(self): + """✅Shallow copy of signal. + + Returns + ------- + AudioSignal + Shallow copy of the audio signal. + """ + return copy.copy(self) + + def clone(self): + """✅Clones all tensors contained in the AudioSignal, + and returns a copy of the signal with everything + cloned. Useful when using AudioSignal within autograd + computation graphs. + + Relevant attributes are the stft data, the audio data, + and the loudness of the file. + + Returns + ------- + AudioSignal + Clone of AudioSignal. + """ + clone = type(self)( + self.audio_data.clone(), + self.sample_rate, + stft_params=self.stft_params, + ) + if self.stft_data is not None: + clone.stft_data = self.stft_data.clone() + if self._loudness is not None: + clone._loudness = self._loudness.clone() + clone.path_to_file = copy.deepcopy(self.path_to_file) + clone.metadata = copy.deepcopy(self.metadata) + return clone + + def detach(self): + """✅Detaches tensors contained in AudioSignal. + + Relevant attributes are the stft data, the audio data, + and the loudness of the file. + + Returns + ------- + AudioSignal + Same signal, but with all tensors detached. + """ + if self._loudness is not None: + self._loudness = self._loudness.detach() + if self.stft_data is not None: + self.stft_data = self.stft_data.detach() + + self.audio_data = self.audio_data.detach() + return self + + def hash(self): + """✅Writes the audio data to a temporary file, and then + hashes it using hashlib. Useful for creating a file + name based on the audio content. + + Returns + ------- + str + Hash of audio data. + + Examples + -------- + Creating a signal, and writing it to a unique file name: + + >>> signal = AudioSignal(paddle.randn([44100]), 44100) + >>> hash = signal.hash() + >>> signal.write(f"{hash}.wav") + + """ + with tempfile.NamedTemporaryFile(suffix=".wav") as f: + self.write(f.name) + h = hashlib.sha256() + b = bytearray(128 * 1024) + mv = memoryview(b) + with open(f.name, "rb", buffering=0) as f: + for n in iter(lambda: f.readinto(mv), 0): + h.update(mv[:n]) + file_hash = h.hexdigest() + return file_hash + + # Signal operations + def to_mono(self): + """✅Converts audio data to mono audio, by taking the mean + along the channels dimension. + + Returns + ------- + AudioSignal + AudioSignal with mean of channels. + """ + self.audio_data = self.audio_data.mean(1, keepdim=True) + return self + + def resample(self, sample_rate: int): + """✅Resamples the audio, using sinc interpolation. This works on both + cpu and gpu, and is much faster on gpu. + + Parameters + ---------- + sample_rate : int + Sample rate to resample to. + + Returns + ------- + AudioSignal + Resampled AudioSignal + """ + if sample_rate == self.sample_rate: + return self + self.audio_data = resample_frac( + self.audio_data, self.sample_rate, sample_rate + ) + self.sample_rate = sample_rate + return self + + # Tensor operations + def to(self, device: str): + """✅Moves all tensors contained in signal to the specified device. + + Parameters + ---------- + device : str + Device to move AudioSignal onto. Typical values are + "gpu", "cpu", or "gpu:x" to specify the nth gpu. + + Returns + ------- + AudioSignal + AudioSignal with all tensors moved to specified device. + """ + if self._loudness is not None: + self._loudness = self._loudness.to(device) + if self.stft_data is not None: + self.stft_data = self.stft_data.to(device) + if self.audio_data is not None: + device = "gpu" if "cuda" == device else device + self.audio_data = self.audio_data.to(device) + return self + + def float(self): + """✅Calls ``.float()`` on ``self.audio_data``. + + Returns + ------- + AudioSignal + """ + self.audio_data = self.audio_data.astype("float32") + return self + + def cpu(self): + """✅Moves AudioSignal to cpu. + + Returns + ------- + AudioSignal + """ + return self.to("cpu") + + def cuda(self): # pragma: no cover + """✅Moves AudioSignal to cuda. + + Returns + ------- + AudioSignal + """ + return self.to("gpu") + + def numpy(self): + """✅Detaches ``self.audio_data``, moves to cpu, and converts to numpy. + + Returns + ------- + np.ndarray + Audio data as a numpy array. + """ + return self.audio_data.detach().cpu().numpy() + + def zero_pad(self, before: int, after: int): + """✅Zero pads the audio_data tensor before and after. + + Parameters + ---------- + before : int + How many zeros to prepend to audio. + after : int + How many zeros to append to audio. + + Returns + ------- + AudioSignal + AudioSignal with padding applied. + """ + self.audio_data = paddle.nn.functional.pad( + self.audio_data, (before, after), data_format="NCL" + ) + return self + + def zero_pad_to(self, length: int, mode: str = "after"): + """✅Pad with zeros to a specified length, either before or after + the audio data. + + Parameters + ---------- + length : int + Length to pad to + mode : str, optional + Whether to prepend or append zeros to signal, by default "after" + + Returns + ------- + AudioSignal + AudioSignal with padding applied. + """ + if mode == "before": + self.zero_pad(max(length - self.signal_length, 0), 0) + elif mode == "after": + self.zero_pad(0, max(length - self.signal_length, 0)) + return self + + def trim(self, before: int, after: int): + """✅Trims the audio_data tensor before and after. + + Parameters + ---------- + before : int + How many samples to trim from beginning. + after : int + How many samples to trim from end. + + Returns + ------- + AudioSignal + AudioSignal with trimming applied. + """ + if after == 0: + self.audio_data = self.audio_data[..., before:] + else: + self.audio_data = self.audio_data[..., before:-after] + return self + + def truncate_samples(self, length_in_samples: int): + """✅Truncate signal to specified length. + + Parameters + ---------- + length_in_samples : int + Truncate to this many samples. + + Returns + ------- + AudioSignal + AudioSignal with truncation applied. + """ + self.audio_data = self.audio_data[..., :length_in_samples] + return self + + @property + def device(self): + """✅Get device that AudioSignal is on. + + Returns + ------- + paddle.device + Device that AudioSignal is on. + """ + if self.audio_data is not None: + device = self.audio_data.place + elif self.stft_data is not None: + device = self.stft_data.place + return device + + # Properties + @property + def audio_data(self): + """✅Returns the audio data tensor in the object. + + Audio data is always of the shape + (batch_size, num_channels, num_samples). If value has less + than 3 dims (e.g. is (num_channels, num_samples)), then it will + be reshaped to (1, num_channels, num_samples) - a batch size of 1. + + Parameters + ---------- + data : typing.Union[paddle.Tensor, np.ndarray] + Audio data to set. + + Returns + ------- + paddle.Tensor + Audio samples. + """ + return self._audio_data + + @audio_data.setter + def audio_data(self, data: typing.Union[paddle.Tensor, np.ndarray]): + if data is not None: + assert paddle.is_tensor(data), "audio_data should be paddle.Tensor" + assert data.ndim == 3, "audio_data should be 3-dim (B, C, T)" + self._audio_data = data + # Old loudness value not guaranteed to be right, reset it. + self._loudness = None + return + + # alias for audio_data + samples = audio_data + + @property + def stft_data(self): + """✅Returns the STFT data inside the signal. Shape is + (batch, channels, frequencies, time). + + Returns + ------- + paddle.Tensor + Complex spectrogram data. + """ + return self._stft_data + + @stft_data.setter + def stft_data(self, data: typing.Union[paddle.Tensor, np.ndarray]): + if data is not None: + assert paddle.is_tensor(data) and paddle.is_complex(data) + if ( + self.stft_data is not None + and self.stft_data.shape != data.shape + ): + warnings.warn("stft_data changed shape") + self._stft_data = data + return + + @property + def batch_size(self): + """✅Batch size of audio signal. + + Returns + ------- + int + Batch size of signal. + """ + return self.audio_data.shape[0] + + @property + def signal_length(self): + """✅Length of audio signal. + + Returns + ------- + int + Length of signal in samples. + """ + return self.audio_data.shape[-1] + + # alias for signal_length + length = signal_length + + @property + def shape(self): + """✅Shape of audio data. + + Returns + ------- + tuple + Shape of audio data. + """ + return self.audio_data.shape + + @property + def signal_duration(self): + """✅Length of audio signal in seconds. + + Returns + ------- + float + Length of signal in seconds. + """ + return self.signal_length / self.sample_rate + + # alias for signal_duration + duration = signal_duration + + @property + def num_channels(self): + """✅Number of audio channels. + + Returns + ------- + int + Number of audio channels. + """ + return self.audio_data.shape[1] + + # STFT + @staticmethod + @functools.lru_cache(None) + def get_window(window_type: str, window_length: int, device: str = None): + """✅Wrapper around scipy.signal.get_window so one can also get the + popular sqrt-hann window. This function caches for efficiency + using functools.lru\_cache. + + Parameters + ---------- + window_type : str + Type of window to get + window_length : int + Length of the window + device : str + Device to put window onto. + + Returns + ------- + paddle.Tensor + Window returned by scipy.signal.get_window, as a tensor. + """ + from scipy import signal + + if window_type == "average": + window = np.ones(window_length) / window_length + elif window_type == "sqrt_hann": + window = np.sqrt(signal.get_window("hann", window_length)) + else: + window = signal.get_window(window_type, window_length) + window = paddle.to_tensor(window).astype("float32") + return window + + @property + def stft_params(self): + """✅Returns STFTParams object, which can be re-used to other + AudioSignals. + + This property can be set as well. If values are not defined in STFTParams, + they are inferred automatically from the signal properties. The default is to use + 32ms windows, with 8ms hop length, and the square root of the hann window. + + Returns + ------- + STFTParams + STFT parameters for the AudioSignal. + + Examples + -------- + >>> stft_params = STFTParams(128, 32) + >>> signal1 = AudioSignal(paddle.randn([44100]), 44100, stft_params=stft_params) + >>> signal2 = AudioSignal(paddle.randn([44100]), 44100, stft_params=signal1.stft_params) + >>> signal1.stft_params = STFTParams() # Defaults + """ + return self._stft_params + + @stft_params.setter + def stft_params(self, value: STFTParams): + # ✅ + default_win_len = int(2 ** (np.ceil(np.log2(0.032 * self.sample_rate)))) + default_hop_len = default_win_len // 4 + default_win_type = "hann" + default_match_stride = False + default_padding_type = "reflect" + + default_stft_params = STFTParams( + window_length=default_win_len, + hop_length=default_hop_len, + window_type=default_win_type, + match_stride=default_match_stride, + padding_type=default_padding_type, + )._asdict() + + value = value._asdict() if value else default_stft_params + + for key in default_stft_params: + if value[key] is None: + value[key] = default_stft_params[key] + + self._stft_params = STFTParams(**value) + self.stft_data = None + + def compute_stft_padding( + self, window_length: int, hop_length: int, match_stride: bool + ): + """✅Compute how the STFT should be padded, based on match\_stride. + + Parameters + ---------- + window_length : int + Window length of STFT. + hop_length : int + Hop length of STFT. + match_stride : bool + Whether or not to match stride, making the STFT have the same alignment as + convolutional layers. + + Returns + ------- + tuple + Amount to pad on either side of audio. + """ + length = self.signal_length + + if match_stride: + assert ( + hop_length == window_length // 4 + ), "For match_stride, hop must equal n_fft // 4" + right_pad = math.ceil(length / hop_length) * hop_length - length + pad = (window_length - hop_length) // 2 + else: + right_pad = 0 + pad = 0 + + return right_pad, pad + + def stft( + self, + window_length: int = None, + hop_length: int = None, + window_type: str = None, + match_stride: bool = None, + padding_type: str = None, + ): + """✅Computes the short-time Fourier transform of the audio data, + with specified STFT parameters. + + Parameters + ---------- + window_length : int, optional + Window length of STFT, by default ``0.032 * self.sample_rate``. + hop_length : int, optional + Hop length of STFT, by default ``window_length // 4``. + window_type : str, optional + Type of window to use, by default ``sqrt\_hann``. + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + padding_type : str, optional + Type of padding to use, by default 'reflect' + + Returns + ------- + paddle.Tensor + STFT of audio data. + + Examples + -------- + Compute the STFT of an AudioSignal: + + >>> signal = AudioSignal(paddle.randn([44100]), 44100) + >>> signal.stft() + + Vary the window and hop length: + + >>> stft_params_list = [STFTParams(128, 32), STFTParams(512, 128)] + >>> for stft_params in stft_params_list: + >>> signal.stft_params = stft_params + >>> signal.stft() + + """ + window_length = ( + self.stft_params.window_length + if window_length is None + else int(window_length) + ) + hop_length = ( + self.stft_params.hop_length + if hop_length is None + else int(hop_length) + ) + window_type = ( + self.stft_params.window_type if window_type is None else window_type + ) + match_stride = ( + self.stft_params.match_stride + if match_stride is None + else match_stride + ) + padding_type = ( + self.stft_params.padding_type + if padding_type is None + else padding_type + ) + + window = self.get_window(window_type, window_length) + # window = window.to(self.audio_data.device) + + audio_data = self.audio_data + right_pad, pad = self.compute_stft_padding( + window_length, hop_length, match_stride + ) + audio_data = paddle.nn.functional.pad( + x=audio_data, + pad=[pad, pad + right_pad], + mode="reflect", + data_format="NCL", + ) + stft_data = paddle.signal.stft( + audio_data.reshape([-1, audio_data.shape[-1]]), + n_fft=window_length, + hop_length=hop_length, + window=window, + # return_complex=True, + center=True, + ) + _, nf, nt = stft_data.shape + stft_data = stft_data.reshape( + [self.batch_size, self.num_channels, nf, nt] + ) + + if match_stride: + # Drop first two and last two frames, which are added + # because of padding. Now num_frames * hop_length = num_samples. + stft_data = stft_data[..., 2:-2] + self.stft_data = stft_data + + return stft_data + + def istft( + self, + window_length: int = None, + hop_length: int = None, + window_type: str = None, + match_stride: bool = None, + length: int = None, + ): + """✅Computes inverse STFT and sets it to audio\_data. + + Parameters + ---------- + window_length : int, optional + Window length of STFT, by default ``0.032 * self.sample_rate``. + hop_length : int, optional + Hop length of STFT, by default ``window_length // 4``. + window_type : str, optional + Type of window to use, by default ``sqrt\_hann``. + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + length : int, optional + Original length of signal, by default None + + Returns + ------- + AudioSignal + AudioSignal with istft applied. + + Raises + ------ + RuntimeError + Raises an error if stft was not called prior to istft on the signal, + or if stft_data is not set. + """ + if self.stft_data is None: + raise RuntimeError("Cannot do inverse STFT without self.stft_data!") + + window_length = ( + self.stft_params.window_length + if window_length is None + else int(window_length) + ) + hop_length = ( + self.stft_params.hop_length + if hop_length is None + else int(hop_length) + ) + window_type = ( + self.stft_params.window_type if window_type is None else window_type + ) + match_stride = ( + self.stft_params.match_stride + if match_stride is None + else match_stride + ) + + window = self.get_window( + window_type, window_length, self.stft_data.place + ) + + nb, nch, nf, nt = self.stft_data.shape + stft_data = self.stft_data.reshape([nb * nch, nf, nt]) + right_pad, pad = self.compute_stft_padding( + window_length, hop_length, match_stride + ) + + if length is None: + length = self.original_signal_length + length = length + 2 * pad + right_pad + + if match_stride: + # Zero-pad the STFT on either side, putting back the frames that were + # dropped in stft(). + stft_data = paddle.nn.functional.pad( + stft_data, pad=(2, 2), data_format="NCL" + ) + + audio_data = paddle.signal.istft( + stft_data, + n_fft=window_length, + hop_length=hop_length, + window=window, + length=length, + center=True, + ) + audio_data = audio_data.reshape([nb, nch, -1]) + if match_stride: + audio_data = audio_data[..., pad : -(pad + right_pad)] + self.audio_data = audio_data + + return self + + @staticmethod + @functools.lru_cache(None) + def get_mel_filters( + sr: int, n_fft: int, n_mels: int, fmin: float = 0.0, fmax: float = None + ): + """✅Create a Filterbank matrix to combine FFT bins into Mel-frequency bins. + + Parameters + ---------- + sr : int + Sample rate of audio + n_fft : int + Number of FFT bins + n_mels : int + Number of mels + fmin : float, optional + Lowest frequency, in Hz, by default 0.0 + fmax : float, optional + Highest frequency, by default None + + Returns + ------- + np.ndarray [shape=(n_mels, 1 + n_fft/2)] + Mel transform matrix + """ + from librosa.filters import mel as librosa_mel_fn + + return librosa_mel_fn( + sr=sr, + n_fft=n_fft, + n_mels=n_mels, + fmin=fmin, + fmax=fmax, + ) + + def mel_spectrogram( + self, + n_mels: int = 80, + mel_fmin: float = 0.0, + mel_fmax: float = None, + **kwargs, + ): + """✅Computes a Mel spectrogram. + + Parameters + ---------- + n_mels : int, optional + Number of mels, by default 80 + mel_fmin : float, optional + Lowest frequency, in Hz, by default 0.0 + mel_fmax : float, optional + Highest frequency, by default None + kwargs : dict, optional + Keyword arguments to self.stft(). + + Returns + ------- + paddle.Tensor [shape=(batch, channels, mels, time)] + Mel spectrogram. + """ + stft = self.stft(**kwargs) + magnitude = paddle.abs(stft) + + nf = magnitude.shape[2] + mel_basis = self.get_mel_filters( + sr=self.sample_rate, + n_fft=2 * (nf - 1), + n_mels=n_mels, + fmin=mel_fmin, + fmax=mel_fmax, + ) + mel_basis = paddle.to_tensor(mel_basis) + + mel_spectrogram = magnitude.transpose([0, 1, 3, 2]) @ mel_basis.T + mel_spectrogram = mel_spectrogram.transpose([0, 1, 3, 2]) + return mel_spectrogram + + @staticmethod + @functools.lru_cache(None) + def get_dct( + n_mfcc: int, n_mels: int, norm: str = "ortho", device: str = None + ): + """✅Create a discrete cosine transform (DCT) transformation matrix with shape (``n_mels``, ``n_mfcc``), + it can be normalized depending on norm. For more information about dct: + http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II + + Parameters + ---------- + n_mfcc : int + Number of mfccs + n_mels : int + Number of mels + norm : str + Use "ortho" to get a orthogonal matrix or None, by default "ortho" + device : str, optional + Device to load the transformation matrix on, by default None + + Returns + ------- + paddle.Tensor [shape=(n_mels, n_mfcc)] T + The dct transformation matrix. + """ + # from torchaudio.functional import create_dct + + return create_dct(n_mfcc, n_mels, norm) + + def mfcc( + self, + n_mfcc: int = 40, + n_mels: int = 80, + log_offset: float = 1e-6, + **kwargs, + ): + """✅Computes mel-frequency cepstral coefficients (MFCCs). + + Parameters + ---------- + n_mfcc : int, optional + Number of mels, by default 40 + n_mels : int, optional + Number of mels, by default 80 + log_offset: float, optional + Small value to prevent numerical issues when trying to compute log(0), by default 1e-6 + kwargs : dict, optional + Keyword arguments to self.mel_spectrogram(), note that some of them will be used for self.stft() + + Returns + ------- + paddle.Tensor [shape=(batch, channels, mfccs, time)] + MFCCs. + """ + + mel_spectrogram = self.mel_spectrogram(n_mels, **kwargs) + mel_spectrogram = paddle.log(mel_spectrogram + log_offset) + dct_mat = self.get_dct(n_mfcc, n_mels, "ortho", self.device) + + mfcc = mel_spectrogram.transpose([0, 1, 3, 2]) @ dct_mat + mfcc = mfcc.transpose([0, 1, 3, 2]) + return mfcc + + @property + def magnitude(self): + """✅Computes and returns the absolute value of the STFT, which + is the magnitude. This value can also be set to some tensor. + When set, ``self.stft_data`` is manipulated so that its magnitude + matches what this is set to, and modulated by the phase. + + Returns + ------- + paddle.Tensor + Magnitude of STFT. + + Examples + -------- + >>> signal = AudioSignal(paddle.randn([44100]), 44100) + >>> magnitude = signal.magnitude # Computes stft if not computed + >>> magnitude[magnitude < magnitude.mean()] = 0 + >>> signal.magnitude = magnitude + >>> signal.istft() + """ + if self.stft_data is None: + self.stft() + return paddle.abs(self.stft_data) + + @magnitude.setter + def magnitude(self, value): + self.stft_data = value * paddle.exp(1j * self.phase) + return + + def log_magnitude( + self, ref_value: float = 1.0, amin: float = 1e-5, top_db: float = 80.0 + ): + """✅Computes the log-magnitude of the spectrogram. + + Parameters + ---------- + ref_value : float, optional + The magnitude is scaled relative to ``ref``: ``20 * log10(S / ref)``. + Zeros in the output correspond to positions where ``S == ref``, + by default 1.0 + amin : float, optional + Minimum threshold for ``S`` and ``ref``, by default 1e-5 + top_db : float, optional + Threshold the output at ``top_db`` below the peak: + ``max(10 * log10(S/ref)) - top_db``, by default -80.0 + + Returns + ------- + paddle.Tensor + Log-magnitude spectrogram + """ + magnitude = self.magnitude + + amin = amin**2 + log_spec = 10.0 * paddle.log10(magnitude.pow(2).clip(min=amin)) + log_spec -= 10.0 * np.log10(np.maximum(amin, ref_value)) + + if top_db is not None: + log_spec = paddle.maximum(log_spec, log_spec.max() - top_db) + return log_spec + + @property + def phase(self): + """✅Computes and returns the phase of the STFT. + This value can also be set to some tensor. + When set, ``self.stft_data`` is manipulated so that its phase + matches what this is set to, we original magnitudeith th. + + Returns + ------- + paddle.Tensor + Phase of STFT. + + Examples + -------- + >>> signal = AudioSignal(paddle.randn([44100]), 44100) + >>> phase = signal.phase # Computes stft if not computed + >>> phase[phase < phase.mean()] = 0 + >>> signal.phase = phase + >>> signal.istft() + """ + if self.stft_data is None: + self.stft() + return paddle.angle(self.stft_data) + + @phase.setter + def phase(self, value): + # ✅ + self.stft_data = self.magnitude * paddle.exp(1j * value) + return + + # Operator overloading + def __add__(self, other): + new_signal = self.clone() + new_signal.audio_data += util._get_value(other) + return new_signal + + def __iadd__(self, other): + self.audio_data += util._get_value(other) + return self + + def __radd__(self, other): + return self + other + + def __sub__(self, other): + new_signal = self.clone() + new_signal.audio_data -= util._get_value(other) + return new_signal + + def __isub__(self, other): + self.audio_data -= util._get_value(other) + return self + + def __mul__(self, other): + new_signal = self.clone() + new_signal.audio_data *= util._get_value(other) + return new_signal + + def __imul__(self, other): + self.audio_data *= util._get_value(other) + return self + + def __rmul__(self, other): + return self * other + + # Representation + def _info(self): + # ✅ + dur = ( + f"{self.signal_duration:0.3f}" + if self.signal_duration + else "[unknown]" + ) + info = { + "duration": f"{dur} seconds", + "batch_size": self.batch_size, + "path": self.path_to_file if self.path_to_file else "path unknown", + "sample_rate": self.sample_rate, + "num_channels": ( + self.num_channels if self.num_channels else "[unknown]" + ), + "audio_data.shape": self.audio_data.shape, + "stft_params": self.stft_params, + "device": self.device, + } + + return info + + def markdown(self): + """✅Produces a markdown representation of AudioSignal, in a markdown table. + + Returns + ------- + str + Markdown representation of AudioSignal. + + Examples + -------- + >>> signal = AudioSignal(paddle.randn([44100]), 44100) + >>> print(signal.markdown()) + | Key | Value + |---|--- + | duration | 1.000 seconds | + | batch_size | 1 | + | path | path unknown | + | sample_rate | 44100 | + | num_channels | 1 | + | audio_data.shape | paddle.Size([1, 1, 44100]) | + | stft_params | STFTParams(window_length=2048, hop_length=512, window_type='sqrt_hann', match_stride=False) | + | device | cpu | + """ + info = self._info() + + FORMAT = "| Key | Value \n" "|---|--- \n" + for k, v in info.items(): + row = f"| {k} | {v} |\n" + FORMAT += row + return FORMAT + + def __str__(self): + info = self._info() + + desc = "" + for k, v in info.items(): + desc += f"{k}: {v}\n" + return desc + + def __rich__(self): + from rich.table import Table + + info = self._info() + + table = Table(title=f"{self.__class__.__name__}") + table.add_column("Key", style="green") + table.add_column("Value", style="cyan") + + for k, v in info.items(): + table.add_row(k, str(v)) + return table + + # Comparison + def __eq__(self, other): + for k, v in list(self.__dict__.items()): + if paddle.is_tensor(v): + if not paddle.allclose(v, other.__dict__[k], atol=1e-6): + max_error = (v - other.__dict__[k]).abs().max() + print(f"Max abs error for {k}: {max_error}") + return False + return True + + # Indexing + def __getitem__(self, key): + if paddle.is_tensor(key) and key.ndim == 0 and key.item() is True: + assert self.batch_size == 1 + audio_data = self.audio_data + _loudness = self._loudness + stft_data = self.stft_data + + elif isinstance(key, (bool, int, list, slice, tuple)) or ( + paddle.is_tensor(key) and key.ndim <= 1 + ): + # Indexing only on the batch dimension. + # Then let's copy over relevant stuff. + # Future work: make this work for time-indexing + # as well, using the hop length. + audio_data = self.audio_data[key] + _loudness = ( + self._loudness[key] if self._loudness is not None else None + ) + stft_data = ( + self.stft_data[key] if self.stft_data is not None else None + ) + + sources = None + + copy = type(self)( + audio_data, self.sample_rate, stft_params=self.stft_params + ) + copy._loudness = _loudness + copy._stft_data = stft_data + copy.sources = sources + + return copy + + def __setitem__(self, key, value): + if not isinstance(value, type(self)): + self.audio_data[key] = value + return + + if paddle.is_tensor(key) and key.ndim == 0 and key.item() is True: + assert self.batch_size == 1 + self.audio_data = value.audio_data + self._loudness = value._loudness + self.stft_data = value.stft_data + return + + elif isinstance(key, (bool, int, list, slice, tuple)) or ( + paddle.is_tensor(key) and key.ndim <= 1 + ): + if self.audio_data is not None and value.audio_data is not None: + self.audio_data[key] = value.audio_data + if self._loudness is not None and value._loudness is not None: + self._loudness[key] = value._loudness + if self.stft_data is not None and value.stft_data is not None: + self.stft_data[key] = value.stft_data + return + + def __ne__(self, other): + return not self == other diff --git a/audio/audiotools/requirements.txt b/audio/audiotools/requirements.txt new file mode 100644 index 000000000..925b740fe --- /dev/null +++ b/audio/audiotools/requirements.txt @@ -0,0 +1,5 @@ +soundfile +librosa +scipy +rich +flatten_dict \ No newline at end of file diff --git a/audio/audiotools/resample.py b/audio/audiotools/resample.py new file mode 100644 index 000000000..b5c490925 --- /dev/null +++ b/audio/audiotools/resample.py @@ -0,0 +1,240 @@ +import inspect +from typing import Optional, Sequence +import paddle +import paddle.nn.functional as F +import math + + +def simple_repr( + obj, attrs: Optional[Sequence[str]] = None, overrides: dict = {} +): + """ + Return a simple representation string for `obj`. + If `attrs` is not None, it should be a list of attributes to include. + """ + params = inspect.signature(obj.__class__).parameters + attrs_repr = [] + if attrs is None: + attrs = list(params.keys()) + for attr in attrs: + display = False + if attr in overrides: + value = overrides[attr] + elif hasattr(obj, attr): + value = getattr(obj, attr) + else: + continue + if attr in params: + param = params[attr] + if param.default is inspect._empty or value != param.default: # type: ignore + display = True + else: + display = True + + if display: + attrs_repr.append(f"{attr}={value}") + return f"{obj.__class__.__name__}({','.join(attrs_repr)})" + + +def sinc(x: paddle.Tensor): + """ + Implementation of sinc, i.e. sin(x) / x + + __Warning__: the input is not multiplied by `pi`! + """ + return paddle.where( + x == 0, + paddle.to_tensor(1.0, dtype=x.dtype, place=x.place), + paddle.sin(x) / x, + ) + + +class ResampleFrac(paddle.nn.Layer): + """ + Resampling from the sample rate `old_sr` to `new_sr`. + """ + + def __init__( + self, old_sr: int, new_sr: int, zeros: int = 24, rolloff: float = 0.945 + ): + """ + Args: + old_sr (int): sample rate of the input signal x. + new_sr (int): sample rate of the output. + zeros (int): number of zero crossing to keep in the sinc filter. + rolloff (float): use a lowpass filter that is `rolloff * new_sr / 2`, + to ensure sufficient margin due to the imperfection of the FIR filter used. + Lowering this value will reduce anti-aliasing, but will reduce some of the + highest frequencies. + + Shape: + + - Input: `[*, T]` + - Output: `[*, T']` with `T' = int(new_sr * T / old_sr)` + + + .. caution:: + After dividing `old_sr` and `new_sr` by their GCD, both should be small + for this implementation to be fast. + + >>> import paddle + >>> resample = ResampleFrac(4, 5) + >>> x = paddle.randn([1000]) + >>> print(len(resample(x))) + 1250 + """ + super(ResampleFrac, self).__init__() + if not isinstance(old_sr, int) or not isinstance(new_sr, int): + raise ValueError("old_sr and new_sr should be integers") + gcd = math.gcd(old_sr, new_sr) + self.old_sr = old_sr // gcd + self.new_sr = new_sr // gcd + self.zeros = zeros + self.rolloff = rolloff + + self._init_kernels() + + def _init_kernels(self): + if self.old_sr == self.new_sr: + return + + kernels = [] + sr = min(self.new_sr, self.old_sr) + # rolloff will perform antialiasing filtering by removing the highest frequencies. + # At first I thought I only needed this when downsampling, but when upsampling + # you will get edge artifacts without this, the edge is equivalent to zero padding, + # which will add high freq artifacts. + sr *= self.rolloff + + # The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor) + # using the sinc interpolation formula: + # x(t) = sum_i x[i] sinc(pi * old_sr * (i / old_sr - t)) + # We can then sample the function x(t) with a different sample rate: + # y[j] = x(j / new_sr) + # or, + # y[j] = sum_i x[i] sinc(pi * old_sr * (i / old_sr - j / new_sr)) + + # We see here that y[j] is the convolution of x[i] with a specific filter, for which + # we take an FIR approximation, stopping when we see at least `zeros` zeros crossing. + # But y[j+1] is going to have a different set of weights and so on, until y[j + new_sr]. + # Indeed: + # y[j + new_sr] = sum_i x[i] sinc(pi * old_sr * ((i / old_sr - (j + new_sr) / new_sr)) + # = sum_i x[i] sinc(pi * old_sr * ((i - old_sr) / old_sr - j / new_sr)) + # = sum_i x[i + old_sr] sinc(pi * old_sr * (i / old_sr - j / new_sr)) + # so y[j+new_sr] uses the same filter as y[j], but on a shifted version of x by `old_sr`. + # This will explain the F.conv1d after, with a stride of old_sr. + self._width = math.ceil(self.zeros * self.old_sr / sr) + # If old_sr is still big after GCD reduction, most filters will be very unbalanced, i.e., + # they will have a lot of almost zero values to the left or to the right... + # There is probably a way to evaluate those filters more efficiently, but this is kept for + # future work. + idx = paddle.arange( + -self._width, self._width + self.old_sr, dtype="float32" + ) + for i in range(self.new_sr): + t = (-i / self.new_sr + idx / self.old_sr) * sr + t = paddle.clip(t, -self.zeros, self.zeros) + t *= math.pi + window = paddle.cos(t / self.zeros / 2) ** 2 + kernel = sinc(t) * window + # Renormalize kernel to ensure a constant signal is preserved. + kernel = kernel / kernel.sum() + kernels.append(kernel) + + _kernel = paddle.stack(kernels).reshape([self.new_sr, 1, -1]) + self.kernel = self.create_parameter( + shape=_kernel.shape, + dtype=_kernel.dtype, + ) + self.kernel.set_value(_kernel) + + def forward( + self, + x: paddle.Tensor, + output_length: Optional[int] = None, + full: bool = False, + ): + """ + Resample x. + Args: + x (Tensor): signal to resample, time should be the last dimension + output_length (None or int): This can be set to the desired output length + (last dimension). Allowed values are between 0 and + ceil(length * new_sr / old_sr). When None (default) is specified, the + floored output length will be used. In order to select the largest possible + size, use the `full` argument. + full (bool): return the longest possible output from the input. This can be useful + if you chain resampling operations, and want to give the `output_length` only + for the last one, while passing `full=True` to all the other ones. + """ + if self.old_sr == self.new_sr: + return x + shape = x.shape + length = x.shape[-1] + x = x.reshape([-1, length]) + x = F.pad( + x.unsqueeze(1), + [self._width, self._width + self.old_sr], + mode="replicate", + data_format="NCL", + ) + ys = F.conv1d(x, self.kernel, stride=self.old_sr, data_format="NCL") + y = ys.transpose([0, 2, 1]).reshape(list(shape[:-1]) + [-1]) + + float_output_length = paddle.to_tensor( + self.new_sr * length / self.old_sr, dtype="float32" + ) + max_output_length = paddle.ceil(float_output_length).astype("int64") + default_output_length = paddle.floor(float_output_length).astype( + "int64" + ) + + if output_length is None: + applied_output_length = ( + max_output_length if full else default_output_length + ) + elif output_length < 0 or output_length > max_output_length: + raise ValueError( + f"output_length must be between 0 and {max_output_length.numpy()}" + ) + else: + applied_output_length = paddle.to_tensor( + output_length, dtype="int64" + ) + if full: + raise ValueError( + "You cannot pass both full=True and output_length" + ) + return y[..., :applied_output_length] + + def __repr__(self): + return simple_repr(self) + + +def resample_frac( + x: paddle.Tensor, + old_sr: int, + new_sr: int, + zeros: int = 24, + rolloff: float = 0.945, + output_length: Optional[int] = None, + full: bool = False, +): + """ + Functional version of `ResampleFrac`, refer to its documentation for more information. + + ..warning:: + If you call repeatidly this functions with the same sample rates, then the + resampling kernel will be recomputed everytime. For best performance, you should use + and cache an instance of `ResampleFrac`. + """ + return ResampleFrac(old_sr, new_sr, zeros, rolloff)( + x, output_length, full + ) + + +if __name__ == "__main__": + + resample = ResampleFrac(4, 5) + x = paddle.randn([1000]) + print(len(resample(x))) diff --git a/audio/audiotools/util.py b/audio/audiotools/util.py new file mode 100644 index 000000000..e5e7a8e3b --- /dev/null +++ b/audio/audiotools/util.py @@ -0,0 +1,669 @@ +import csv +import glob +import math +import numbers +import os +import random +import typing +import soundfile +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Optional, List + +import numpy as np +import paddle +from flatten_dict import flatten +from flatten_dict import unflatten + + +@dataclass +class Info: + + sample_rate: float + num_frames: int + + @property + def duration(self) -> float: + return self.num_frames / self.sample_rate + + +def info(audio_path: str): + """✅ + + Parameters + ---------- + audio_path : str + Path to audio file. + """ + info = soundfile.info(str(audio_path)) + info = Info(sample_rate=info.samplerate, num_frames=info.frames) + + return info + + +def ensure_tensor( + x: typing.Union[np.ndarray, paddle.Tensor, float, int], + ndim: int = None, + batch_size: int = None, +): + """✅Ensures that the input ``x`` is a tensor of specified + dimensions and batch size. + + Parameters + ---------- + x : typing.Union[np.ndarray, paddle.Tensor, float, int] + Data that will become a tensor on its way out. + ndim : int, optional + How many dimensions should be in the output, by default None + batch_size : int, optional + The batch size of the output, by default None + + Returns + ------- + paddle.Tensor + Modified version of ``x`` as a tensor. + """ + if not paddle.is_tensor(x): + x = paddle.to_tensor(x) + if ndim is not None: + assert x.ndim <= ndim + while x.ndim < ndim: + x = x.unsqueeze(-1) + if batch_size is not None: + if x.shape[0] != batch_size: + shape = list(x.shape) + shape[0] = batch_size + x = paddle.expand(x, shape) + return x + + +def _get_value(other): + # ✅ + # from . import AudioSignal + from audio_signal import AudioSignal + + if isinstance(other, AudioSignal): + return other.audio_data + return other + + +def hz_to_bin(hz: paddle.Tensor, n_fft: int, sample_rate: int): + """Closest frequency bin given a frequency, number + of bins, and a sampling rate. + + Parameters + ---------- + hz : paddle.Tensor + Tensor of frequencies in Hz. + n_fft : int + Number of FFT bins. + sample_rate : int + Sample rate of audio. + + Returns + ------- + paddle.Tensor + Closest bins to the data. + """ + shape = hz.shape + hz = hz.flatten() + freqs = paddle.linspace(0, sample_rate / 2, 2 + n_fft // 2) + hz[hz > sample_rate / 2] = sample_rate / 2 + + closest = (hz[None, :] - freqs[:, None]).abs() + closest_bins = closest.min(dim=0).indices + + return closest_bins.reshape(*shape) + + +def random_state(seed: typing.Union[int, np.random.RandomState]): + """✅ + Turn seed into a np.random.RandomState instance. + + Parameters + ---------- + seed : typing.Union[int, np.random.RandomState] or None + If seed is None, return the RandomState singleton used by np.random. + If seed is an int, return a new RandomState instance seeded with seed. + If seed is already a RandomState instance, return it. + Otherwise raise ValueError. + + Returns + ------- + np.random.RandomState + Random state object. + + Raises + ------ + ValueError + If seed is not valid, an error is thrown. + """ + if seed is None or seed is np.random: + return np.random.mtrand._rand + elif isinstance(seed, (numbers.Integral, np.integer, int)): + return np.random.RandomState(seed) + elif isinstance(seed, np.random.RandomState): + return seed + else: + raise ValueError( + "%r cannot be used to seed a numpy.random.RandomState" + " instance" % seed + ) + + +def seed(random_seed, set_cudnn=False): + """ + Seeds all random states with the same random seed + for reproducibility. Seeds ``numpy``, ``random`` and ``paddle`` + random generators. + For full reproducibility, two further options must be set + according to the paddle documentation: + https://pypaddle.org/docs/stable/notes/randomness.html + To do this, ``set_cudnn`` must be True. It defaults to + False, since setting it to True results in a performance + hit. + + Args: + random_seed (int): integer corresponding to random seed to + use. + set_cudnn (bool): Whether or not to set cudnn into determinstic + mode and off of benchmark mode. Defaults to False. + """ + + paddle.manual_seed(random_seed) + np.random.seed(random_seed) + random.seed(random_seed) + + if set_cudnn: + paddle.backends.cudnn.deterministic = True + paddle.backends.cudnn.benchmark = False + + +@contextmanager +def _close_temp_files(tmpfiles: list): + """Utility function for creating a context and closing all temporary files + once the context is exited. For correct functionality, all temporary file + handles created inside the context must be appended to the ```tmpfiles``` + list. + + This function is taken wholesale from Scaper. + + Parameters + ---------- + tmpfiles : list + List of temporary file handles + """ + + def _close(): + for t in tmpfiles: + try: + t.close() + os.unlink(t.name) + except: + pass + + try: + yield + except: # pragma: no cover + _close() + raise + _close() + + +AUDIO_EXTENSIONS = [".wav", ".flac", ".mp3", ".mp4"] + + +def find_audio(folder: str, ext: List[str] = AUDIO_EXTENSIONS): + """Finds all audio files in a directory recursively. + Returns a list. + + Parameters + ---------- + folder : str + Folder to look for audio files in, recursively. + ext : List[str], optional + Extensions to look for without the ., by default + ``['.wav', '.flac', '.mp3', '.mp4']``. + """ + folder = Path(folder) + # Take care of case where user has passed in an audio file directly + # into one of the calling functions. + if str(folder).endswith(tuple(ext)): + # if, however, there's a glob in the path, we need to + # return the glob, not the file. + if "*" in str(folder): + return glob.glob(str(folder), recursive=("**" in str(folder))) + else: + return [folder] + + files = [] + for x in ext: + files += folder.glob(f"**/*{x}") + return files + + +def read_sources( + sources: List[str], + remove_empty: bool = True, + relative_path: str = "", + ext: List[str] = AUDIO_EXTENSIONS, +): + """Reads audio sources that can either be folders + full of audio files, or CSV files that contain paths + to audio files. CSV files that adhere to the expected + format can be generated by + :py:func:`audiotools.data.preprocess.create_csv`. + + Parameters + ---------- + sources : List[str] + List of audio sources to be converted into a + list of lists of audio files. + remove_empty : bool, optional + Whether or not to remove rows with an empty "path" + from each CSV file, by default True. + + Returns + ------- + list + List of lists of rows of CSV files. + """ + files = [] + relative_path = Path(relative_path) + for source in sources: + source = str(source) + _files = [] + if source.endswith(".csv"): + with open(source, "r") as f: + reader = csv.DictReader(f) + for x in reader: + if remove_empty and x["path"] == "": + continue + if x["path"] != "": + x["path"] = str(relative_path / x["path"]) + _files.append(x) + else: + for x in find_audio(source, ext=ext): + x = str(relative_path / x) + _files.append({"path": x}) + files.append(sorted(_files, key=lambda x: x["path"])) + return files + + +def choose_from_list_of_lists( + state: np.random.RandomState, list_of_lists: list, p: float = None +): + """Choose a single item from a list of lists. + + Parameters + ---------- + state : np.random.RandomState + Random state to use when choosing an item. + list_of_lists : list + A list of lists from which items will be drawn. + p : float, optional + Probabilities of each list, by default None + + Returns + ------- + typing.Any + An item from the list of lists. + """ + source_idx = state.choice(list(range(len(list_of_lists))), p=p) + item_idx = state.randint(len(list_of_lists[source_idx])) + return list_of_lists[source_idx][item_idx], source_idx, item_idx + + +@contextmanager +def chdir(newdir: typing.Union[Path, str]): + """✅ + Context manager for switching directories to run a + function. Useful for when you want to use relative + paths to different runs. + + Parameters + ---------- + newdir : typing.Union[Path, str] + Directory to switch to. + """ + curdir = os.getcwd() + try: + os.chdir(newdir) + yield + finally: + os.chdir(curdir) + + +def prepare_batch( + batch: typing.Union[dict, list, paddle.Tensor], device: str = "cpu" +): + """Moves items in a batch (typically generated by a DataLoader as a list + or a dict) to the specified device. This works even if dictionaries + are nested. + + Parameters + ---------- + batch : typing.Union[dict, list, paddle.Tensor] + Batch, typically generated by a dataloader, that will be moved to + the device. + device : str, optional + Device to move batch to, by default "cpu" + + Returns + ------- + typing.Union[dict, list, paddle.Tensor] + Batch with all values moved to the specified device. + """ + if isinstance(batch, dict): + batch = flatten(batch) + for key, val in batch.items(): + try: + batch[key] = val.to(device) + except: + pass + batch = unflatten(batch) + elif paddle.is_tensor(batch): + batch = batch.to(device) + elif isinstance(batch, list): + for i in range(len(batch)): + try: + batch[i] = batch[i].to(device) + except: + pass + return batch + + +def sample_from_dist(dist_tuple: tuple, state: np.random.RandomState = None): + """Samples from a distribution defined by a tuple. The first + item in the tuple is the distribution type, and the rest of the + items are arguments to that distribution. The distribution function + is gotten from the ``np.random.RandomState`` object. + + Parameters + ---------- + dist_tuple : tuple + Distribution tuple + state : np.random.RandomState, optional + Random state, or seed to use, by default None + + Returns + ------- + typing.Union[float, int, str] + Draw from the distribution. + + Examples + -------- + Sample from a uniform distribution: + + >>> dist_tuple = ("uniform", 0, 1) + >>> sample_from_dist(dist_tuple) + + Sample from a constant distribution: + + >>> dist_tuple = ("const", 0) + >>> sample_from_dist(dist_tuple) + + Sample from a normal distribution: + + >>> dist_tuple = ("normal", 0, 0.5) + >>> sample_from_dist(dist_tuple) + + """ + if dist_tuple[0] == "const": + return dist_tuple[1] + state = random_state(state) + dist_fn = getattr(state, dist_tuple[0]) + return dist_fn(*dist_tuple[1:]) + + +def collate(list_of_dicts: list, n_splits: int = None): + """Collates a list of dictionaries (e.g. as returned by a + dataloader) into a dictionary with batched values. This routine + uses the default paddle collate function for everything + except AudioSignal objects, which are handled by the + :py:func:`audiotools.core.audio_signal.AudioSignal.batch` + function. + + This function takes n_splits to enable splitting a batch + into multiple sub-batches for the purposes of gradient accumulation, + etc. + + Parameters + ---------- + list_of_dicts : list + List of dictionaries to be collated. + n_splits : int + Number of splits to make when creating the batches (split into + sub-batches). Useful for things like gradient accumulation. + + Returns + ------- + dict + Dictionary containing batched data. + """ + + from . import AudioSignal + + batches = [] + list_len = len(list_of_dicts) + + return_list = False if n_splits is None else True + n_splits = 1 if n_splits is None else n_splits + n_items = int(math.ceil(list_len / n_splits)) + + for i in range(0, list_len, n_items): + # Flatten the dictionaries to avoid recursion. + list_of_dicts_ = [flatten(d) for d in list_of_dicts[i : i + n_items]] + dict_of_lists = { + k: [dic[k] for dic in list_of_dicts_] for k in list_of_dicts_[0] + } + + batch = {} + for k, v in dict_of_lists.items(): + if isinstance(v, list): + if all(isinstance(s, AudioSignal) for s in v): + batch[k] = AudioSignal.batch(v, pad_signals=True) + else: + # Borrow the default collate fn from paddle. + batch[k] = paddle.utils.data._utils.collate.default_collate( + v + ) + batches.append(unflatten(batch)) + + batches = batches[0] if not return_list else batches + return batches + + +BASE_SIZE = 864 +DEFAULT_FIG_SIZE = (9, 3) + + +def format_figure( + fig_size: tuple = None, + title: str = None, + fig=None, + format_axes: bool = True, + format: bool = True, + font_color: str = "white", +): + """Prettifies the spectrogram and waveform plots. A title + can be inset into the top right corner, and the axes can be + inset into the figure, allowing the data to take up the entire + image. Used in + + - :py:func:`audiotools.core.display.DisplayMixin.specshow` + - :py:func:`audiotools.core.display.DisplayMixin.waveplot` + - :py:func:`audiotools.core.display.DisplayMixin.wavespec` + + Parameters + ---------- + fig_size : tuple, optional + Size of figure, by default (9, 3) + title : str, optional + Title to inset in top right, by default None + fig : matplotlib.figure.Figure, optional + Figure object, if None ``plt.gcf()`` will be used, by default None + format_axes : bool, optional + Format the axes to be inside the figure, by default True + format : bool, optional + This formatting can be skipped entirely by passing ``format=False`` + to any of the plotting functions that use this formater, by default True + font_color : str, optional + Color of font of axes, by default "white" + """ + import matplotlib + import matplotlib.pyplot as plt + + if fig_size is None: + fig_size = DEFAULT_FIG_SIZE + if not format: + return + if fig is None: + fig = plt.gcf() + fig.set_size_inches(*fig_size) + axs = fig.axes + + pixels = (fig.get_size_inches() * fig.dpi)[0] + font_scale = pixels / BASE_SIZE + + if format_axes: + axs = fig.axes + + for ax in axs: + ymin, _ = ax.get_ylim() + xmin, _ = ax.get_xlim() + + ticks = ax.get_yticks() + for t in ticks[2:-1]: + t = axs[0].annotate( + f"{(t / 1000):2.1f}k", + xy=(xmin, t), + xycoords="data", + xytext=(5, -5), + textcoords="offset points", + ha="left", + va="top", + color=font_color, + fontsize=12 * font_scale, + alpha=0.75, + ) + + ticks = ax.get_xticks()[2:] + for t in ticks[:-1]: + t = axs[0].annotate( + f"{t:2.1f}s", + xy=(t, ymin), + xycoords="data", + xytext=(5, 5), + textcoords="offset points", + ha="center", + va="bottom", + color=font_color, + fontsize=12 * font_scale, + alpha=0.75, + ) + + ax.margins(0, 0) + ax.set_axis_off() + ax.xaxis.set_major_locator(plt.NullLocator()) + ax.yaxis.set_major_locator(plt.NullLocator()) + + plt.subplots_adjust( + top=1, bottom=0, right=1, left=0, hspace=0, wspace=0 + ) + + if title is not None: + t = axs[0].annotate( + title, + xy=(1, 1), + xycoords="axes fraction", + fontsize=20 * font_scale, + xytext=(-5, -5), + textcoords="offset points", + ha="right", + va="top", + color="white", + ) + t.set_bbox(dict(facecolor="black", alpha=0.5, edgecolor="black")) + + +def generate_chord_dataset( + max_voices: int = 8, + sample_rate: int = 44100, + num_items: int = 5, + duration: float = 1.0, + min_note: str = "C2", + max_note: str = "C6", + output_dir: Path = "chords", +): + """ + Generates a toy multitrack dataset of chords, synthesized from sine waves. + + + Parameters + ---------- + max_voices : int, optional + Maximum number of voices in a chord, by default 8 + sample_rate : int, optional + Sample rate of audio, by default 44100 + num_items : int, optional + Number of items to generate, by default 5 + duration : float, optional + Duration of each item, by default 1.0 + min_note : str, optional + Minimum note in the dataset, by default "C2" + max_note : str, optional + Maximum note in the dataset, by default "C6" + output_dir : Path, optional + Directory to save the dataset, by default "chords" + + """ + import librosa + from . import AudioSignal + from ..data.preprocess import create_csv + + min_midi = librosa.note_to_midi(min_note) + max_midi = librosa.note_to_midi(max_note) + + tracks = [] + for idx in range(num_items): + track = {} + # figure out how many voices to put in this track + num_voices = random.randint(1, max_voices) + for voice_idx in range(num_voices): + # choose some random params + midinote = random.randint(min_midi, max_midi) + dur = random.uniform(0.85 * duration, duration) + + sig = AudioSignal.wave( + frequency=librosa.midi_to_hz(midinote), + duration=dur, + sample_rate=sample_rate, + shape="sine", + ) + track[f"voice_{voice_idx}"] = sig + tracks.append(track) + + # save the tracks to disk + output_dir = Path(output_dir) + output_dir.mkdir(exist_ok=True) + for idx, track in enumerate(tracks): + track_dir = output_dir / f"track_{idx}" + track_dir.mkdir(exist_ok=True) + for voice_name, sig in track.items(): + sig.write(track_dir / f"{voice_name}.wav") + + all_voices = list(set([k for track in tracks for k in track.keys()])) + voice_lists = {voice: [] for voice in all_voices} + for track in tracks: + for voice_name in all_voices: + if voice_name in track: + voice_lists[voice_name].append(track[voice_name].path_to_file) + else: + voice_lists[voice_name].append("") + + for voice_name, paths in voice_lists.items(): + create_csv(paths, output_dir / f"{voice_name}.csv", loudness=True) + + return output_dir