diff --git a/third_party/nnAudio/nnAudio/Spectrogram.py b/third_party/nnAudio/nnAudio/Spectrogram.py new file mode 100755 index 00000000..c92046ee --- /dev/null +++ b/third_party/nnAudio/nnAudio/Spectrogram.py @@ -0,0 +1,2440 @@ +""" +Module containing all the spectrogram classes +""" + +# 0.2.0 + +import torch +import torch.nn as nn +from torch.nn.functional import conv1d, conv2d, fold +import scipy # used only in CFP + +import numpy as np +from time import time + +# from nnAudio.librosa_functions import * # For debug purpose +# from nnAudio.utils import * + +from .librosa_functions import * +from .utils import * + +sz_float = 4 # size of a float +epsilon = 10e-8 # fudge factor for normalization + +### --------------------------- Spectrogram Classes ---------------------------### +class STFT(torch.nn.Module): + """This function is to calculate the short-time Fourier transform (STFT) of the input signal. + Input signal should be in either of the following shapes.\n + 1. ``(len_audio)``\n + 2. ``(num_audio, len_audio)``\n + 3. ``(num_audio, 1, len_audio)`` + + The correct shape will be inferred automatically if the input follows these 3 shapes. + Most of the arguments follow the convention from librosa. + This class inherits from ``torch.nn.Module``, therefore, the usage is same as ``torch.nn.Module``. + + Parameters + ---------- + n_fft : int + Size of Fourier transform. Default value is 2048. + + win_length : int + the size of window frame and STFT filter. + Default: None (treated as equal to n_fft) + + freq_bins : int + Number of frequency bins. Default is ``None``, which means ``n_fft//2+1`` bins. + + hop_length : int + The hop (or stride) size. Default value is ``None`` which is equivalent to ``n_fft//4``. + + window : str + The windowing function for STFT. It uses ``scipy.signal.get_window``, please refer to + scipy documentation for possible windowing functions. The default value is 'hann'. + + freq_scale : 'linear', 'log', or 'no' + Determine the spacing between each frequency bin. When `linear` or `log` is used, + the bin spacing can be controlled by ``fmin`` and ``fmax``. If 'no' is used, the bin will + start at 0Hz and end at Nyquist frequency with linear spacing. + + center : bool + Putting the STFT keneral at the center of the time-step or not. If ``False``, the time + index is the beginning of the STFT kernel, if ``True``, the time index is the center of + the STFT kernel. Default value if ``True``. + + pad_mode : str + The padding method. Default value is 'reflect'. + + iSTFT : bool + To activate the iSTFT module or not. By default, it is False to save GPU memory. + Note: The iSTFT kernel is not trainable. If you want + a trainable iSTFT, use the iSTFT module. + + fmin : int + The starting frequency for the lowest frequency bin. If freq_scale is ``no``, this argument + does nothing. + + fmax : int + The ending frequency for the highest frequency bin. If freq_scale is ``no``, this argument + does nothing. + + sr : int + The sampling rate for the input audio. It is used to calucate the correct ``fmin`` and ``fmax``. + Setting the correct sampling rate is very important for calculating the correct frequency. + + trainable : bool + Determine if the STFT kenrels are trainable or not. If ``True``, the gradients for STFT + kernels will also be caluclated and the STFT kernels will be updated during model training. + Default value is ``False`` + + output_format : str + Control the spectrogram output type, either ``Magnitude``, ``Complex``, or ``Phase``. + The output_format can also be changed during the ``forward`` method. + + verbose : bool + If ``True``, it shows layer information. If ``False``, it suppresses all prints + + Returns + ------- + spectrogram : torch.tensor + It returns a tensor of spectrograms. + ``shape = (num_samples, freq_bins,time_steps)`` if ``output_format='Magnitude'``; + ``shape = (num_samples, freq_bins,time_steps, 2)`` if ``output_format='Complex' or 'Phase'``; + + Examples + -------- + >>> spec_layer = Spectrogram.STFT() + >>> specs = spec_layer(x) + """ + + def __init__(self, n_fft=2048, win_length=None, freq_bins=None, hop_length=None, window='hann', + freq_scale='no', center=True, pad_mode='reflect', iSTFT=False, + fmin=50, fmax=6000, sr=22050, trainable=False, + output_format="Complex", verbose=True): + + super().__init__() + + # Trying to make the default setting same as librosa + if win_length==None: win_length = n_fft + if hop_length==None: hop_length = int(win_length // 4) + + self.output_format = output_format + self.trainable = trainable + self.stride = hop_length + self.center = center + self.pad_mode = pad_mode + self.n_fft = n_fft + self.freq_bins = freq_bins + self.trainable = trainable + self.pad_amount = self.n_fft // 2 + self.window = window + self.win_length = win_length + self.iSTFT = iSTFT + self.trainable = trainable + start = time() + + + + # Create filter windows for stft + kernel_sin, kernel_cos, self.bins2freq, self.bin_list, window_mask = create_fourier_kernels(n_fft, + win_length=win_length, + freq_bins=freq_bins, + window=window, + freq_scale=freq_scale, + fmin=fmin, + fmax=fmax, + sr=sr, + verbose=verbose) + + + kernel_sin = torch.tensor(kernel_sin, dtype=torch.float) + kernel_cos = torch.tensor(kernel_cos, dtype=torch.float) + + # In this way, the inverse kernel and the forward kernel do not share the same memory... + kernel_sin_inv = torch.cat((kernel_sin, -kernel_sin[1:-1].flip(0)), 0) + kernel_cos_inv = torch.cat((kernel_cos, kernel_cos[1:-1].flip(0)), 0) + + + + if iSTFT: + self.register_buffer('kernel_sin_inv', kernel_sin_inv.unsqueeze(-1)) + self.register_buffer('kernel_cos_inv', kernel_cos_inv.unsqueeze(-1)) + + # Making all these variables nn.Parameter, so that the model can be used with nn.Parallel +# self.kernel_sin = torch.nn.Parameter(self.kernel_sin, requires_grad=self.trainable) +# self.kernel_cos = torch.nn.Parameter(self.kernel_cos, requires_grad=self.trainable) + + # Applying window functions to the Fourier kernels + window_mask = torch.tensor(window_mask) + wsin = kernel_sin * window_mask + wcos = kernel_cos * window_mask + + if self.trainable==False: + self.register_buffer('wsin', wsin) + self.register_buffer('wcos', wcos) + + if self.trainable==True: + wsin = torch.nn.Parameter(wsin, requires_grad=self.trainable) + wcos = torch.nn.Parameter(wcos, requires_grad=self.trainable) + self.register_parameter('wsin', wsin) + self.register_parameter('wcos', wcos) + + + # Prepare the shape of window mask so that it can be used later in inverse + self.register_buffer('window_mask', window_mask.unsqueeze(0).unsqueeze(-1)) + + + + if verbose==True: + print("STFT kernels created, time used = {:.4f} seconds".format(time()-start)) + else: + pass + + def forward(self, x, output_format=None): + """ + Convert a batch of waveforms to spectrograms. + + Parameters + ---------- + x : torch tensor + Input signal should be in either of the following shapes.\n + 1. ``(len_audio)``\n + 2. ``(num_audio, len_audio)``\n + 3. ``(num_audio, 1, len_audio)`` + It will be automatically broadcast to the right shape + + output_format : str + Control the type of spectrogram to be return. Can be either ``Magnitude`` or ``Complex`` or ``Phase``. + Default value is ``Complex``. + + """ + output_format = output_format or self.output_format + self.num_samples = x.shape[-1] + + x = broadcast_dim(x) + if self.center: + if self.pad_mode == 'constant': + padding = nn.ConstantPad1d(self.pad_amount, 0) + + elif self.pad_mode == 'reflect': + if self.num_samples < self.pad_amount: + raise AssertionError("Signal length shorter than reflect padding length (n_fft // 2).") + padding = nn.ReflectionPad1d(self.pad_amount) + + x = padding(x) + spec_imag = conv1d(x, self.wsin, stride=self.stride) + spec_real = conv1d(x, self.wcos, stride=self.stride) # Doing STFT by using conv1d + + # remove redundant parts + spec_real = spec_real[:, :self.freq_bins, :] + spec_imag = spec_imag[:, :self.freq_bins, :] + + if output_format=='Magnitude': + spec = spec_real.pow(2) + spec_imag.pow(2) + if self.trainable==True: + return torch.sqrt(spec+1e-8) # prevent Nan gradient when sqrt(0) due to output=0 + else: + return torch.sqrt(spec) + + elif output_format=='Complex': + return torch.stack((spec_real,-spec_imag), -1) # Remember the minus sign for imaginary part + + elif output_format=='Phase': + return torch.atan2(-spec_imag+0.0,spec_real) # +0.0 removes -0.0 elements, which leads to error in calculating phase + + def inverse(self, X, onesided=True, length=None, refresh_win=True): + """ + This function is same as the :func:`~nnAudio.Spectrogram.iSTFT` class, + which is to convert spectrograms back to waveforms. + It only works for the complex value spectrograms. If you have the magnitude spectrograms, + please use :func:`~nnAudio.Spectrogram.Griffin_Lim`. + + Parameters + ---------- + onesided : bool + If your spectrograms only have ``n_fft//2+1`` frequency bins, please use ``onesided=True``, + else use ``onesided=False`` + + length : int + To make sure the inverse STFT has the same output length of the original waveform, please + set `length` as your intended waveform length. By default, ``length=None``, + which will remove ``n_fft//2`` samples from the start and the end of the output. + + refresh_win : bool + Recalculating the window sum square. If you have an input with fixed number of timesteps, + you can increase the speed by setting ``refresh_win=False``. Else please keep ``refresh_win=True`` + + + """ + if (hasattr(self, 'kernel_sin_inv') != True) or (hasattr(self, 'kernel_cos_inv') != True): + raise NameError("Please activate the iSTFT module by setting `iSTFT=True` if you want to use `inverse`") + + assert X.dim()==4 , "Inverse iSTFT only works for complex number," \ + "make sure our tensor is in the shape of (batch, freq_bins, timesteps, 2)."\ + "\nIf you have a magnitude spectrogram, please consider using Griffin-Lim." + if onesided: + X = extend_fbins(X) # extend freq + + + X_real, X_imag = X[:, :, :, 0], X[:, :, :, 1] + + # broadcast dimensions to support 2D convolution + X_real_bc = X_real.unsqueeze(1) + X_imag_bc = X_imag.unsqueeze(1) + a1 = conv2d(X_real_bc, self.kernel_cos_inv, stride=(1,1)) + b2 = conv2d(X_imag_bc, self.kernel_sin_inv, stride=(1,1)) + + # compute real and imag part. signal lies in the real part + real = a1 - b2 + real = real.squeeze(-2)*self.window_mask + + # Normalize the amplitude with n_fft + real /= (self.n_fft) + + # Overlap and Add algorithm to connect all the frames + real = overlap_add(real, self.stride) + + # Prepare the window sumsqure for division + # Only need to create this window once to save time + # Unless the input spectrograms have different time steps + if hasattr(self, 'w_sum')==False or refresh_win==True: + self.w_sum = torch_window_sumsquare(self.window_mask.flatten(), X.shape[2], self.stride, self.n_fft).flatten() + self.nonzero_indices = (self.w_sum>1e-10) + else: + pass + real[:, self.nonzero_indices] = real[:,self.nonzero_indices].div(self.w_sum[self.nonzero_indices]) + # Remove padding + if length is None: + if self.center: + real = real[:, self.pad_amount:-self.pad_amount] + + else: + if self.center: + real = real[:, self.pad_amount:self.pad_amount + length] + else: + real = real[:, :length] + + return real + + def extra_repr(self) -> str: + return 'n_fft={}, Fourier Kernel size={}, iSTFT={}, trainable={}'.format( + self.n_fft, (*self.wsin.shape,), self.iSTFT, self.trainable + ) + + +class MelSpectrogram(torch.nn.Module): + """This function is to calculate the Melspectrogram of the input signal. + Input signal should be in either of the following shapes.\n + 1. ``(len_audio)``\n + 2. ``(num_audio, len_audio)``\n + 3. ``(num_audio, 1, len_audio)`` + + The correct shape will be inferred automatically if the input follows these 3 shapes. + Most of the arguments follow the convention from librosa. + This class inherits from ``torch.nn.Module``, therefore, the usage is same as ``torch.nn.Module``. + + Parameters + ---------- + sr : int + The sampling rate for the input audio. + It is used to calculate the correct ``fmin`` and ``fmax``. + Setting the correct sampling rate is very important for calculating the correct frequency. + + n_fft : int + The window size for the STFT. Default value is 2048 + + win_length : int + the size of window frame and STFT filter. + Default: None (treated as equal to n_fft) + + n_mels : int + The number of Mel filter banks. The filter banks maps the n_fft to mel bins. + Default value is 128. + + hop_length : int + The hop (or stride) size. Default value is 512. + + window : str + The windowing function for STFT. It uses ``scipy.signal.get_window``, please refer to + scipy documentation for possible windowing functions. The default value is 'hann'. + + center : bool + Putting the STFT keneral at the center of the time-step or not. If ``False``, + the time index is the beginning of the STFT kernel, if ``True``, the time index is the + center of the STFT kernel. Default value if ``True``. + + pad_mode : str + The padding method. Default value is 'reflect'. + + htk : bool + When ``False`` is used, the Mel scale is quasi-logarithmic. When ``True`` is used, the + Mel scale is logarithmic. The default value is ``False``. + + fmin : int + The starting frequency for the lowest Mel filter bank. + + fmax : int + The ending frequency for the highest Mel filter bank. + + norm : + if 1, divide the triangular mel weights by the width of the mel band + (area normalization, AKA 'slaney' default in librosa). + Otherwise, leave all the triangles aiming for + a peak value of 1.0 + + trainable_mel : bool + Determine if the Mel filter banks are trainable or not. If ``True``, the gradients for Mel + filter banks will also be calculated and the Mel filter banks will be updated during model + training. Default value is ``False``. + + trainable_STFT : bool + Determine if the STFT kenrels are trainable or not. If ``True``, the gradients for STFT + kernels will also be caluclated and the STFT kernels will be updated during model training. + Default value is ``False``. + + verbose : bool + If ``True``, it shows layer information. If ``False``, it suppresses all prints. + + Returns + ------- + spectrogram : torch.tensor + It returns a tensor of spectrograms. shape = ``(num_samples, freq_bins,time_steps)``. + + Examples + -------- + >>> spec_layer = Spectrogram.MelSpectrogram() + >>> specs = spec_layer(x) + """ + + def __init__(self, sr=22050, n_fft=2048, win_length=None, n_mels=128, hop_length=512, + window='hann', center=True, pad_mode='reflect', power=2.0, htk=False, + fmin=0.0, fmax=None, norm=1, trainable_mel=False, trainable_STFT=False, + verbose=True, **kwargs): + + super().__init__() + self.stride = hop_length + self.center = center + self.pad_mode = pad_mode + self.n_fft = n_fft + self.power = power + self.trainable_mel = trainable_mel + self.trainable_STFT = trainable_STFT + + # Preparing for the stft layer. No need for center + self.stft = STFT(n_fft=n_fft, win_length=win_length, freq_bins=None, + hop_length=hop_length, window=window, freq_scale='no', + center=center, pad_mode=pad_mode, sr=sr, trainable=trainable_STFT, + output_format="Magnitude", verbose=verbose, **kwargs) + + + # Create filter windows for stft + start = time() + + # Creating kernel for mel spectrogram + start = time() + mel_basis = mel(sr, n_fft, n_mels, fmin, fmax, htk=htk, norm=norm) + mel_basis = torch.tensor(mel_basis) + + if verbose==True: + print("STFT filter created, time used = {:.4f} seconds".format(time()-start)) + print("Mel filter created, time used = {:.4f} seconds".format(time()-start)) + else: + pass + + if trainable_mel: + # Making everything nn.Parameter, so that this model can support nn.DataParallel + mel_basis = torch.nn.Parameter(mel_basis, requires_grad=trainable_mel) + self.register_parameter('mel_basis', mel_basis) + else: + self.register_buffer('mel_basis', mel_basis) + + # if trainable_mel==True: + # self.mel_basis = torch.nn.Parameter(self.mel_basis) + # if trainable_STFT==True: + # self.wsin = torch.nn.Parameter(self.wsin) + # self.wcos = torch.nn.Parameter(self.wcos) + + def forward(self, x): + """ + Convert a batch of waveforms to Mel spectrograms. + + Parameters + ---------- + x : torch tensor + Input signal should be in either of the following shapes.\n + 1. ``(len_audio)``\n + 2. ``(num_audio, len_audio)``\n + 3. ``(num_audio, 1, len_audio)`` + It will be automatically broadcast to the right shape + """ + x = broadcast_dim(x) + + spec = self.stft(x, output_format='Magnitude')**self.power + + melspec = torch.matmul(self.mel_basis, spec) + return melspec + + def extra_repr(self) -> str: + return 'Mel filter banks size = {}, trainable_mel={}'.format( + (*self.mel_basis.shape,), self.trainable_mel, self.trainable_STFT + ) + + +class MFCC(torch.nn.Module): + """This function is to calculate the Mel-frequency cepstral coefficients (MFCCs) of the input signal. + This algorithm first extracts Mel spectrograms from the audio clips, + then the discrete cosine transform is calcuated to obtain the final MFCCs. + Therefore, the Mel spectrogram part can be made trainable using + ``trainable_mel`` and ``trainable_STFT``. + It only support type-II DCT at the moment. Input signal should be in either of the following shapes.\n + 1. ``(len_audio)``\n + 2. ``(num_audio, len_audio)``\n + 3. ``(num_audio, 1, len_audio)`` + + The correct shape will be inferred autommatically if the input follows these 3 shapes. + Most of the arguments follow the convention from librosa. + This class inherits from ``torch.nn.Module``, therefore, the usage is same as ``torch.nn.Module``. + + Parameters + ---------- + sr : int + The sampling rate for the input audio. It is used to calculate the correct ``fmin`` and ``fmax``. + Setting the correct sampling rate is very important for calculating the correct frequency. + + n_mfcc : int + The number of Mel-frequency cepstral coefficients + + norm : string + The default value is 'ortho'. Normalization for DCT basis + + **kwargs + Other arguments for Melspectrogram such as n_fft, n_mels, hop_length, and window + + Returns + ------- + MFCCs : torch.tensor + It returns a tensor of MFCCs. shape = ``(num_samples, n_mfcc, time_steps)``. + + Examples + -------- + >>> spec_layer = Spectrogram.MFCC() + >>> mfcc = spec_layer(x) + """ + + def __init__(self, sr=22050, n_mfcc=20, norm='ortho', verbose=True, ref=1.0, amin=1e-10, top_db=80.0, **kwargs): + super().__init__() + self.melspec_layer = MelSpectrogram(sr=sr, verbose=verbose, **kwargs) + self.m_mfcc = n_mfcc + + # attributes that will be used for _power_to_db + if amin <= 0: + raise ParameterError('amin must be strictly positive') + amin = torch.tensor([amin]) + ref = torch.abs(torch.tensor([ref])) + self.register_buffer('amin', amin) + self.register_buffer('ref', ref) + self.top_db = top_db + self.n_mfcc = n_mfcc + + def _power_to_db(self, S): + ''' + Refer to https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html#power_to_db + for the original implmentation. + ''' + + log_spec = 10.0 * torch.log10(torch.max(S, self.amin)) + log_spec -= 10.0 * torch.log10(torch.max(self.amin, self.ref)) + if self.top_db is not None: + if self.top_db < 0: + raise ParameterError('top_db must be non-negative') + + # make the dim same as log_spec so that it can be broadcasted + batch_wise_max = log_spec.flatten(1).max(1)[0].unsqueeze(1).unsqueeze(1) + log_spec = torch.max(log_spec, batch_wise_max - self.top_db) + + return log_spec + + def _dct(self, x, norm=None): + ''' + Refer to https://github.com/zh217/torch-dct for the original implmentation. + ''' + x = x.permute(0,2,1) # make freq the last axis, since dct applies to the frequency axis + x_shape = x.shape + N = x_shape[-1] + + v = torch.cat([x[:, :, ::2], x[:, :, 1::2].flip([2])], dim=2) + Vc = torch.rfft(v, 1, onesided=False) + + # TODO: Can make the W_r and W_i trainable here + k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N) + W_r = torch.cos(k) + W_i = torch.sin(k) + + V = Vc[:, :, :, 0] * W_r - Vc[:, :, :, 1] * W_i + + if norm == 'ortho': + V[:, :, 0] /= np.sqrt(N) * 2 + V[:, :, 1:] /= np.sqrt(N / 2) * 2 + + V = 2 * V + + return V.permute(0,2,1) # swapping back the time axis and freq axis + + def forward(self, x): + """ + Convert a batch of waveforms to MFCC. + + Parameters + ---------- + x : torch tensor + Input signal should be in either of the following shapes.\n + 1. ``(len_audio)``\n + 2. ``(num_audio, len_audio)``\n + 3. ``(num_audio, 1, len_audio)`` + It will be automatically broadcast to the right shape + """ + + x = self.melspec_layer(x) + x = self._power_to_db(x) + x = self._dct(x, norm='ortho')[:,:self.m_mfcc,:] + return x + + def extra_repr(self) -> str: + return 'n_mfcc = {}'.format( + (self.n_mfcc) + ) + + +class Gammatonegram(torch.nn.Module): + """ + This function is to calculate the Gammatonegram of the input signal. Input signal should be in either of the following shapes. 1. ``(len_audio)``, 2. ``(num_audio, len_audio)``, 3. ``(num_audio, 1, len_audio)``. The correct shape will be inferred autommatically if the input follows these 3 shapes. This class inherits from ``torch.nn.Module``, therefore, the usage is same as ``torch.nn.Module``. + + Parameters + ---------- + sr : int + The sampling rate for the input audio. It is used to calucate the correct ``fmin`` and ``fmax``. Setting the correct sampling rate is very important for calculating the correct frequency. + n_fft : int + The window size for the STFT. Default value is 2048 + n_mels : int + The number of Gammatonegram filter banks. The filter banks maps the n_fft to Gammatone bins. Default value is 64 + + hop_length : int + The hop (or stride) size. Default value is 512. + window : str + The windowing function for STFT. It uses ``scipy.signal.get_window``, please refer to scipy documentation for possible windowing functions. The default value is 'hann' + center : bool + Putting the STFT keneral at the center of the time-step or not. If ``False``, the time index is the beginning of the STFT kernel, if ``True``, the time index is the center of the STFT kernel. Default value if ``True``. + pad_mode : str + The padding method. Default value is 'reflect'. + htk : bool + When ``False`` is used, the Mel scale is quasi-logarithmic. When ``True`` is used, the Mel scale is logarithmic. The default value is ``False`` + + fmin : int + The starting frequency for the lowest Gammatone filter bank + fmax : int + The ending frequency for the highest Gammatone filter bank + trainable_mel : bool + Determine if the Gammatone filter banks are trainable or not. If ``True``, the gradients for Mel filter banks will also be caluclated and the Mel filter banks will be updated during model training. Default value is ``False`` + trainable_STFT : bool + Determine if the STFT kenrels are trainable or not. If ``True``, the gradients for STFT kernels will also be caluclated and the STFT kernels will be updated during model training. Default value is ``False`` + + verbose : bool + If ``True``, it shows layer information. If ``False``, it suppresses all prints + + Returns + ------- + spectrogram : torch.tensor + It returns a tensor of spectrograms. shape = ``(num_samples, freq_bins,time_steps)``. + + Examples + -------- + >>> spec_layer = Spectrogram.Gammatonegram() + >>> specs = spec_layer(x) + """ + + def __init__(self, sr=44100, n_fft=2048, n_bins=64, hop_length=512, window='hann', center=True, pad_mode='reflect', + power=2.0, htk=False, fmin=20.0, fmax=None, norm=1, trainable_bins=False, trainable_STFT=False, + verbose=True): + super(Gammatonegram, self).__init__() + self.stride = hop_length + self.center = center + self.pad_mode = pad_mode + self.n_fft = n_fft + self.power = power + + # Create filter windows for stft + start = time() + wsin, wcos, self.bins2freq, _, _ = create_fourier_kernels(n_fft, freq_bins=None, window=window, freq_scale='no', + sr=sr) + + wsin = torch.tensor(wsin, dtype=torch.float) + wcos = torch.tensor(wcos, dtype=torch.float) + + if trainable_STFT: + wsin = torch.nn.Parameter(wsin, requires_grad=trainable_STFT) + wcos = torch.nn.Parameter(wcos, requires_grad=trainable_STFT) + self.register_parameter('wsin', wsin) + self.register_parameter('wcos', wcos) + else: + self.register_buffer('wsin', wsin) + self.register_buffer('wcos', wcos) + + # Creating kenral for Gammatone spectrogram + start = time() + gammatone_basis = gammatone(sr, n_fft, n_bins, fmin, fmax) + gammatone_basis = torch.tensor(gammatone_basis) + + if verbose == True: + print("STFT filter created, time used = {:.4f} seconds".format(time() - start)) + print("Gammatone filter created, time used = {:.4f} seconds".format(time() - start)) + else: + pass + # Making everything nn.Prarmeter, so that this model can support nn.DataParallel + + if trainable_bins: + gammatone_basis = torch.nn.Parameter(gammatone_basis, requires_grad=trainable_bins) + self.register_parameter('gammatone_basis', gammatone_basis) + else: + self.register_buffer('gammatone_basis', gammatone_basis) + + # if trainable_mel==True: + # self.mel_basis = torch.nn.Parameter(self.mel_basis) + # if trainable_STFT==True: + # self.wsin = torch.nn.Parameter(self.wsin) + # self.wcos = torch.nn.Parameter(self.wcos) + + def forward(self, x): + x = broadcast_dim(x) + if self.center: + if self.pad_mode == 'constant': + padding = nn.ConstantPad1d(self.n_fft // 2, 0) + elif self.pad_mode == 'reflect': + padding = nn.ReflectionPad1d(self.n_fft // 2) + + x = padding(x) + + spec = torch.sqrt(conv1d(x, self.wsin, stride=self.stride).pow(2) \ + + conv1d(x, self.wcos, stride=self.stride).pow(2)) ** self.power # Doing STFT by using conv1d + + gammatonespec = torch.matmul(self.gammatone_basis, spec) + return gammatonespec + + +class CQT1992(torch.nn.Module): + """ + This alogrithm uses the method proposed in [1], which would run extremely slow if low frequencies (below 220Hz) + are included in the frequency bins. + Please refer to :func:`~nnAudio.Spectrogram.CQT1992v2` for a more + computational and memory efficient version. + [1] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of a + constant Q transform.” (1992). + + This function is to calculate the CQT of the input signal. + Input signal should be in either of the following shapes.\n + 1. ``(len_audio)``\n + 2. ``(num_audio, len_audio)``\n + 3. ``(num_audio, 1, len_audio)`` + + The correct shape will be inferred autommatically if the input follows these 3 shapes. + Most of the arguments follow the convention from librosa. + This class inherits from ``torch.nn.Module``, therefore, the usage is same as ``torch.nn.Module``. + + + + Parameters + ---------- + sr : int + The sampling rate for the input audio. It is used to calucate the correct ``fmin`` and ``fmax``. + Setting the correct sampling rate is very important for calculating the correct frequency. + + hop_length : int + The hop (or stride) size. Default value is 512. + + fmin : float + The frequency for the lowest CQT bin. Default is 32.70Hz, which coresponds to the note C0. + + fmax : float + The frequency for the highest CQT bin. Default is ``None``, therefore the higest CQT bin is + inferred from the ``n_bins`` and ``bins_per_octave``. + If ``fmax`` is not ``None``, then the argument ``n_bins`` will be ignored and ``n_bins`` + will be calculated automatically. Default is ``None`` + + n_bins : int + The total numbers of CQT bins. Default is 84. Will be ignored if ``fmax`` is not ``None``. + + bins_per_octave : int + Number of bins per octave. Default is 12. + + trainable_STFT : bool + Determine if the time to frequency domain transformation kernel for the input audio is trainable or not. + Default is ``False`` + + trainable_CQT : bool + Determine if the frequency domain CQT kernel is trainable or not. + Default is ``False`` + + norm : int + Normalization for the CQT kernels. ``1`` means L1 normalization, and ``2`` means L2 normalization. + Default is ``1``, which is same as the normalization used in librosa. + + window : str + The windowing function for CQT. It uses ``scipy.signal.get_window``, please refer to + scipy documentation for possible windowing functions. The default value is 'hann'. + + center : bool + Putting the CQT keneral at the center of the time-step or not. If ``False``, the time index is + the beginning of the CQT kernel, if ``True``, the time index is the center of the CQT kernel. + Default value if ``True``. + + pad_mode : str + The padding method. Default value is 'reflect'. + + trainable : bool + Determine if the CQT kernels are trainable or not. If ``True``, the gradients for CQT kernels + will also be caluclated and the CQT kernels will be updated during model training. + Default value is ``False``. + + output_format : str + Determine the return type. + ``Magnitude`` will return the magnitude of the STFT result, shape = ``(num_samples, freq_bins,time_steps)``; + ``Complex`` will return the STFT result in complex number, shape = ``(num_samples, freq_bins,time_steps, 2)``; + ``Phase`` will return the phase of the STFT reuslt, shape = ``(num_samples, freq_bins,time_steps, 2)``. + The complex number is stored as ``(real, imag)`` in the last axis. Default value is 'Magnitude'. + + verbose : bool + If ``True``, it shows layer information. If ``False``, it suppresses all prints + + Returns + ------- + spectrogram : torch.tensor + It returns a tensor of spectrograms. + shape = ``(num_samples, freq_bins,time_steps)`` if ``output_format='Magnitude'``; + shape = ``(num_samples, freq_bins,time_steps, 2)`` if ``output_format='Complex' or 'Phase'``; + + Examples + -------- + >>> spec_layer = Spectrogram.CQT1992v2() + >>> specs = spec_layer(x) + """ + + def __init__(self, sr=22050, hop_length=512, fmin=220, fmax=None, n_bins=84, + trainable_STFT=False, trainable_CQT=False, bins_per_octave=12, filter_scale=1, + output_format='Magnitude', norm=1, window='hann', center=True, pad_mode='reflect'): + + super().__init__() + + # norm arg is not functioning + self.hop_length = hop_length + self.center = center + self.pad_mode = pad_mode + self.norm = norm + self.output_format = output_format + + # creating kernels for CQT + Q = float(filter_scale)/(2**(1/bins_per_octave)-1) + + print("Creating CQT kernels ...", end='\r') + start = time() + cqt_kernels, self.kernel_width, lenghts, freqs = create_cqt_kernels(Q, + sr, + fmin, + n_bins, + bins_per_octave, + norm, + window, + fmax) + + self.register_buffer('lenghts', lenghts) + self.frequencies = freqs + + cqt_kernels = fft(cqt_kernels)[:,:self.kernel_width//2+1] + print("CQT kernels created, time used = {:.4f} seconds".format(time()-start)) + + # creating kernels for stft + # self.cqt_kernels_real*=lenghts.unsqueeze(1)/self.kernel_width # Trying to normalize as librosa + # self.cqt_kernels_imag*=lenghts.unsqueeze(1)/self.kernel_width + + print("Creating STFT kernels ...", end='\r') + start = time() + kernel_sin, kernel_cos, self.bins2freq, _, window = create_fourier_kernels(self.kernel_width, + window='ones', + freq_scale='no') + + # Converting kernels from numpy arrays to torch tensors + wsin = torch.tensor(kernel_sin * window) + wcos = torch.tensor(kernel_cos * window) + + cqt_kernels_real = torch.tensor(cqt_kernels.real.astype(np.float32)) + cqt_kernels_imag = torch.tensor(cqt_kernels.imag.astype(np.float32)) + + if trainable_STFT: + wsin = torch.nn.Parameter(wsin, requires_grad=trainable_STFT) + wcos = torch.nn.Parameter(wcos, requires_grad=trainable_STFT) + self.register_parameter('wsin', wsin) + self.register_parameter('wcos', wcos) + else: + self.register_buffer('wsin', wsin) + self.register_buffer('wcos', wcos) + + if trainable_CQT: + cqt_kernels_real = torch.nn.Parameter(cqt_kernels_real, requires_grad=trainable_CQT) + cqt_kernels_imag = torch.nn.Parameter(cqt_kernels_imag, requires_grad=trainable_CQT) + self.register_parameter('cqt_kernels_real', cqt_kernels_real) + self.register_parameter('cqt_kernels_imag', cqt_kernels_imag) + else: + self.register_buffer('cqt_kernels_real', cqt_kernels_real) + self.register_buffer('cqt_kernels_imag', cqt_kernels_imag) + + print("STFT kernels created, time used = {:.4f} seconds".format(time()-start)) + + def forward(self, x, output_format=None, normalization_type='librosa'): + """ + Convert a batch of waveforms to CQT spectrograms. + + Parameters + ---------- + x : torch tensor + Input signal should be in either of the following shapes.\n + 1. ``(len_audio)``\n + 2. ``(num_audio, len_audio)``\n + 3. ``(num_audio, 1, len_audio)`` + It will be automatically broadcast to the right shape + """ + output_format = output_format or self.output_format + + x = broadcast_dim(x) + if self.center: + if self.pad_mode == 'constant': + padding = nn.ConstantPad1d(self.kernel_width//2, 0) + elif self.pad_mode == 'reflect': + padding = nn.ReflectionPad1d(self.kernel_width//2) + + x = padding(x) + + # STFT + fourier_real = conv1d(x, self.wcos, stride=self.hop_length) + fourier_imag = conv1d(x, self.wsin, stride=self.hop_length) + + # CQT + CQT_real, CQT_imag = complex_mul((self.cqt_kernels_real, self.cqt_kernels_imag), + (fourier_real, fourier_imag)) + + CQT = torch.stack((CQT_real,-CQT_imag),-1) + + if normalization_type == 'librosa': + CQT *= torch.sqrt(self.lenghts.view(-1,1,1))/self.kernel_width + elif normalization_type == 'convolutional': + pass + elif normalization_type == 'wrap': + CQT *= 2/self.kernel_width + else: + raise ValueError("The normalization_type %r is not part of our current options." % normalization_type) + + +# if self.norm: +# CQT = CQT/self.kernel_width*torch.sqrt(self.lenghts.view(-1,1,1)) +# else: +# CQT = CQT*torch.sqrt(self.lenghts.view(-1,1,1)) + + if output_format=='Magnitude': + # Getting CQT Amplitude + return torch.sqrt(CQT.pow(2).sum(-1)) + + elif output_format=='Complex': + return CQT + + elif output_format=='Phase': + phase_real = torch.cos(torch.atan2(CQT_imag,CQT_real)) + phase_imag = torch.sin(torch.atan2(CQT_imag,CQT_real)) + return torch.stack((phase_real,phase_imag), -1) + + def extra_repr(self) -> str: + return 'STFT kernel size = {}, CQT kernel size = {}'.format( + (*self.wcos.shape,), (*self.cqt_kernels_real.shape,) + ) + + +class CQT2010(torch.nn.Module): + """ + This algorithm is using the resampling method proposed in [1]. + Instead of convoluting the STFT results with a gigantic CQT kernel covering the full frequency + spectrum, we make a small CQT kernel covering only the top octave. + Then we keep downsampling the input audio by a factor of 2 to convoluting it with the + small CQT kernel. Everytime the input audio is downsampled, the CQT relative to the downsampled + input is equavalent to the next lower octave. + The kernel creation process is still same as the 1992 algorithm. Therefore, we can reuse the code + from the 1992 alogrithm [2] + [1] Schörkhuber, Christian. “CONSTANT-Q TRANSFORM TOOLBOX FOR MUSIC PROCESSING.” (2010). + [2] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of a + constant Q transform.” (1992). + early downsampling factor is to downsample the input audio to reduce the CQT kernel size. + The result with and without early downsampling are more or less the same except in the very low + frequency region where freq < 40Hz. + """ + + def __init__(self, sr=22050, hop_length=512, fmin=32.70, fmax=None, n_bins=84, bins_per_octave=12, + norm=True, basis_norm=1, window='hann', pad_mode='reflect', trainable_STFT=False, filter_scale=1, + trainable_CQT=False, output_format='Magnitude', earlydownsample=True, verbose=True): + + super().__init__() + + self.norm = norm # Now norm is used to normalize the final CQT result by dividing n_fft + # basis_norm is for normalizing basis + self.hop_length = hop_length + self.pad_mode = pad_mode + self.n_bins = n_bins + self.output_format = output_format + self.earlydownsample = earlydownsample # TODO: activate early downsampling later if possible + + # This will be used to calculate filter_cutoff and creating CQT kernels + Q = float(filter_scale)/(2**(1/bins_per_octave)-1) + + # Creating lowpass filter and make it a torch tensor + if verbose==True: + print("Creating low pass filter ...", end='\r') + start = time() + lowpass_filter = torch.tensor(create_lowpass_filter( + band_center = 0.5, + kernelLength=256, + transitionBandwidth=0.001 + ) + ) + + # Broadcast the tensor to the shape that fits conv1d + self.register_buffer('lowpass_filter', lowpass_filter[None,None,:]) + + if verbose==True: + print("Low pass filter created, time used = {:.4f} seconds".format(time()-start)) + + # Calculate num of filter requires for the kernel + # n_octaves determines how many resampling requires for the CQT + n_filters = min(bins_per_octave, n_bins) + self.n_octaves = int(np.ceil(float(n_bins) / bins_per_octave)) + # print("n_octaves = ", self.n_octaves) + + # Calculate the lowest frequency bin for the top octave kernel + self.fmin_t = fmin*2**(self.n_octaves-1) + remainder = n_bins % bins_per_octave + # print("remainder = ", remainder) + + if remainder==0: + # Calculate the top bin frequency + fmax_t = self.fmin_t*2**((bins_per_octave-1)/bins_per_octave) + else: + # Calculate the top bin frequency + fmax_t = self.fmin_t*2**((remainder-1)/bins_per_octave) + + self.fmin_t = fmax_t/2**(1-1/bins_per_octave) # Adjusting the top minium bins + if fmax_t > sr/2: + raise ValueError('The top bin {}Hz has exceeded the Nyquist frequency, \ + please reduce the n_bins'.format(fmax_t)) + + if self.earlydownsample == True: # Do early downsampling if this argument is True + if verbose==True: + print("Creating early downsampling filter ...", end='\r') + start = time() + sr, self.hop_length, self.downsample_factor, early_downsample_filter, \ + self.earlydownsample = get_early_downsample_params(sr, + hop_length, + fmax_t, + Q, + self.n_octaves, + verbose) + + self.register_buffer('early_downsample_filter', early_downsample_filter) + if verbose==True: + print("Early downsampling filter created, \ + time used = {:.4f} seconds".format(time()-start)) + else: + self.downsample_factor=1. + + # Preparing CQT kernels + if verbose==True: + print("Creating CQT kernels ...", end='\r') + + start = time() + # print("Q = {}, fmin_t = {}, n_filters = {}".format(Q, self.fmin_t, n_filters)) + basis, self.n_fft, _, _ = create_cqt_kernels(Q, + sr, + self.fmin_t, + n_filters, + bins_per_octave, + norm=basis_norm, + topbin_check=False) + + # This is for the normalization in the end + freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave)) + self.frequencies = freqs + + lenghts = np.ceil(Q * sr / freqs) + lenghts = torch.tensor(lenghts).float() + self.register_buffer('lenghts', lenghts) + + + self.basis=basis + fft_basis = fft(basis)[:,:self.n_fft//2+1] # Convert CQT kenral from time domain to freq domain + + # These cqt_kernel is already in the frequency domain + cqt_kernels_real = torch.tensor(fft_basis.real.astype(np.float32)) + cqt_kernels_imag = torch.tensor(fft_basis.imag.astype(np.float32)) + + if verbose==True: + print("CQT kernels created, time used = {:.4f} seconds".format(time()-start)) + + # print("Getting cqt kernel done, n_fft = ",self.n_fft) + # Preparing kernels for Short-Time Fourier Transform (STFT) + # We set the frequency range in the CQT filter instead of here. + + if verbose==True: + print("Creating STFT kernels ...", end='\r') + + start = time() + kernel_sin, kernel_cos, self.bins2freq, _, window = create_fourier_kernels(self.n_fft, window='ones', freq_scale='no') + wsin = kernel_sin * window + wcos = kernel_cos * window + + wsin = torch.tensor(wsin) + wcos = torch.tensor(wcos) + + if verbose==True: + print("STFT kernels created, time used = {:.4f} seconds".format(time()-start)) + + if trainable_STFT: + wsin = torch.nn.Parameter(wsin, requires_grad=trainable_STFT) + wcos = torch.nn.Parameter(wcos, requires_grad=trainable_STFT) + self.register_parameter('wsin', wsin) + self.register_parameter('wcos', wcos) + else: + self.register_buffer('wsin', wsin) + self.register_buffer('wcos', wcos) + + if trainable_CQT: + cqt_kernels_real = torch.nn.Parameter(cqt_kernels_real, requires_grad=trainable_CQT) + cqt_kernels_imag = torch.nn.Parameter(cqt_kernels_imag, requires_grad=trainable_CQT) + self.register_parameter('cqt_kernels_real', cqt_kernels_real) + self.register_parameter('cqt_kernels_imag', cqt_kernels_imag) + else: + self.register_buffer('cqt_kernels_real', cqt_kernels_real) + self.register_buffer('cqt_kernels_imag', cqt_kernels_imag) + + # If center==True, the STFT window will be put in the middle, and paddings at the beginning + # and ending are required. + if self.pad_mode == 'constant': + self.padding = nn.ConstantPad1d(self.n_fft//2, 0) + elif self.pad_mode == 'reflect': + self.padding = nn.ReflectionPad1d(self.n_fft//2) + + + def forward(self,x, output_format=None, normalization_type='librosa'): + """ + Convert a batch of waveforms to CQT spectrograms. + + Parameters + ---------- + x : torch tensor + Input signal should be in either of the following shapes.\n + 1. ``(len_audio)``\n + 2. ``(num_audio, len_audio)``\n + 3. ``(num_audio, 1, len_audio)`` + It will be automatically broadcast to the right shape + """ + output_format = output_format or self.output_format + + x = broadcast_dim(x) + if self.earlydownsample==True: + x = downsampling_by_n(x, self.early_downsample_filter, self.downsample_factor) + hop = self.hop_length + + + + CQT = get_cqt_complex2(x, self.cqt_kernels_real, self.cqt_kernels_imag, hop, self.padding, + wcos=self.wcos, wsin=self.wsin) + + x_down = x # Preparing a new variable for downsampling + for i in range(self.n_octaves-1): + hop = hop//2 + x_down = downsampling_by_2(x_down, self.lowpass_filter) + + CQT1 = get_cqt_complex2(x_down, self.cqt_kernels_real, self.cqt_kernels_imag, hop, self.padding, + wcos=self.wcos, wsin=self.wsin) + CQT = torch.cat((CQT1, CQT),1) + + CQT = CQT[:,-self.n_bins:,:] # Removing unwanted top bins + + if normalization_type == 'librosa': + CQT *= torch.sqrt(self.lenghts.view(-1,1,1))/self.n_fft + elif normalization_type == 'convolutional': + pass + elif normalization_type == 'wrap': + CQT *= 2/self.n_fft + else: + raise ValueError("The normalization_type %r is not part of our current options." % normalization_type) + + if output_format=='Magnitude': + # Getting CQT Amplitude + return torch.sqrt(CQT.pow(2).sum(-1)) + + elif output_format=='Complex': + return CQT + + elif output_format=='Phase': + phase_real = torch.cos(torch.atan2(CQT[:,:,:,1],CQT[:,:,:,0])) + phase_imag = torch.sin(torch.atan2(CQT[:,:,:,1],CQT[:,:,:,0])) + return torch.stack((phase_real,phase_imag), -1) + + def extra_repr(self) -> str: + return 'STFT kernel size = {}, CQT kernel size = {}'.format( + (*self.wcos.shape,), (*self.cqt_kernels_real.shape,) + ) + + +class CQT1992v2(torch.nn.Module): + """This function is to calculate the CQT of the input signal. + Input signal should be in either of the following shapes.\n + 1. ``(len_audio)``\n + 2. ``(num_audio, len_audio)``\n + 3. ``(num_audio, 1, len_audio)`` + + The correct shape will be inferred autommatically if the input follows these 3 shapes. + Most of the arguments follow the convention from librosa. + This class inherits from ``torch.nn.Module``, therefore, the usage is same as ``torch.nn.Module``. + + This alogrithm uses the method proposed in [1]. I slightly modify it so that it runs faster + than the original 1992 algorithm, that is why I call it version 2. + [1] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of a + constant Q transform.” (1992). + + Parameters + ---------- + sr : int + The sampling rate for the input audio. It is used to calucate the correct ``fmin`` and ``fmax``. + Setting the correct sampling rate is very important for calculating the correct frequency. + + hop_length : int + The hop (or stride) size. Default value is 512. + + fmin : float + The frequency for the lowest CQT bin. Default is 32.70Hz, which coresponds to the note C0. + + fmax : float + The frequency for the highest CQT bin. Default is ``None``, therefore the higest CQT bin is + inferred from the ``n_bins`` and ``bins_per_octave``. + If ``fmax`` is not ``None``, then the argument ``n_bins`` will be ignored and ``n_bins`` + will be calculated automatically. Default is ``None`` + + n_bins : int + The total numbers of CQT bins. Default is 84. Will be ignored if ``fmax`` is not ``None``. + + bins_per_octave : int + Number of bins per octave. Default is 12. + + filter_scale : float > 0 + Filter scale factor. Values of filter_scale smaller than 1 can be used to improve the time resolution at the + cost of degrading the frequency resolution. Important to note is that setting for example filter_scale = 0.5 and + bins_per_octave = 48 leads to exactly the same time-frequency resolution trade-off as setting filter_scale = 1 + and bins_per_octave = 24, but the former contains twice more frequency bins per octave. In this sense, values + filter_scale < 1 can be seen to implement oversampling of the frequency axis, analogously to the use of zero + padding when calculating the DFT. + + norm : int + Normalization for the CQT kernels. ``1`` means L1 normalization, and ``2`` means L2 normalization. + Default is ``1``, which is same as the normalization used in librosa. + + window : string, float, or tuple + The windowing function for CQT. If it is a string, It uses ``scipy.signal.get_window``. If it is a + tuple, only the gaussian window wanrantees constant Q factor. Gaussian window should be given as a + tuple ('gaussian', att) where att is the attenuation in the border given in dB. + Please refer to scipy documentation for possible windowing functions. The default value is 'hann'. + + center : bool + Putting the CQT keneral at the center of the time-step or not. If ``False``, the time index is + the beginning of the CQT kernel, if ``True``, the time index is the center of the CQT kernel. + Default value if ``True``. + + pad_mode : str + The padding method. Default value is 'reflect'. + + trainable : bool + Determine if the CQT kernels are trainable or not. If ``True``, the gradients for CQT kernels + will also be caluclated and the CQT kernels will be updated during model training. + Default value is ``False``. + + output_format : str + Determine the return type. + ``Magnitude`` will return the magnitude of the STFT result, shape = ``(num_samples, freq_bins,time_steps)``; + ``Complex`` will return the STFT result in complex number, shape = ``(num_samples, freq_bins,time_steps, 2)``; + ``Phase`` will return the phase of the STFT reuslt, shape = ``(num_samples, freq_bins,time_steps, 2)``. + The complex number is stored as ``(real, imag)`` in the last axis. Default value is 'Magnitude'. + + verbose : bool + If ``True``, it shows layer information. If ``False``, it suppresses all prints + + Returns + ------- + spectrogram : torch.tensor + It returns a tensor of spectrograms. + shape = ``(num_samples, freq_bins,time_steps)`` if ``output_format='Magnitude'``; + shape = ``(num_samples, freq_bins,time_steps, 2)`` if ``output_format='Complex' or 'Phase'``; + + Examples + -------- + >>> spec_layer = Spectrogram.CQT1992v2() + >>> specs = spec_layer(x) + """ + + def __init__(self, sr=22050, hop_length=512, fmin=32.70, fmax=None, n_bins=84, + bins_per_octave=12, filter_scale=1, norm=1, window='hann', center=True, pad_mode='reflect', + trainable=False, output_format='Magnitude', verbose=True): + + super().__init__() + + self.trainable = trainable + self.hop_length = hop_length + self.center = center + self.pad_mode = pad_mode + self.output_format = output_format + + # creating kernels for CQT + Q = float(filter_scale)/(2**(1/bins_per_octave)-1) + + if verbose==True: + print("Creating CQT kernels ...", end='\r') + + start = time() + cqt_kernels, self.kernel_width, lenghts, freqs = create_cqt_kernels(Q, + sr, + fmin, + n_bins, + bins_per_octave, + norm, + window, + fmax) + + self.register_buffer('lenghts', lenghts) + self.frequencies = freqs + + cqt_kernels_real = torch.tensor(cqt_kernels.real).unsqueeze(1) + cqt_kernels_imag = torch.tensor(cqt_kernels.imag).unsqueeze(1) + + if trainable: + cqt_kernels_real = torch.nn.Parameter(cqt_kernels_real, requires_grad=trainable) + cqt_kernels_imag = torch.nn.Parameter(cqt_kernels_imag, requires_grad=trainable) + self.register_parameter('cqt_kernels_real', cqt_kernels_real) + self.register_parameter('cqt_kernels_imag', cqt_kernels_imag) + else: + self.register_buffer('cqt_kernels_real', cqt_kernels_real) + self.register_buffer('cqt_kernels_imag', cqt_kernels_imag) + + if verbose==True: + print("CQT kernels created, time used = {:.4f} seconds".format(time()-start)) + + + def forward(self,x, output_format=None, normalization_type='librosa'): + """ + Convert a batch of waveforms to CQT spectrograms. + + Parameters + ---------- + x : torch tensor + Input signal should be in either of the following shapes.\n + 1. ``(len_audio)``\n + 2. ``(num_audio, len_audio)``\n + 3. ``(num_audio, 1, len_audio)`` + It will be automatically broadcast to the right shape + + normalization_type : str + Type of the normalisation. The possible options are: \n + 'librosa' : the output fits the librosa one \n + 'convolutional' : the output conserves the convolutional inequalities of the wavelet transform:\n + for all p ϵ [1, inf] \n + - || CQT ||_p <= || f ||_p || g ||_1 \n + - || CQT ||_p <= || f ||_1 || g ||_p \n + - || CQT ||_2 = || f ||_2 || g ||_2 \n + 'wrap' : wraps positive and negative frequencies into positive frequencies. This means that the CQT of a + sinus (or a cosinus) with a constant amplitude equal to 1 will have the value 1 in the bin corresponding to + its frequency. + """ + output_format = output_format or self.output_format + + x = broadcast_dim(x) + if self.center: + if self.pad_mode == 'constant': + padding = nn.ConstantPad1d(self.kernel_width//2, 0) + elif self.pad_mode == 'reflect': + padding = nn.ReflectionPad1d(self.kernel_width//2) + + x = padding(x) + + # CQT + CQT_real = conv1d(x, self.cqt_kernels_real, stride=self.hop_length) + CQT_imag = -conv1d(x, self.cqt_kernels_imag, stride=self.hop_length) + + if normalization_type == 'librosa': + CQT_real *= torch.sqrt(self.lenghts.view(-1, 1)) + CQT_imag *= torch.sqrt(self.lenghts.view(-1, 1)) + elif normalization_type == 'convolutional': + pass + elif normalization_type == 'wrap': + CQT_real *= 2 + CQT_imag *= 2 + else: + raise ValueError("The normalization_type %r is not part of our current options." % normalization_type) + + if output_format=='Magnitude': + if self.trainable==False: + # Getting CQT Amplitude + CQT = torch.sqrt(CQT_real.pow(2)+CQT_imag.pow(2)) + else: + CQT = torch.sqrt(CQT_real.pow(2)+CQT_imag.pow(2)+1e-8) + return CQT + + elif output_format=='Complex': + return torch.stack((CQT_real,CQT_imag),-1) + + elif output_format=='Phase': + phase_real = torch.cos(torch.atan2(CQT_imag,CQT_real)) + phase_imag = torch.sin(torch.atan2(CQT_imag,CQT_real)) + return torch.stack((phase_real,phase_imag), -1) + + def forward_manual(self,x): + """ + Method for debugging + """ + + x = broadcast_dim(x) + if self.center: + if self.pad_mode == 'constant': + padding = nn.ConstantPad1d(self.kernel_width//2, 0) + elif self.pad_mode == 'reflect': + padding = nn.ReflectionPad1d(self.kernel_width//2) + + x = padding(x) + + # CQT + CQT_real = conv1d(x, self.cqt_kernels_real, stride=self.hop_length) + CQT_imag = conv1d(x, self.cqt_kernels_imag, stride=self.hop_length) + + # Getting CQT Amplitude + CQT = torch.sqrt(CQT_real.pow(2)+CQT_imag.pow(2)) + return CQT*torch.sqrt(self.lenghts.view(-1,1)) + + +class CQT2010v2(torch.nn.Module): + """This function is to calculate the CQT of the input signal. + Input signal should be in either of the following shapes.\n + 1. ``(len_audio)``\n + 2. ``(num_audio, len_audio)``\n + 3. ``(num_audio, 1, len_audio)`` + + The correct shape will be inferred autommatically if the input follows these 3 shapes. + Most of the arguments follow the convention from librosa. + This class inherits from ``torch.nn.Module``, therefore, the usage is same as ``torch.nn.Module``. + + This alogrithm uses the resampling method proposed in [1]. + Instead of convoluting the STFT results with a gigantic CQT kernel covering the full frequency + spectrum, we make a small CQT kernel covering only the top octave. Then we keep downsampling the + input audio by a factor of 2 to convoluting it with the small CQT kernel. + Everytime the input audio is downsampled, the CQT relative to the downsampled input is equivalent + to the next lower octave. + The kernel creation process is still same as the 1992 algorithm. Therefore, we can reuse the + code from the 1992 alogrithm [2] + [1] Schörkhuber, Christian. “CONSTANT-Q TRANSFORM TOOLBOX FOR MUSIC PROCESSING.” (2010). + [2] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of a + constant Q transform.” (1992). + + Early downsampling factor is to downsample the input audio to reduce the CQT kernel size. + The result with and without early downsampling are more or less the same except in the very low + frequency region where freq < 40Hz. + + Parameters + ---------- + sr : int + The sampling rate for the input audio. It is used to calucate the correct ``fmin`` and ``fmax``. + Setting the correct sampling rate is very important for calculating the correct frequency. + + hop_length : int + The hop (or stride) size. Default value is 512. + + fmin : float + The frequency for the lowest CQT bin. Default is 32.70Hz, which coresponds to the note C0. + + fmax : float + The frequency for the highest CQT bin. Default is ``None``, therefore the higest CQT bin is + inferred from the ``n_bins`` and ``bins_per_octave``. If ``fmax`` is not ``None``, then the + argument ``n_bins`` will be ignored and ``n_bins`` will be calculated automatically. + Default is ``None`` + + n_bins : int + The total numbers of CQT bins. Default is 84. Will be ignored if ``fmax`` is not ``None``. + + bins_per_octave : int + Number of bins per octave. Default is 12. + + norm : bool + Normalization for the CQT result. + + basis_norm : int + Normalization for the CQT kernels. ``1`` means L1 normalization, and ``2`` means L2 normalization. + Default is ``1``, which is same as the normalization used in librosa. + + window : str + The windowing function for CQT. It uses ``scipy.signal.get_window``, please refer to + scipy documentation for possible windowing functions. The default value is 'hann' + + pad_mode : str + The padding method. Default value is 'reflect'. + + trainable : bool + Determine if the CQT kernels are trainable or not. If ``True``, the gradients for CQT kernels + will also be caluclated and the CQT kernels will be updated during model training. + Default value is ``False`` + + output_format : str + Determine the return type. + 'Magnitude' will return the magnitude of the STFT result, shape = ``(num_samples, freq_bins, time_steps)``; + 'Complex' will return the STFT result in complex number, shape = ``(num_samples, freq_bins, time_steps, 2)``; + 'Phase' will return the phase of the STFT reuslt, shape = ``(num_samples, freq_bins,time_steps, 2)``. + The complex number is stored as ``(real, imag)`` in the last axis. Default value is 'Magnitude'. + + verbose : bool + If ``True``, it shows layer information. If ``False``, it suppresses all prints. + + Returns + ------- + spectrogram : torch.tensor + It returns a tensor of spectrograms. + shape = ``(num_samples, freq_bins,time_steps)`` if ``output_format='Magnitude'``; + shape = ``(num_samples, freq_bins,time_steps, 2)`` if ``output_format='Complex' or 'Phase'``; + + Examples + -------- + >>> spec_layer = Spectrogram.CQT2010v2() + >>> specs = spec_layer(x) + """ + + +# To DO: +# need to deal with the filter and other tensors + + def __init__(self, sr=22050, hop_length=512, fmin=32.70, fmax=None, n_bins=84, filter_scale=1, + bins_per_octave=12, norm=True, basis_norm=1, window='hann', pad_mode='reflect', + earlydownsample=True, trainable=False, output_format='Magnitude', verbose=True): + + super().__init__() + + self.norm = norm # Now norm is used to normalize the final CQT result by dividing n_fft + # basis_norm is for normalizing basis + self.hop_length = hop_length + self.pad_mode = pad_mode + self.n_bins = n_bins + self.earlydownsample = earlydownsample # We will activate early downsampling later if possible + self.trainable = trainable + self.output_format = output_format + + # It will be used to calculate filter_cutoff and creating CQT kernels + Q = float(filter_scale)/(2**(1/bins_per_octave)-1) + + # Creating lowpass filter and make it a torch tensor + if verbose==True: + print("Creating low pass filter ...", end='\r') + start = time() + # self.lowpass_filter = torch.tensor( + # create_lowpass_filter( + # band_center = 0.50, + # kernelLength=256, + # transitionBandwidth=0.001)) + lowpass_filter = torch.tensor(create_lowpass_filter( + band_center = 0.50, + kernelLength=256, + transitionBandwidth=0.001) + ) + + # Broadcast the tensor to the shape that fits conv1d + self.register_buffer('lowpass_filter', lowpass_filter[None,None,:]) + if verbose==True: + print("Low pass filter created, time used = {:.4f} seconds".format(time()-start)) + + # Caluate num of filter requires for the kernel + # n_octaves determines how many resampling requires for the CQT + n_filters = min(bins_per_octave, n_bins) + self.n_octaves = int(np.ceil(float(n_bins) / bins_per_octave)) + if verbose==True: + print("num_octave = ", self.n_octaves) + + # Calculate the lowest frequency bin for the top octave kernel + self.fmin_t = fmin*2**(self.n_octaves-1) + remainder = n_bins % bins_per_octave + # print("remainder = ", remainder) + + if remainder==0: + # Calculate the top bin frequency + fmax_t = self.fmin_t*2**((bins_per_octave-1)/bins_per_octave) + else: + # Calculate the top bin frequency + fmax_t = self.fmin_t*2**((remainder-1)/bins_per_octave) + + self.fmin_t = fmax_t/2**(1-1/bins_per_octave) # Adjusting the top minium bins + if fmax_t > sr/2: + raise ValueError('The top bin {}Hz has exceeded the Nyquist frequency, \ + please reduce the n_bins'.format(fmax_t)) + + if self.earlydownsample == True: # Do early downsampling if this argument is True + if verbose==True: + print("Creating early downsampling filter ...", end='\r') + start = time() + sr, self.hop_length, self.downsample_factor, early_downsample_filter, \ + self.earlydownsample = get_early_downsample_params(sr, + hop_length, + fmax_t, + Q, + self.n_octaves, + verbose) + self.register_buffer('early_downsample_filter', early_downsample_filter) + + if verbose==True: + print("Early downsampling filter created, \ + time used = {:.4f} seconds".format(time()-start)) + else: + self.downsample_factor=1. + + # Preparing CQT kernels + if verbose==True: + print("Creating CQT kernels ...", end='\r') + start = time() + basis, self.n_fft, lenghts, _ = create_cqt_kernels(Q, + sr, + self.fmin_t, + n_filters, + bins_per_octave, + norm=basis_norm, + topbin_check=False) + # For normalization in the end + # The freqs returned by create_cqt_kernels cannot be used + # Since that returns only the top octave bins + # We need the information for all freq bin + freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave)) + self.frequencies = freqs + + lenghts = np.ceil(Q * sr / freqs) + lenghts = torch.tensor(lenghts).float() + self.register_buffer('lenghts', lenghts) + + self.basis = basis + # These cqt_kernel is already in the frequency domain + cqt_kernels_real = torch.tensor(basis.real.astype(np.float32)).unsqueeze(1) + cqt_kernels_imag = torch.tensor(basis.imag.astype(np.float32)).unsqueeze(1) + + if trainable: + cqt_kernels_real = torch.nn.Parameter(cqt_kernels_real, requires_grad=trainable) + cqt_kernels_imag = torch.nn.Parameter(cqt_kernels_imag, requires_grad=trainable) + self.register_parameter('cqt_kernels_real', cqt_kernels_real) + self.register_parameter('cqt_kernels_imag', cqt_kernels_imag) + else: + self.register_buffer('cqt_kernels_real', cqt_kernels_real) + self.register_buffer('cqt_kernels_imag', cqt_kernels_imag) + + + if verbose==True: + print("CQT kernels created, time used = {:.4f} seconds".format(time()-start)) + # print("Getting cqt kernel done, n_fft = ",self.n_fft) + + # If center==True, the STFT window will be put in the middle, and paddings at the beginning + # and ending are required. + if self.pad_mode == 'constant': + self.padding = nn.ConstantPad1d(self.n_fft//2, 0) + elif self.pad_mode == 'reflect': + self.padding = nn.ReflectionPad1d(self.n_fft//2) + + + def forward(self,x,output_format=None, normalization_type='librosa'): + """ + Convert a batch of waveforms to CQT spectrograms. + + Parameters + ---------- + x : torch tensor + Input signal should be in either of the following shapes.\n + 1. ``(len_audio)``\n + 2. ``(num_audio, len_audio)``\n + 3. ``(num_audio, 1, len_audio)`` + It will be automatically broadcast to the right shape + """ + output_format = output_format or self.output_format + + x = broadcast_dim(x) + if self.earlydownsample==True: + x = downsampling_by_n(x, self.early_downsample_filter, self.downsample_factor) + hop = self.hop_length + CQT = get_cqt_complex(x, self.cqt_kernels_real, self.cqt_kernels_imag, hop, self.padding) # Getting the top octave CQT + + x_down = x # Preparing a new variable for downsampling + + for i in range(self.n_octaves-1): + hop = hop//2 + x_down = downsampling_by_2(x_down, self.lowpass_filter) + CQT1 = get_cqt_complex(x_down, self.cqt_kernels_real, self.cqt_kernels_imag, hop, self.padding) + CQT = torch.cat((CQT1, CQT),1) + + CQT = CQT[:,-self.n_bins:,:] # Removing unwanted bottom bins + # print("downsample_factor = ",self.downsample_factor) + # print(CQT.shape) + # print(self.lenghts.view(-1,1).shape) + + # Normalizing the output with the downsampling factor, 2**(self.n_octaves-1) is make it + # same mag as 1992 + CQT = CQT*self.downsample_factor + # Normalize again to get same result as librosa + if normalization_type == 'librosa': + CQT = CQT*torch.sqrt(self.lenghts.view(-1,1,1)) + elif normalization_type == 'convolutional': + pass + elif normalization_type == 'wrap': + CQT *= 2 + else: + raise ValueError("The normalization_type %r is not part of our current options." % normalization_type) + + + + if output_format=='Magnitude': + if self.trainable==False: + # Getting CQT Amplitude + return torch.sqrt(CQT.pow(2).sum(-1)) + else: + return torch.sqrt(CQT.pow(2).sum(-1)+1e-8) + + elif output_format=='Complex': + return CQT + + elif output_format=='Phase': + phase_real = torch.cos(torch.atan2(CQT[:,:,:,1],CQT[:,:,:,0])) + phase_imag = torch.sin(torch.atan2(CQT[:,:,:,1],CQT[:,:,:,0])) + return torch.stack((phase_real,phase_imag), -1) + + +class CQT(CQT1992v2): + """An abbreviation for :func:`~nnAudio.Spectrogram.CQT1992v2`. Please refer to the :func:`~nnAudio.Spectrogram.CQT1992v2` documentation""" + pass + + + +# The section below is for developing purpose +# Please don't use the following classes +# + +class DFT(torch.nn.Module): + """ + Experimental feature before `torch.fft` was made avaliable. + The inverse function only works for 1 single frame. i.e. input shape = (batch, n_fft, 1) + """ + def __init__(self, n_fft=2048, freq_bins=None, hop_length=512, + window='hann', freq_scale='no', center=True, pad_mode='reflect', + fmin=50, fmax=6000, sr=22050): + + super().__init__() + + self.stride = hop_length + self.center = center + self.pad_mode = pad_mode + self.n_fft = n_fft + + # Create filter windows for stft + wsin, wcos, self.bins2freq = create_fourier_kernels(n_fft=n_fft, + freq_bins=n_fft, + window=window, + freq_scale=freq_scale, + fmin=fmin, + fmax=fmax, + sr=sr) + self.wsin = torch.tensor(wsin, dtype=torch.float) + self.wcos = torch.tensor(wcos, dtype=torch.float) + + def forward(self,x): + """ + Convert a batch of waveforms to spectrums. + + Parameters + ---------- + x : torch tensor + Input signal should be in either of the following shapes.\n + 1. ``(len_audio)``\n + 2. ``(num_audio, len_audio)``\n + 3. ``(num_audio, 1, len_audio)`` + It will be automatically broadcast to the right shape + """ + x = broadcast_dim(x) + if self.center: + if self.pad_mode == 'constant': + padding = nn.ConstantPad1d(self.n_fft//2, 0) + elif self.pad_mode == 'reflect': + padding = nn.ReflectionPad1d(self.n_fft//2) + + x = padding(x) + + imag = conv1d(x, self.wsin, stride=self.stride) + real = conv1d(x, self.wcos, stride=self.stride) + return (real, -imag) + + def inverse(self,x_real,x_imag): + """ + Convert a batch of waveforms to CQT spectrograms. + + Parameters + ---------- + x_real : torch tensor + Real part of the signal. + x_imag : torch tensor + Imaginary part of the signal. + """ + x_real = broadcast_dim(x_real) + x_imag = broadcast_dim(x_imag) + + x_real.transpose_(1,2) # Prepare the right shape to do inverse + x_imag.transpose_(1,2) # Prepare the right shape to do inverse + + # if self.center: + # if self.pad_mode == 'constant': + # padding = nn.ConstantPad1d(self.n_fft//2, 0) + # elif self.pad_mode == 'reflect': + # padding = nn.ReflectionPad1d(self.n_fft//2) + + # x_real = padding(x_real) + # x_imag = padding(x_imag) + + # Watch out for the positive and negative signs + # ifft = e^(+2\pi*j)*X + + # ifft(X_real) = (a1, a2) + + # ifft(X_imag)*1j = (b1, b2)*1j + # = (-b2, b1) + + a1 = conv1d(x_real, self.wcos, stride=self.stride) + a2 = conv1d(x_real, self.wsin, stride=self.stride) + b1 = conv1d(x_imag, self.wcos, stride=self.stride) + b2 = conv1d(x_imag, self.wsin, stride=self.stride) + + imag = a2+b1 + real = a1-b2 + return (real/self.n_fft, imag/self.n_fft) + + + + +class iSTFT(torch.nn.Module): + """This class is to convert spectrograms back to waveforms. It only works for the complex value spectrograms. + If you have the magnitude spectrograms, please use :func:`~nnAudio.Spectrogram.Griffin_Lim`. + The parameters (e.g. n_fft, window) need to be the same as the STFT in order to obtain the correct inverse. + If trainability is not required, it is recommended to use the ``inverse`` method under the ``STFT`` class + to save GPU/RAM memory. + + When ``trainable=True`` and ``freq_scale!='no'``, there is no guarantee that the inverse is perfect, please + use with extra care. + + Parameters + ---------- + n_fft : int + The window size. Default value is 2048. + + freq_bins : int + Number of frequency bins. Default is ``None``, which means ``n_fft//2+1`` bins + Please make sure the value is the same as the forward STFT. + + hop_length : int + The hop (or stride) size. Default value is ``None`` which is equivalent to ``n_fft//4``. + Please make sure the value is the same as the forward STFT. + + window : str + The windowing function for iSTFT. It uses ``scipy.signal.get_window``, please refer to + scipy documentation for possible windowing functions. The default value is 'hann'. + Please make sure the value is the same as the forward STFT. + + freq_scale : 'linear', 'log', or 'no' + Determine the spacing between each frequency bin. When `linear` or `log` is used, + the bin spacing can be controlled by ``fmin`` and ``fmax``. If 'no' is used, the bin will + start at 0Hz and end at Nyquist frequency with linear spacing. + Please make sure the value is the same as the forward STFT. + + center : bool + Putting the iSTFT keneral at the center of the time-step or not. If ``False``, the time + index is the beginning of the iSTFT kernel, if ``True``, the time index is the center of + the iSTFT kernel. Default value if ``True``. + Please make sure the value is the same as the forward STFT. + + fmin : int + The starting frequency for the lowest frequency bin. If freq_scale is ``no``, this argument + does nothing. Please make sure the value is the same as the forward STFT. + + fmax : int + The ending frequency for the highest frequency bin. If freq_scale is ``no``, this argument + does nothing. Please make sure the value is the same as the forward STFT. + + sr : int + The sampling rate for the input audio. It is used to calucate the correct ``fmin`` and ``fmax``. + Setting the correct sampling rate is very important for calculating the correct frequency. + + trainable_kernels : bool + Determine if the STFT kenrels are trainable or not. If ``True``, the gradients for STFT + kernels will also be caluclated and the STFT kernels will be updated during model training. + Default value is ``False``. + + trainable_window : bool + Determine if the window function is trainable or not. + Default value is ``False``. + + verbose : bool + If ``True``, it shows layer information. If ``False``, it suppresses all prints. + + Returns + ------- + spectrogram : torch.tensor + It returns a batch of waveforms. + + Examples + -------- + >>> spec_layer = Spectrogram.iSTFT() + >>> specs = spec_layer(x) + """ + + def __init__(self, n_fft=2048, win_length=None, freq_bins=None, hop_length=None, window='hann', + freq_scale='no', center=True, fmin=50, fmax=6000, sr=22050, trainable_kernels=False, + trainable_window=False, verbose=True, refresh_win=True): + + super().__init__() + + # Trying to make the default setting same as librosa + if win_length==None: win_length = n_fft + if hop_length==None: hop_length = int(win_length // 4) + + self.n_fft = n_fft + self.win_length = win_length + self.stride = hop_length + self.center = center + + self.pad_amount = self.n_fft // 2 + self.refresh_win = refresh_win + + start = time() + + # Create the window function and prepare the shape for batch-wise-time-wise multiplication + + # Create filter windows for inverse + kernel_sin, kernel_cos, _, _, window_mask = create_fourier_kernels(n_fft, + win_length=win_length, + freq_bins=n_fft, + window=window, + freq_scale=freq_scale, + fmin=fmin, + fmax=fmax, + sr=sr, + verbose=False) + window_mask = get_window(window,int(win_length), fftbins=True) + + # For inverse, the Fourier kernels do not need to be windowed + window_mask = torch.tensor(window_mask).unsqueeze(0).unsqueeze(-1) + + # kernel_sin and kernel_cos have the shape (freq_bins, 1, n_fft, 1) to support 2D Conv + kernel_sin = torch.tensor(kernel_sin, dtype=torch.float).unsqueeze(-1) + kernel_cos = torch.tensor(kernel_cos, dtype=torch.float).unsqueeze(-1) + + # Decide if the Fourier kernels are trainable + if trainable_kernels: + # Making all these variables trainable + kernel_sin = torch.nn.Parameter(kernel_sin, requires_grad=trainable_kernels) + kernel_cos = torch.nn.Parameter(kernel_cos, requires_grad=trainable_kernels) + self.register_parameter('kernel_sin', kernel_sin) + self.register_parameter('kernel_cos', kernel_cos) + + else: + self.register_buffer('kernel_sin', kernel_sin) + self.register_buffer('kernel_cos', kernel_cos) + + # Decide if the window function is trainable + if trainable_window: + window_mask = torch.nn.Parameter(window_mask, requires_grad=trainable_window) + self.register_parameter('window_mask', window_mask) + else: + self.register_buffer('window_mask', window_mask) + + + if verbose==True: + print("iSTFT kernels created, time used = {:.4f} seconds".format(time()-start)) + else: + pass + + + def forward(self, X, onesided=False, length=None, refresh_win=None): + """ + If your spectrograms only have ``n_fft//2+1`` frequency bins, please use ``onesided=True``, + else use ``onesided=False`` + To make sure the inverse STFT has the same output length of the original waveform, please + set `length` as your intended waveform length. By default, ``length=None``, + which will remove ``n_fft//2`` samples from the start and the end of the output. + If your input spectrograms X are of the same length, please use ``refresh_win=None`` to increase + computational speed. + """ + if refresh_win==None: + refresh_win=self.refresh_win + + assert X.dim()==4 , "Inverse iSTFT only works for complex number," \ + "make sure our tensor is in the shape of (batch, freq_bins, timesteps, 2)" + + # If the input spectrogram contains only half of the n_fft + # Use extend_fbins function to get back another half + if onesided: + X = extend_fbins(X) # extend freq + + + X_real, X_imag = X[:, :, :, 0], X[:, :, :, 1] + + # broadcast dimensions to support 2D convolution + X_real_bc = X_real.unsqueeze(1) + X_imag_bc = X_imag.unsqueeze(1) + + a1 = conv2d(X_real_bc, self.kernel_cos, stride=(1,1)) + b2 = conv2d(X_imag_bc, self.kernel_sin, stride=(1,1)) + + # compute real and imag part. signal lies in the real part + real = a1 - b2 + real = real.squeeze(-2)*self.window_mask + + # Normalize the amplitude with n_fft + real /= (self.n_fft) + + # Overlap and Add algorithm to connect all the frames + real = overlap_add(real, self.stride) + + # Prepare the window sumsqure for division + # Only need to create this window once to save time + # Unless the input spectrograms have different time steps + if hasattr(self, 'w_sum')==False or refresh_win==True: + self.w_sum = torch_window_sumsquare(self.window_mask.flatten(), X.shape[2], self.stride, self.n_fft).flatten() + self.nonzero_indices = (self.w_sum>1e-10) + else: + pass + real[:, self.nonzero_indices] = real[:,self.nonzero_indices].div(self.w_sum[self.nonzero_indices]) + # Remove padding + if length is None: + if self.center: + real = real[:, self.pad_amount:-self.pad_amount] + + else: + if self.center: + real = real[:, self.pad_amount:self.pad_amount + length] + else: + real = real[:, :length] + + return real + + +class Griffin_Lim(torch.nn.Module): + """ + Converting Magnitude spectrograms back to waveforms based on the "fast Griffin-Lim"[1]. + This Griffin Lim is a direct clone from librosa.griffinlim. + + [1] Perraudin, N., Balazs, P., & Søndergaard, P. L. “A fast Griffin-Lim algorithm,” + IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (pp. 1-4), Oct. 2013. + + Parameters + ---------- + n_fft : int + The window size. Default value is 2048. + + n_iter=32 : int + The number of iterations for Griffin-Lim. The default value is ``32`` + + hop_length : int + The hop (or stride) size. Default value is ``None`` which is equivalent to ``n_fft//4``. + Please make sure the value is the same as the forward STFT. + + window : str + The windowing function for iSTFT. It uses ``scipy.signal.get_window``, please refer to + scipy documentation for possible windowing functions. The default value is 'hann'. + Please make sure the value is the same as the forward STFT. + + center : bool + Putting the iSTFT keneral at the center of the time-step or not. If ``False``, the time + index is the beginning of the iSTFT kernel, if ``True``, the time index is the center of + the iSTFT kernel. Default value if ``True``. + Please make sure the value is the same as the forward STFT. + + momentum : float + The momentum for the update rule. The default value is ``0.99``. + + device : str + Choose which device to initialize this layer. Default value is 'cpu' + + """ + + def __init__(self, + n_fft, + n_iter=32, + hop_length=None, + win_length=None, + window='hann', + center=True, + pad_mode='reflect', + momentum=0.99, + device='cpu'): + super().__init__() + + self.n_fft = n_fft + self.win_length = win_length + self.n_iter = n_iter + self.center = center + self.pad_mode = pad_mode + self.momentum = momentum + self.device = device + if win_length==None: + self.win_length=n_fft + else: + self.win_length=win_length + if hop_length==None: + self.hop_length = n_fft//4 + else: + self.hop_length = hop_length + + # Creating window function for stft and istft later + self.w = torch.tensor(get_window(window, + int(self.win_length), + fftbins=True), + device=device).float() + + def forward(self, S): + """ + Convert a batch of magnitude spectrograms to waveforms. + + Parameters + ---------- + S : torch tensor + Spectrogram of the shape ``(batch, n_fft//2+1, timesteps)`` + """ + + assert S.dim()==3 , "Please make sure your input is in the shape of (batch, freq_bins, timesteps)" + + # Initializing Random Phase + rand_phase = torch.randn(*S.shape, device=self.device) + angles = torch.empty((*S.shape,2), device=self.device) + angles[:, :,:,0] = torch.cos(2 * np.pi * rand_phase) + angles[:,:,:,1] = torch.sin(2 * np.pi * rand_phase) + + # Initializing the rebuilt magnitude spectrogram + rebuilt = torch.zeros(*angles.shape, device=self.device) + + for _ in range(self.n_iter): + tprev = rebuilt # Saving previous rebuilt magnitude spec + + # spec2wav conversion +# print(f'win_length={self.win_length}\tw={self.w.shape}') + inverse = torch.istft(S.unsqueeze(-1) * angles, + self.n_fft, + self.hop_length, + win_length=self.win_length, + window=self.w, + center=self.center) + # wav2spec conversion + rebuilt = torch.stft(inverse, + self.n_fft, + self.hop_length, + win_length=self.win_length, + window=self.w, + pad_mode=self.pad_mode) + + # Phase update rule + angles[:,:,:] = rebuilt[:,:,:] - (self.momentum / (1 + self.momentum)) * tprev[:,:,:] + + # Phase normalization + angles = angles.div(torch.sqrt(angles.pow(2).sum(-1)).unsqueeze(-1) + 1e-16) # normalizing the phase + + # Using the final phase to reconstruct the waveforms + inverse = torch.istft(S.unsqueeze(-1) * angles, + self.n_fft, + self.hop_length, + win_length=self.win_length, + window=self.w, + center=self.center) + return inverse + + + +class Combined_Frequency_Periodicity(nn.Module): + """ + Vectorized version of the code in https://github.com/leo-so/VocalMelodyExtPatchCNN/blob/master/MelodyExt.py. + This feature is described in 'Combining Spectral and Temporal Representations for Multipitch Estimation of Polyphonic Music' + https://ieeexplore.ieee.org/document/7118691 + + Under development, please report any bugs you found + """ + def __init__(self,fr=2, fs=16000, hop_length=320, + window_size=2049, fc=80, tc=1/1000, + g=[0.24, 0.6, 1], NumPerOct=48): + super().__init__() + + self.window_size = window_size + self.hop_length = hop_length + + # variables for STFT part + self.N = int(fs/float(fr)) # Will be used to calculate padding + self.f = fs*np.linspace(0, 0.5, np.round(self.N//2), endpoint=True) # it won't be used but will be returned + self.pad_value = ((self.N-window_size)) + # Create window function, always blackmanharris? + h = scipy.signal.blackmanharris(window_size).astype(np.float32) # window function for STFT + self.register_buffer('h',torch.tensor(h)) + + # variables for CFP + self.NumofLayer = np.size(g) + self.g = g + self.tc_idx = round(fs*tc) # index to filter out top tc_idx and bottom tc_idx bins + self.fc_idx = round(fc/fr) # index to filter out top fc_idx and bottom fc_idx bins + self.HighFreqIdx = int(round((1/tc)/fr)+1) + self.HighQuefIdx = int(round(fs/fc)+1) + + # attributes to be returned + self.f = self.f[:self.HighFreqIdx] + self.q = np.arange(self.HighQuefIdx)/float(fs) + + # filters for the final step + freq2logfreq_matrix, quef2logfreq_matrix = self.create_logfreq_matrix(self.f, self.q, fr, fc, tc, NumPerOct, fs) + self.register_buffer('freq2logfreq_matrix',torch.tensor(freq2logfreq_matrix.astype(np.float32))) + self.register_buffer('quef2logfreq_matrix',torch.tensor(quef2logfreq_matrix.astype(np.float32))) + + def _CFP(self, spec): + spec = torch.relu(spec).pow(self.g[0]) + + if self.NumofLayer >= 2: + for gc in range(1, self.NumofLayer): + if np.remainder(gc, 2) == 1: + ceps = torch.rfft(spec, 1, onesided=False)[:,:,:,0]/np.sqrt(self.N) + ceps = self.nonlinear_func(ceps, self.g[gc], self.tc_idx) + else: + spec = torch.rfft(ceps, 1, onesided=False)[:,:,:,0]/np.sqrt(self.N) + spec = self.nonlinear_func(spec, self.g[gc], self.fc_idx) + + return spec, ceps + + + def forward(self, x): + tfr0 = torch.stft(x, self.N, hop_length=self.hop_length, win_length=self.window_size, + window=self.h, onesided=False, pad_mode='constant') + tfr0 = torch.sqrt(tfr0.pow(2).sum(-1))/torch.norm(self.h) # calcuate magnitude + tfr0 = tfr0.transpose(1,2)[:,1:-1] #transpose F and T axis and discard first and last frames + # The transpose is necessary for rfft later + # (batch, timesteps, n_fft) + tfr, ceps = self._CFP(tfr0) + +# return tfr0 + # removing duplicate bins + tfr0 = tfr0[:,:,:int(round(self.N/2))] + tfr = tfr[:,:,:int(round(self.N/2))] + ceps = ceps[:,:,:int(round(self.N/2))] + + # Crop up to the highest frequency + tfr0 = tfr0[:,:,:self.HighFreqIdx] + tfr = tfr[:,:,:self.HighFreqIdx] + ceps = ceps[:,:,:self.HighQuefIdx] + tfrL0 = torch.matmul(self.freq2logfreq_matrix, tfr0.transpose(1,2)) + tfrLF = torch.matmul(self.freq2logfreq_matrix, tfr.transpose(1,2)) + tfrLQ = torch.matmul(self.quef2logfreq_matrix, ceps.transpose(1,2)) + Z = tfrLF * tfrLQ + + # Only need to calculate this once + self.t = np.arange(self.hop_length, + np.ceil(len(x)/float(self.hop_length))*self.hop_length, + self.hop_length) # it won't be used but will be returned + + return Z, tfrL0, tfrLF, tfrLQ + + def nonlinear_func(self, X, g, cutoff): + cutoff = int(cutoff) + if g!=0: + X = torch.relu(X) + X[:, :, :cutoff] = 0 + X[:, :, -cutoff:] = 0 + X = X.pow(g) + else: # when g=0, it converges to log + X = torch.log(X) + X[:, :, :cutoff] = 0 + X[:, :, -cutoff:] = 0 + return X + + def create_logfreq_matrix(self, f, q, fr, fc, tc, NumPerOct, fs): + StartFreq = fc + StopFreq = 1/tc + Nest = int(np.ceil(np.log2(StopFreq/StartFreq))*NumPerOct) + central_freq = [] # A list holding the frequencies in log scale + + for i in range(0, Nest): + CenFreq = StartFreq*pow(2, float(i)/NumPerOct) + if CenFreq < StopFreq: + central_freq.append(CenFreq) + else: + break + + Nest = len(central_freq) + freq_band_transformation = np.zeros((Nest-1, len(f)), dtype=np.float) + + # Calculating the freq_band_transformation + for i in range(1, Nest-1): + l = int(round(central_freq[i-1]/fr)) + r = int(round(central_freq[i+1]/fr)+1) + #rounding1 + if l >= r-1: + freq_band_transformation[i, l] = 1 + else: + for j in range(l, r): + if f[j] > central_freq[i-1] and f[j] < central_freq[i]: + freq_band_transformation[i, j] = (f[j] - central_freq[i-1]) / (central_freq[i] - central_freq[i-1]) + elif f[j] > central_freq[i] and f[j] < central_freq[i+1]: + freq_band_transformation[i, j] = (central_freq[i + 1] - f[j]) / (central_freq[i + 1] - central_freq[i]) + + # Calculating the quef_band_transformation + f = 1/q # divide by 0, do I need to fix this? + quef_band_transformation = np.zeros((Nest-1, len(f)), dtype=np.float) + for i in range(1, Nest-1): + for j in range(int(round(fs/central_freq[i+1])), int(round(fs/central_freq[i-1])+1)): + if f[j] > central_freq[i-1] and f[j] < central_freq[i]: + quef_band_transformation[i, j] = (f[j] - central_freq[i-1])/(central_freq[i] - central_freq[i-1]) + elif f[j] > central_freq[i] and f[j] < central_freq[i+1]: + quef_band_transformation[i, j] = (central_freq[i + 1] - f[j]) / (central_freq[i + 1] - central_freq[i]) + + return freq_band_transformation, quef_band_transformation + + +class CFP(nn.Module): + """ + This is the modified version so that the number of timesteps fits with other classes + + Under development, please report any bugs you found + """ + def __init__(self,fr=2, fs=16000, hop_length=320, + window_size=2049, fc=80, tc=1/1000, + g=[0.24, 0.6, 1], NumPerOct=48): + super().__init__() + + self.window_size = window_size + self.hop_length = hop_length + + # variables for STFT part + self.N = int(fs/float(fr)) # Will be used to calculate padding + self.f = fs*np.linspace(0, 0.5, np.round(self.N//2), endpoint=True) # it won't be used but will be returned + self.pad_value = ((self.N-window_size)) + # Create window function, always blackmanharris? + h = scipy.signal.blackmanharris(window_size).astype(np.float32) # window function for STFT + self.register_buffer('h',torch.tensor(h)) + + # variables for CFP + self.NumofLayer = np.size(g) + self.g = g + self.tc_idx = round(fs*tc) # index to filter out top tc_idx and bottom tc_idx bins + self.fc_idx = round(fc/fr) # index to filter out top fc_idx and bottom fc_idx bins + self.HighFreqIdx = int(round((1/tc)/fr)+1) + self.HighQuefIdx = int(round(fs/fc)+1) + + # attributes to be returned + self.f = self.f[:self.HighFreqIdx] + self.q = np.arange(self.HighQuefIdx)/float(fs) + + # filters for the final step + freq2logfreq_matrix, quef2logfreq_matrix = self.create_logfreq_matrix(self.f, self.q, fr, fc, tc, NumPerOct, fs) + self.register_buffer('freq2logfreq_matrix',torch.tensor(freq2logfreq_matrix.astype(np.float32))) + self.register_buffer('quef2logfreq_matrix',torch.tensor(quef2logfreq_matrix.astype(np.float32))) + + def _CFP(self, spec): + spec = torch.relu(spec).pow(self.g[0]) + + if self.NumofLayer >= 2: + for gc in range(1, self.NumofLayer): + if np.remainder(gc, 2) == 1: + ceps = torch.rfft(spec, 1, onesided=False)[:,:,:,0]/np.sqrt(self.N) + ceps = self.nonlinear_func(ceps, self.g[gc], self.tc_idx) + else: + spec = torch.rfft(ceps, 1, onesided=False)[:,:,:,0]/np.sqrt(self.N) + spec = self.nonlinear_func(spec, self.g[gc], self.fc_idx) + + return spec, ceps + + + def forward(self, x): + tfr0 = torch.stft(x, self.N, hop_length=self.hop_length, win_length=self.window_size, + window=self.h, onesided=False, pad_mode='constant') + tfr0 = torch.sqrt(tfr0.pow(2).sum(-1))/torch.norm(self.h) # calcuate magnitude + tfr0 = tfr0.transpose(1,2) #transpose F and T axis and discard first and last frames + # The transpose is necessary for rfft later + # (batch, timesteps, n_fft) + tfr, ceps = self._CFP(tfr0) + +# return tfr0 + # removing duplicate bins + tfr0 = tfr0[:,:,:int(round(self.N/2))] + tfr = tfr[:,:,:int(round(self.N/2))] + ceps = ceps[:,:,:int(round(self.N/2))] + + # Crop up to the highest frequency + tfr0 = tfr0[:,:,:self.HighFreqIdx] + tfr = tfr[:,:,:self.HighFreqIdx] + ceps = ceps[:,:,:self.HighQuefIdx] + tfrL0 = torch.matmul(self.freq2logfreq_matrix, tfr0.transpose(1,2)) + tfrLF = torch.matmul(self.freq2logfreq_matrix, tfr.transpose(1,2)) + tfrLQ = torch.matmul(self.quef2logfreq_matrix, ceps.transpose(1,2)) + Z = tfrLF * tfrLQ + + # Only need to calculate this once + self.t = np.arange(self.hop_length, + np.ceil(len(x)/float(self.hop_length))*self.hop_length, + self.hop_length) # it won't be used but will be returned + + return Z#, tfrL0, tfrLF, tfrLQ + + def nonlinear_func(self, X, g, cutoff): + cutoff = int(cutoff) + if g!=0: + X = torch.relu(X) + X[:, :, :cutoff] = 0 + X[:, :, -cutoff:] = 0 + X = X.pow(g) + else: # when g=0, it converges to log + X = torch.log(X) + X[:, :, :cutoff] = 0 + X[:, :, -cutoff:] = 0 + return X + + def create_logfreq_matrix(self, f, q, fr, fc, tc, NumPerOct, fs): + StartFreq = fc + StopFreq = 1/tc + Nest = int(np.ceil(np.log2(StopFreq/StartFreq))*NumPerOct) + central_freq = [] # A list holding the frequencies in log scale + + for i in range(0, Nest): + CenFreq = StartFreq*pow(2, float(i)/NumPerOct) + if CenFreq < StopFreq: + central_freq.append(CenFreq) + else: + break + + Nest = len(central_freq) + freq_band_transformation = np.zeros((Nest-1, len(f)), dtype=np.float) + + # Calculating the freq_band_transformation + for i in range(1, Nest-1): + l = int(round(central_freq[i-1]/fr)) + r = int(round(central_freq[i+1]/fr)+1) + #rounding1 + if l >= r-1: + freq_band_transformation[i, l] = 1 + else: + for j in range(l, r): + if f[j] > central_freq[i-1] and f[j] < central_freq[i]: + freq_band_transformation[i, j] = (f[j] - central_freq[i-1]) / (central_freq[i] - central_freq[i-1]) + elif f[j] > central_freq[i] and f[j] < central_freq[i+1]: + freq_band_transformation[i, j] = (central_freq[i + 1] - f[j]) / (central_freq[i + 1] - central_freq[i]) + + # Calculating the quef_band_transformation + f = 1/q # divide by 0, do I need to fix this? + quef_band_transformation = np.zeros((Nest-1, len(f)), dtype=np.float) + for i in range(1, Nest-1): + for j in range(int(round(fs/central_freq[i+1])), int(round(fs/central_freq[i-1])+1)): + if f[j] > central_freq[i-1] and f[j] < central_freq[i]: + quef_band_transformation[i, j] = (f[j] - central_freq[i-1])/(central_freq[i] - central_freq[i-1]) + elif f[j] > central_freq[i] and f[j] < central_freq[i+1]: + quef_band_transformation[i, j] = (central_freq[i + 1] - f[j]) / (central_freq[i + 1] - central_freq[i]) + + return freq_band_transformation, quef_band_transformation diff --git a/third_party/nnAudio/nnAudio/__init__.py b/third_party/nnAudio/nnAudio/__init__.py new file mode 100755 index 00000000..984fc572 --- /dev/null +++ b/third_party/nnAudio/nnAudio/__init__.py @@ -0,0 +1 @@ +__version__ = "0.2.2" \ No newline at end of file diff --git a/third_party/nnAudio/nnAudio/librosa_functions.py b/third_party/nnAudio/nnAudio/librosa_functions.py new file mode 100755 index 00000000..0d779217 --- /dev/null +++ b/third_party/nnAudio/nnAudio/librosa_functions.py @@ -0,0 +1,490 @@ +""" +Module containing functions cloned from librosa + +To make sure nnAudio would not become broken when updating librosa +""" + +import numpy as np +import warnings +### ----------------Functions for generating kenral for Mel Spectrogram------------ ### +# This code is equalvant to from librosa.filters import mel +# By doing so, we can run nnAudio without installing librosa +def fft2gammatonemx(sr=20000, n_fft=2048, n_bins=64, width=1.0, fmin=0.0, + fmax=11025, maxlen=1024): + """ + # Ellis' description in MATLAB: + # [wts,cfreqa] = fft2gammatonemx(nfft, sr, nfilts, width, minfreq, maxfreq, maxlen) + # Generate a matrix of weights to combine FFT bins into + # Gammatone bins. nfft defines the source FFT size at + # sampling rate sr. Optional nfilts specifies the number of + # output bands required (default 64), and width is the + # constant width of each band in Bark (default 1). + # minfreq, maxfreq specify range covered in Hz (100, sr/2). + # While wts has nfft columns, the second half are all zero. + # Hence, aud spectrum is + # fft2gammatonemx(nfft,sr)*abs(fft(xincols,nfft)); + # maxlen truncates the rows to this many bins. + # cfreqs returns the actual center frequencies of each + # gammatone band in Hz. + # + # 2009/02/22 02:29:25 Dan Ellis dpwe@ee.columbia.edu based on rastamat/audspec.m + # Sat May 27 15:37:50 2017 Maddie Cusimano, mcusi@mit.edu 27 May 2017: convert to python + """ + + wts = np.zeros([n_bins, n_fft], dtype=np.float32) + + # after Slaney's MakeERBFilters + EarQ = 9.26449; + minBW = 24.7; + order = 1; + + nFr = np.array(range(n_bins)) + 1 + em = EarQ * minBW + cfreqs = (fmax + em) * np.exp(nFr * (-np.log(fmax + em) + np.log(fmin + em)) / n_bins) - em + cfreqs = cfreqs[::-1] + + GTord = 4 + ucircArray = np.array(range(int(n_fft / 2 + 1))) + ucirc = np.exp(1j * 2 * np.pi * ucircArray / n_fft); + # justpoles = 0 :taking out the 'if' corresponding to this. + + ERB = width * np.power(np.power(cfreqs / EarQ, order) + np.power(minBW, order), 1 / order); + B = 1.019 * 2 * np.pi * ERB; + r = np.exp(-B / sr) + theta = 2 * np.pi * cfreqs / sr + pole = r * np.exp(1j * theta) + T = 1 / sr + ebt = np.exp(B * T); + cpt = 2 * cfreqs * np.pi * T; + ccpt = 2 * T * np.cos(cpt); + scpt = 2 * T * np.sin(cpt); + A11 = -np.divide(np.divide(ccpt, ebt) + np.divide(np.sqrt(3 + 2 ** 1.5) * scpt, ebt), 2); + A12 = -np.divide(np.divide(ccpt, ebt) - np.divide(np.sqrt(3 + 2 ** 1.5) * scpt, ebt), 2); + A13 = -np.divide(np.divide(ccpt, ebt) + np.divide(np.sqrt(3 - 2 ** 1.5) * scpt, ebt), 2); + A14 = -np.divide(np.divide(ccpt, ebt) - np.divide(np.sqrt(3 - 2 ** 1.5) * scpt, ebt), 2); + zros = -np.array([A11, A12, A13, A14]) / T; + wIdx = range(int(n_fft / 2 + 1)) + gain = np.abs((-2 * np.exp(4 * 1j * cfreqs * np.pi * T) * T + 2 * np.exp( + -(B * T) + 2 * 1j * cfreqs * np.pi * T) * T * ( + np.cos(2 * cfreqs * np.pi * T) - np.sqrt(3 - 2 ** (3 / 2)) * np.sin( + 2 * cfreqs * np.pi * T))) * (-2 * np.exp(4 * 1j * cfreqs * np.pi * T) * T + 2 * np.exp( + -(B * T) + 2 * 1j * cfreqs * np.pi * T) * T * (np.cos(2 * cfreqs * np.pi * T) + np.sqrt( + 3 - 2 ** (3 / 2)) * np.sin(2 * cfreqs * np.pi * T))) * ( + -2 * np.exp(4 * 1j * cfreqs * np.pi * T) * T + 2 * np.exp( + -(B * T) + 2 * 1j * cfreqs * np.pi * T) * T * ( + np.cos(2 * cfreqs * np.pi * T) - np.sqrt(3 + 2 ** (3 / 2)) * np.sin( + 2 * cfreqs * np.pi * T))) * ( + -2 * np.exp(4 * 1j * cfreqs * np.pi * T) * T + 2 * np.exp( + -(B * T) + 2 * 1j * cfreqs * np.pi * T) * T * ( + np.cos(2 * cfreqs * np.pi * T) + np.sqrt(3 + 2 ** (3 / 2)) * np.sin( + 2 * cfreqs * np.pi * T))) / ( + -2 / np.exp(2 * B * T) - 2 * np.exp(4 * 1j * cfreqs * np.pi * T) + 2 * ( + 1 + np.exp(4 * 1j * cfreqs * np.pi * T)) / np.exp(B * T)) ** 4); + # in MATLAB, there used to be 64 where here it says n_bins: + wts[:, wIdx] = ((T ** 4) / np.reshape(gain, (n_bins, 1))) * np.abs( + ucirc - np.reshape(zros[0], (n_bins, 1))) * np.abs(ucirc - np.reshape(zros[1], (n_bins, 1))) * np.abs( + ucirc - np.reshape(zros[2], (n_bins, 1))) * np.abs(ucirc - np.reshape(zros[3], (n_bins, 1))) * (np.abs( + np.power(np.multiply(np.reshape(pole, (n_bins, 1)) - ucirc, np.conj(np.reshape(pole, (n_bins, 1))) - ucirc), + -GTord))); + wts = wts[:, range(maxlen)]; + + return wts, cfreqs + +def gammatone(sr, n_fft, n_bins=64, fmin=20.0, fmax=None, htk=False, + norm=1, dtype=np.float32): + """Create a Filterbank matrix to combine FFT bins into Gammatone bins + Parameters + ---------- + sr : number > 0 [scalar] + sampling rate of the incoming signal + n_fft : int > 0 [scalar] + number of FFT components + n_bins : int > 0 [scalar] + number of Mel bands to generate + fmin : float >= 0 [scalar] + lowest frequency (in Hz) + fmax : float >= 0 [scalar] + highest frequency (in Hz). + If `None`, use `fmax = sr / 2.0` + htk : bool [scalar] + use HTK formula instead of Slaney + norm : {None, 1, np.inf} [scalar] + if 1, divide the triangular mel weights by the width of the mel band + (area normalization). Otherwise, leave all the triangles aiming for + a peak value of 1.0 + dtype : np.dtype + The data type of the output basis. + By default, uses 32-bit (single-precision) floating point. + Returns + ------- + G : np.ndarray [shape=(n_bins, 1 + n_fft/2)] + Gammatone transform matrix + """ + + if fmax is None: + fmax = float(sr) / 2 + n_bins = int(n_bins) + + weights,_ = fft2gammatonemx(sr=sr, n_fft=n_fft, n_bins=n_bins, fmin=fmin, fmax=fmax, maxlen=int(n_fft//2+1)) + + return (1/n_fft)*weights + +def mel_to_hz(mels, htk=False): + """Convert mel bin numbers to frequencies + Examples + -------- + >>> librosa.mel_to_hz(3) + 200. + >>> librosa.mel_to_hz([1,2,3,4,5]) + array([ 66.667, 133.333, 200. , 266.667, 333.333]) + Parameters + ---------- + mels : np.ndarray [shape=(n,)], float + mel bins to convert + htk : bool + use HTK formula instead of Slaney + Returns + ------- + frequencies : np.ndarray [shape=(n,)] + input mels in Hz + See Also + -------- + hz_to_mel + """ + + mels = np.asanyarray(mels) + + if htk: + return 700.0 * (10.0**(mels / 2595.0) - 1.0) + + # Fill in the linear scale + f_min = 0.0 + f_sp = 200.0 / 3 + freqs = f_min + f_sp * mels + + # And now the nonlinear scale + min_log_hz = 1000.0 # beginning of log region (Hz) + min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) + logstep = np.log(6.4) / 27.0 # step size for log region + + if mels.ndim: + # If we have vector data, vectorize + log_t = (mels >= min_log_mel) + freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel)) + elif mels >= min_log_mel: + # If we have scalar data, check directly + freqs = min_log_hz * np.exp(logstep * (mels - min_log_mel)) + + return freqs + +def hz_to_mel(frequencies, htk=False): + """Convert Hz to Mels + Examples + -------- + >>> librosa.hz_to_mel(60) + 0.9 + >>> librosa.hz_to_mel([110, 220, 440]) + array([ 1.65, 3.3 , 6.6 ]) + Parameters + ---------- + frequencies : number or np.ndarray [shape=(n,)] , float + scalar or array of frequencies + htk : bool + use HTK formula instead of Slaney + Returns + ------- + mels : number or np.ndarray [shape=(n,)] + input frequencies in Mels + See Also + -------- + mel_to_hz + """ + + frequencies = np.asanyarray(frequencies) + + if htk: + return 2595.0 * np.log10(1.0 + frequencies / 700.0) + + # Fill in the linear part + f_min = 0.0 + f_sp = 200.0 / 3 + + mels = (frequencies - f_min) / f_sp + + # Fill in the log-scale part + + min_log_hz = 1000.0 # beginning of log region (Hz) + min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) + logstep = np.log(6.4) / 27.0 # step size for log region + + if frequencies.ndim: + # If we have array data, vectorize + log_t = (frequencies >= min_log_hz) + mels[log_t] = min_log_mel + np.log(frequencies[log_t]/min_log_hz) / logstep + elif frequencies >= min_log_hz: + # If we have scalar data, heck directly + mels = min_log_mel + np.log(frequencies / min_log_hz) / logstep + + return mels + +def fft_frequencies(sr=22050, n_fft=2048): + '''Alternative implementation of `np.fft.fftfreq` + Parameters + ---------- + sr : number > 0 [scalar] + Audio sampling rate + n_fft : int > 0 [scalar] + FFT window size + Returns + ------- + freqs : np.ndarray [shape=(1 + n_fft/2,)] + Frequencies `(0, sr/n_fft, 2*sr/n_fft, ..., sr/2)` + Examples + -------- + >>> librosa.fft_frequencies(sr=22050, n_fft=16) + array([ 0. , 1378.125, 2756.25 , 4134.375, + 5512.5 , 6890.625, 8268.75 , 9646.875, 11025. ]) + ''' + + return np.linspace(0, + float(sr) / 2, + int(1 + n_fft//2), + endpoint=True) + +def mel_frequencies(n_mels=128, fmin=0.0, fmax=11025.0, htk=False): + """ + This function is cloned from librosa 0.7. + Please refer to the original + `documentation `__ + for more info. + + Parameters + ---------- + n_mels : int > 0 [scalar] + Number of mel bins. + + fmin : float >= 0 [scalar] + Minimum frequency (Hz). + + fmax : float >= 0 [scalar] + Maximum frequency (Hz). + + htk : bool + If True, use HTK formula to convert Hz to mel. + Otherwise (False), use Slaney's Auditory Toolbox. + + Returns + ------- + bin_frequencies : ndarray [shape=(n_mels,)] + Vector of n_mels frequencies in Hz which are uniformly spaced on the Mel + axis. + + Examples + -------- + >>> librosa.mel_frequencies(n_mels=40) + array([ 0. , 85.317, 170.635, 255.952, + 341.269, 426.586, 511.904, 597.221, + 682.538, 767.855, 853.173, 938.49 , + 1024.856, 1119.114, 1222.042, 1334.436, + 1457.167, 1591.187, 1737.532, 1897.337, + 2071.84 , 2262.393, 2470.47 , 2697.686, + 2945.799, 3216.731, 3512.582, 3835.643, + 4188.417, 4573.636, 4994.285, 5453.621, + 5955.205, 6502.92 , 7101.009, 7754.107, + 8467.272, 9246.028, 10096.408, 11025. ]) + """ + + # 'Center freqs' of mel bands - uniformly spaced between limits + min_mel = hz_to_mel(fmin, htk=htk) + max_mel = hz_to_mel(fmax, htk=htk) + + mels = np.linspace(min_mel, max_mel, n_mels) + + return mel_to_hz(mels, htk=htk) + +def mel(sr, n_fft, n_mels=128, fmin=0.0, fmax=None, htk=False, + norm=1, dtype=np.float32): + """ + This function is cloned from librosa 0.7. + Please refer to the original + `documentation `__ + for more info. + Create a Filterbank matrix to combine FFT bins into Mel-frequency bins + + + Parameters + ---------- + sr : number > 0 [scalar] + sampling rate of the incoming signal + n_fft : int > 0 [scalar] + number of FFT components + n_mels : int > 0 [scalar] + number of Mel bands to generate + fmin : float >= 0 [scalar] + lowest frequency (in Hz) + fmax : float >= 0 [scalar] + highest frequency (in Hz). + If `None`, use `fmax = sr / 2.0` + htk : bool [scalar] + use HTK formula instead of Slaney + norm : {None, 1, np.inf} [scalar] + if 1, divide the triangular mel weights by the width of the mel band + (area normalization). Otherwise, leave all the triangles aiming for + a peak value of 1.0 + dtype : np.dtype + The data type of the output basis. + By default, uses 32-bit (single-precision) floating point. + + Returns + ------- + M : np.ndarray [shape=(n_mels, 1 + n_fft/2)] + Mel transform matrix + + Notes + ----- + This function caches at level 10. + + Examples + -------- + >>> melfb = librosa.filters.mel(22050, 2048) + >>> melfb + array([[ 0. , 0.016, ..., 0. , 0. ], + [ 0. , 0. , ..., 0. , 0. ], + ..., + [ 0. , 0. , ..., 0. , 0. ], + [ 0. , 0. , ..., 0. , 0. ]]) + Clip the maximum frequency to 8KHz + >>> librosa.filters.mel(22050, 2048, fmax=8000) + array([[ 0. , 0.02, ..., 0. , 0. ], + [ 0. , 0. , ..., 0. , 0. ], + ..., + [ 0. , 0. , ..., 0. , 0. ], + [ 0. , 0. , ..., 0. , 0. ]]) + >>> import matplotlib.pyplot as plt + >>> plt.figure() + >>> librosa.display.specshow(melfb, x_axis='linear') + >>> plt.ylabel('Mel filter') + >>> plt.title('Mel filter bank') + >>> plt.colorbar() + >>> plt.tight_layout() + >>> plt.show() + """ + + if fmax is None: + fmax = float(sr) / 2 + + if norm is not None and norm != 1 and norm != np.inf: + raise ParameterError('Unsupported norm: {}'.format(repr(norm))) + + # Initialize the weights + n_mels = int(n_mels) + weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype) + + # Center freqs of each FFT bin + fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft) + + # 'Center freqs' of mel bands - uniformly spaced between limits + mel_f = mel_frequencies(n_mels + 2, fmin=fmin, fmax=fmax, htk=htk) + + fdiff = np.diff(mel_f) + ramps = np.subtract.outer(mel_f, fftfreqs) + + for i in range(n_mels): + # lower and upper slopes for all bins + lower = -ramps[i] / fdiff[i] + upper = ramps[i+2] / fdiff[i+1] + + # .. then intersect them with each other and zero + weights[i] = np.maximum(0, np.minimum(lower, upper)) + + if norm == 1: + # Slaney-style mel is scaled to be approx constant energy per channel + enorm = 2.0 / (mel_f[2:n_mels+2] - mel_f[:n_mels]) + weights *= enorm[:, np.newaxis] + + # Only check weights if f_mel[0] is positive + if not np.all((mel_f[:-2] == 0) | (weights.max(axis=1) > 0)): + # This means we have an empty channel somewhere + warnings.warn('Empty filters detected in mel frequency basis. ' + 'Some channels will produce empty responses. ' + 'Try increasing your sampling rate (and fmax) or ' + 'reducing n_mels.') + + return weights +### ------------------End of Functions for generating kenral for Mel Spectrogram ----------------### + + +### ------------------Functions for making STFT same as librosa ---------------------------------### +def pad_center(data, size, axis=-1, **kwargs): + '''Wrapper for np.pad to automatically center an array prior to padding. + This is analogous to `str.center()` + + Examples + -------- + >>> # Generate a vector + >>> data = np.ones(5) + >>> librosa.util.pad_center(data, 10, mode='constant') + array([ 0., 0., 1., 1., 1., 1., 1., 0., 0., 0.]) + + >>> # Pad a matrix along its first dimension + >>> data = np.ones((3, 5)) + >>> librosa.util.pad_center(data, 7, axis=0) + array([[ 0., 0., 0., 0., 0.], + [ 0., 0., 0., 0., 0.], + [ 1., 1., 1., 1., 1.], + [ 1., 1., 1., 1., 1.], + [ 1., 1., 1., 1., 1.], + [ 0., 0., 0., 0., 0.], + [ 0., 0., 0., 0., 0.]]) + >>> # Or its second dimension + >>> librosa.util.pad_center(data, 7, axis=1) + array([[ 0., 1., 1., 1., 1., 1., 0.], + [ 0., 1., 1., 1., 1., 1., 0.], + [ 0., 1., 1., 1., 1., 1., 0.]]) + + Parameters + ---------- + data : np.ndarray + Vector to be padded and centered + + size : int >= len(data) [scalar] + Length to pad `data` + + axis : int + Axis along which to pad and center the data + + kwargs : additional keyword arguments + arguments passed to `np.pad()` + + Returns + ------- + data_padded : np.ndarray + `data` centered and padded to length `size` along the + specified axis + + Raises + ------ + ParameterError + If `size < data.shape[axis]` + + See Also + -------- + numpy.pad + ''' + + kwargs.setdefault('mode', 'constant') + + n = data.shape[axis] + + lpad = int((size - n) // 2) + + lengths = [(0, 0)] * data.ndim + lengths[axis] = (lpad, int(size - n - lpad)) + + if lpad < 0: + raise ParameterError(('Target size ({:d}) must be ' + 'at least input size ({:d})').format(size, n)) + + return np.pad(data, lengths, **kwargs) + +### ------------------End of functions for making STFT same as librosa ---------------------------### diff --git a/third_party/nnAudio/nnAudio/utils.py b/third_party/nnAudio/nnAudio/utils.py new file mode 100644 index 00000000..a5ac366c --- /dev/null +++ b/third_party/nnAudio/nnAudio/utils.py @@ -0,0 +1,535 @@ +""" +Module containing helper functions such as overlap sum and Fourier kernels generators +""" + +import torch +from torch.nn.functional import conv1d, fold + +import numpy as np +from time import time +import math +from scipy.signal import get_window +from scipy import signal +from scipy import fft +import warnings + +from nnAudio.librosa_functions import * + +## --------------------------- Filter Design ---------------------------## +def torch_window_sumsquare(w, n_frames, stride, n_fft, power=2): + w_stacks = w.unsqueeze(-1).repeat((1,n_frames)).unsqueeze(0) + # Window length + stride*(frames-1) + output_len = w_stacks.shape[1] + stride*(w_stacks.shape[2]-1) + return fold(w_stacks**power, (1,output_len), kernel_size=(1,n_fft), stride=stride) + +def overlap_add(X, stride): + n_fft = X.shape[1] + output_len = n_fft + stride*(X.shape[2]-1) + + return fold(X, (1,output_len), kernel_size=(1,n_fft), stride=stride).flatten(1) + +def uniform_distribution(r1,r2, *size, device): + return (r1 - r2) * torch.rand(*size, device=device) + r2 + +def extend_fbins(X): + """Extending the number of frequency bins from `n_fft//2+1` back to `n_fft` by + reversing all bins except DC and Nyquist and append it on top of existing spectrogram""" + X_upper = torch.flip(X[:,1:-1],(0,1)) + X_upper[:,:,:,1] = -X_upper[:,:,:,1] # For the imaganinry part, it is an odd function + return torch.cat((X[:, :, :], X_upper), 1) + + +def downsampling_by_n(x, filterKernel, n): + """A helper function that downsamples the audio by a arbitary factor n. + It is used in CQT2010 and CQT2010v2. + + Parameters + ---------- + x : torch.Tensor + The input waveform in ``torch.Tensor`` type with shape ``(batch, 1, len_audio)`` + + filterKernel : str + Filter kernel in ``torch.Tensor`` type with shape ``(1, 1, len_kernel)`` + + n : int + The downsampling factor + + Returns + ------- + torch.Tensor + The downsampled waveform + + Examples + -------- + >>> x_down = downsampling_by_n(x, filterKernel) + """ + + x = conv1d(x,filterKernel,stride=n, padding=(filterKernel.shape[-1]-1)//2) + return x + + +def downsampling_by_2(x, filterKernel): + """A helper function that downsamples the audio by half. It is used in CQT2010 and CQT2010v2 + + Parameters + ---------- + x : torch.Tensor + The input waveform in ``torch.Tensor`` type with shape ``(batch, 1, len_audio)`` + + filterKernel : str + Filter kernel in ``torch.Tensor`` type with shape ``(1, 1, len_kernel)`` + + Returns + ------- + torch.Tensor + The downsampled waveform + + Examples + -------- + >>> x_down = downsampling_by_2(x, filterKernel) + """ + + x = conv1d(x,filterKernel,stride=2, padding=(filterKernel.shape[-1]-1)//2) + return x + + +## Basic tools for computation ## +def nextpow2(A): + """A helper function to calculate the next nearest number to the power of 2. + + Parameters + ---------- + A : float + A float number that is going to be rounded up to the nearest power of 2 + + Returns + ------- + int + The nearest power of 2 to the input number ``A`` + + Examples + -------- + + >>> nextpow2(6) + 3 + """ + + return int(np.ceil(np.log2(A))) + +## Basic tools for computation ## +def prepow2(A): + """A helper function to calculate the next nearest number to the power of 2. + + Parameters + ---------- + A : float + A float number that is going to be rounded up to the nearest power of 2 + + Returns + ------- + int + The nearest power of 2 to the input number ``A`` + + Examples + -------- + + >>> nextpow2(6) + 3 + """ + + return int(np.floor(np.log2(A))) + + +def complex_mul(cqt_filter, stft): + """Since PyTorch does not support complex numbers and its operation. + We need to write our own complex multiplication function. This one is specially + designed for CQT usage. + + Parameters + ---------- + cqt_filter : tuple of torch.Tensor + The tuple is in the format of ``(real_torch_tensor, imag_torch_tensor)`` + + Returns + ------- + tuple of torch.Tensor + The output is in the format of ``(real_torch_tensor, imag_torch_tensor)`` + """ + + cqt_filter_real = cqt_filter[0] + cqt_filter_imag = cqt_filter[1] + fourier_real = stft[0] + fourier_imag = stft[1] + + CQT_real = torch.matmul(cqt_filter_real, fourier_real) - torch.matmul(cqt_filter_imag, fourier_imag) + CQT_imag = torch.matmul(cqt_filter_real, fourier_imag) + torch.matmul(cqt_filter_imag, fourier_real) + + return CQT_real, CQT_imag + + +def broadcast_dim(x): + """ + Auto broadcast input so that it can fits into a Conv1d + """ + + if x.dim() == 2: + x = x[:, None, :] + elif x.dim() == 1: + # If nn.DataParallel is used, this broadcast doesn't work + x = x[None, None, :] + elif x.dim() == 3: + pass + else: + raise ValueError("Only support input with shape = (batch, len) or shape = (len)") + return x + + +def broadcast_dim_conv2d(x): + """ + Auto broadcast input so that it can fits into a Conv2d + """ + + if x.dim() == 3: + x = x[:, None, :,:] + + else: + raise ValueError("Only support input with shape = (batch, len) or shape = (len)") + return x + + +## Kernal generation functions ## +def create_fourier_kernels(n_fft, win_length=None, freq_bins=None, fmin=50,fmax=6000, sr=44100, + freq_scale='linear', window='hann', verbose=True): + """ This function creates the Fourier Kernel for STFT, Melspectrogram and CQT. + Most of the parameters follow librosa conventions. Part of the code comes from + pytorch_musicnet. https://github.com/jthickstun/pytorch_musicnet + + Parameters + ---------- + n_fft : int + The window size + + freq_bins : int + Number of frequency bins. Default is ``None``, which means ``n_fft//2+1`` bins + + fmin : int + The starting frequency for the lowest frequency bin. + If freq_scale is ``no``, this argument does nothing. + + fmax : int + The ending frequency for the highest frequency bin. + If freq_scale is ``no``, this argument does nothing. + + sr : int + The sampling rate for the input audio. It is used to calculate the correct ``fmin`` and ``fmax``. + Setting the correct sampling rate is very important for calculating the correct frequency. + + freq_scale: 'linear', 'log', or 'no' + Determine the spacing between each frequency bin. + When 'linear' or 'log' is used, the bin spacing can be controlled by ``fmin`` and ``fmax``. + If 'no' is used, the bin will start at 0Hz and end at Nyquist frequency with linear spacing. + + Returns + ------- + wsin : numpy.array + Imaginary Fourier Kernel with the shape ``(freq_bins, 1, n_fft)`` + + wcos : numpy.array + Real Fourier Kernel with the shape ``(freq_bins, 1, n_fft)`` + + bins2freq : list + Mapping each frequency bin to frequency in Hz. + + binslist : list + The normalized frequency ``k`` in digital domain. + This ``k`` is in the Discrete Fourier Transform equation $$ + + """ + + if freq_bins==None: freq_bins = n_fft//2+1 + if win_length==None: win_length = n_fft + + s = np.arange(0, n_fft, 1.) + wsin = np.empty((freq_bins,1,n_fft)) + wcos = np.empty((freq_bins,1,n_fft)) + start_freq = fmin + end_freq = fmax + bins2freq = [] + binslist = [] + + # num_cycles = start_freq*d/44000. + # scaling_ind = np.log(end_freq/start_freq)/k + + # Choosing window shape + + window_mask = get_window(window,int(win_length), fftbins=True) + window_mask = pad_center(window_mask, n_fft) + + if freq_scale == 'linear': + if verbose==True: + print(f"sampling rate = {sr}. Please make sure the sampling rate is correct in order to" + f"get a valid freq range") + start_bin = start_freq*n_fft/sr + scaling_ind = (end_freq-start_freq)*(n_fft/sr)/freq_bins + + for k in range(freq_bins): # Only half of the bins contain useful info + # print("linear freq = {}".format((k*scaling_ind+start_bin)*sr/n_fft)) + bins2freq.append((k*scaling_ind+start_bin)*sr/n_fft) + binslist.append((k*scaling_ind+start_bin)) + wsin[k,0,:] = np.sin(2*np.pi*(k*scaling_ind+start_bin)*s/n_fft) + wcos[k,0,:] = np.cos(2*np.pi*(k*scaling_ind+start_bin)*s/n_fft) + + elif freq_scale == 'log': + if verbose==True: + print(f"sampling rate = {sr}. Please make sure the sampling rate is correct in order to" + f"get a valid freq range") + start_bin = start_freq*n_fft/sr + scaling_ind = np.log(end_freq/start_freq)/freq_bins + + for k in range(freq_bins): # Only half of the bins contain useful info + # print("log freq = {}".format(np.exp(k*scaling_ind)*start_bin*sr/n_fft)) + bins2freq.append(np.exp(k*scaling_ind)*start_bin*sr/n_fft) + binslist.append((np.exp(k*scaling_ind)*start_bin)) + wsin[k,0,:] = np.sin(2*np.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft) + wcos[k,0,:] = np.cos(2*np.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft) + + elif freq_scale == 'no': + for k in range(freq_bins): # Only half of the bins contain useful info + bins2freq.append(k*sr/n_fft) + binslist.append(k) + wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft) + wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft) + else: + print("Please select the correct frequency scale, 'linear' or 'log'") + return wsin.astype(np.float32),wcos.astype(np.float32), bins2freq, binslist, window_mask.astype(np.float32) + + +# Tools for CQT + +def create_cqt_kernels(Q, fs, fmin, n_bins=84, bins_per_octave=12, norm=1, + window='hann', fmax=None, topbin_check=True): + """ + Automatically create CQT kernels in time domain + """ + + fftLen = 2**nextpow2(np.ceil(Q * fs / fmin)) + # minWin = 2**nextpow2(np.ceil(Q * fs / fmax)) + + if (fmax != None) and (n_bins == None): + n_bins = np.ceil(bins_per_octave * np.log2(fmax / fmin)) # Calculate the number of bins + freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave)) + + elif (fmax == None) and (n_bins != None): + freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave)) + + else: + warnings.warn('If fmax is given, n_bins will be ignored',SyntaxWarning) + n_bins = np.ceil(bins_per_octave * np.log2(fmax / fmin)) # Calculate the number of bins + freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave)) + + if np.max(freqs) > fs/2 and topbin_check==True: + raise ValueError('The top bin {}Hz has exceeded the Nyquist frequency, \ + please reduce the n_bins'.format(np.max(freqs))) + + tempKernel = np.zeros((int(n_bins), int(fftLen)), dtype=np.complex64) + specKernel = np.zeros((int(n_bins), int(fftLen)), dtype=np.complex64) + + lengths = np.ceil(Q * fs / freqs) + for k in range(0, int(n_bins)): + freq = freqs[k] + l = np.ceil(Q * fs / freq) + + # Centering the kernels + if l%2==1: # pad more zeros on RHS + start = int(np.ceil(fftLen / 2.0 - l / 2.0))-1 + else: + start = int(np.ceil(fftLen / 2.0 - l / 2.0)) + + sig = get_window_dispatch(window,int(l), fftbins=True)*np.exp(np.r_[-l//2:l//2]*1j*2*np.pi*freq/fs)/l + + if norm: # Normalizing the filter # Trying to normalize like librosa + tempKernel[k, start:start + int(l)] = sig/np.linalg.norm(sig, norm) + else: + tempKernel[k, start:start + int(l)] = sig + # specKernel[k, :] = fft(tempKernel[k]) + + # return specKernel[:,:fftLen//2+1], fftLen, torch.tensor(lenghts).float() + return tempKernel, fftLen, torch.tensor(lengths).float(), freqs + + +def get_window_dispatch(window, N, fftbins=True): + if isinstance(window, str): + return get_window(window, N, fftbins=fftbins) + elif isinstance(window, tuple): + if window[0] == 'gaussian': + assert window[1] >= 0 + sigma = np.floor(- N / 2 / np.sqrt(- 2 * np.log(10**(- window[1] / 20)))) + return get_window(('gaussian', sigma), N, fftbins=fftbins) + else: + Warning("Tuple windows may have undesired behaviour regarding Q factor") + elif isinstance(window, float): + Warning("You are using Kaiser window with beta factor " + str(window) + ". Correct behaviour not checked.") + else: + raise Exception("The function get_window from scipy only supports strings, tuples and floats.") + + + +def get_cqt_complex(x, cqt_kernels_real, cqt_kernels_imag, hop_length, padding): + """Multiplying the STFT result with the cqt_kernel, check out the 1992 CQT paper [1] + for how to multiple the STFT result with the CQT kernel + [2] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of + a constant Q transform.” (1992).""" + + # STFT, converting the audio input from time domain to frequency domain + try: + x = padding(x) # When center == True, we need padding at the beginning and ending + except: + warnings.warn(f"\ninput size = {x.shape}\tkernel size = {cqt_kernels_real.shape[-1]}\n" + "padding with reflection mode might not be the best choice, try using constant padding", + UserWarning) + x = torch.nn.functional.pad(x, (cqt_kernels_real.shape[-1]//2, cqt_kernels_real.shape[-1]//2)) + CQT_real = conv1d(x, cqt_kernels_real, stride=hop_length) + CQT_imag = -conv1d(x, cqt_kernels_imag, stride=hop_length) + + return torch.stack((CQT_real, CQT_imag),-1) + +def get_cqt_complex2(x, cqt_kernels_real, cqt_kernels_imag, hop_length, padding, wcos=None, wsin=None): + """Multiplying the STFT result with the cqt_kernel, check out the 1992 CQT paper [1] + for how to multiple the STFT result with the CQT kernel + [2] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of + a constant Q transform.” (1992).""" + + # STFT, converting the audio input from time domain to frequency domain + try: + x = padding(x) # When center == True, we need padding at the beginning and ending + except: + warnings.warn(f"\ninput size = {x.shape}\tkernel size = {cqt_kernels_real.shape[-1]}\n" + "padding with reflection mode might not be the best choice, try using constant padding", + UserWarning) + x = torch.nn.functional.pad(x, (cqt_kernels_real.shape[-1]//2, cqt_kernels_real.shape[-1]//2)) + + + + if wcos==None or wsin==None: + CQT_real = conv1d(x, cqt_kernels_real, stride=hop_length) + CQT_imag = -conv1d(x, cqt_kernels_imag, stride=hop_length) + + else: + fourier_real = conv1d(x, wcos, stride=hop_length) + fourier_imag = conv1d(x, wsin, stride=hop_length) + # Multiplying input with the CQT kernel in freq domain + CQT_real, CQT_imag = complex_mul((cqt_kernels_real, cqt_kernels_imag), + (fourier_real, fourier_imag)) + + return torch.stack((CQT_real, CQT_imag),-1) + + + + +def create_lowpass_filter(band_center=0.5, kernelLength=256, transitionBandwidth=0.03): + """ + Calculate the highest frequency we need to preserve and the lowest frequency we allow + to pass through. + Note that frequency is on a scale from 0 to 1 where 0 is 0 and 1 is Nyquist frequency of + the signal BEFORE downsampling. + """ + + # transitionBandwidth = 0.03 + passbandMax = band_center / (1 + transitionBandwidth) + stopbandMin = band_center * (1 + transitionBandwidth) + + # Unlike the filter tool we used online yesterday, this tool does + # not allow us to specify how closely the filter matches our + # specifications. Instead, we specify the length of the kernel. + # The longer the kernel is, the more precisely it will match. + # kernelLength = 256 + + # We specify a list of key frequencies for which we will require + # that the filter match a specific output gain. + # From [0.0 to passbandMax] is the frequency range we want to keep + # untouched and [stopbandMin, 1.0] is the range we want to remove + keyFrequencies = [0.0, passbandMax, stopbandMin, 1.0] + + # We specify a list of output gains to correspond to the key + # frequencies listed above. + # The first two gains are 1.0 because they correspond to the first + # two key frequencies. the second two are 0.0 because they + # correspond to the stopband frequencies + gainAtKeyFrequencies = [1.0, 1.0, 0.0, 0.0] + + # This command produces the filter kernel coefficients + filterKernel = signal.firwin2(kernelLength, keyFrequencies, gainAtKeyFrequencies) + + return filterKernel.astype(np.float32) + +def get_early_downsample_params(sr, hop_length, fmax_t, Q, n_octaves, verbose): + """Used in CQT2010 and CQT2010v2""" + + window_bandwidth = 1.5 # for hann window + filter_cutoff = fmax_t * (1 + 0.5 * window_bandwidth / Q) + sr, hop_length, downsample_factor = early_downsample(sr, + hop_length, + n_octaves, + sr//2, + filter_cutoff) + if downsample_factor != 1: + if verbose==True: + print("Can do early downsample, factor = ", downsample_factor) + earlydownsample=True + # print("new sr = ", sr) + # print("new hop_length = ", hop_length) + early_downsample_filter = create_lowpass_filter(band_center=1/downsample_factor, + kernelLength=256, + transitionBandwidth=0.03) + early_downsample_filter = torch.tensor(early_downsample_filter)[None, None, :] + + else: + if verbose==True: + print("No early downsampling is required, downsample_factor = ", downsample_factor) + early_downsample_filter = None + earlydownsample=False + + return sr, hop_length, downsample_factor, early_downsample_filter, earlydownsample + +def early_downsample(sr, hop_length, n_octaves, + nyquist, filter_cutoff): + '''Return new sampling rate and hop length after early dowansampling''' + downsample_count = early_downsample_count(nyquist, filter_cutoff, hop_length, n_octaves) + # print("downsample_count = ", downsample_count) + downsample_factor = 2**(downsample_count) + + hop_length //= downsample_factor # Getting new hop_length + new_sr = sr / float(downsample_factor) # Getting new sampling rate + sr = new_sr + + return sr, hop_length, downsample_factor + + +# The following two downsampling count functions are obtained from librosa CQT +# They are used to determine the number of pre resamplings if the starting and ending frequency +# are both in low frequency regions. +def early_downsample_count(nyquist, filter_cutoff, hop_length, n_octaves): + '''Compute the number of early downsampling operations''' + + downsample_count1 = max(0, int(np.ceil(np.log2(0.85 * nyquist / + filter_cutoff)) - 1) - 1) + # print("downsample_count1 = ", downsample_count1) + num_twos = nextpow2(hop_length) + downsample_count2 = max(0, num_twos - n_octaves + 1) + # print("downsample_count2 = ",downsample_count2) + + return min(downsample_count1, downsample_count2) + +def early_downsample(sr, hop_length, n_octaves, + nyquist, filter_cutoff): + '''Return new sampling rate and hop length after early dowansampling''' + downsample_count = early_downsample_count(nyquist, filter_cutoff, hop_length, n_octaves) + # print("downsample_count = ", downsample_count) + downsample_factor = 2**(downsample_count) + + hop_length //= downsample_factor # Getting new hop_length + new_sr = sr / float(downsample_factor) # Getting new sampling rate + + sr = new_sr + + return sr, hop_length, downsample_factor \ No newline at end of file diff --git a/third_party/nnAudio/setup.py b/third_party/nnAudio/setup.py new file mode 100755 index 00000000..9b2f3688 --- /dev/null +++ b/third_party/nnAudio/setup.py @@ -0,0 +1,37 @@ +import setuptools +import codecs +import os.path + +with open("README.md", "r") as fh: + long_description = fh.read() + +def read(rel_path): + here = os.path.abspath(os.path.dirname(__file__)) + with codecs.open(os.path.join(here, rel_path), 'r') as fp: + return fp.read() + +def get_version(rel_path): + for line in read(rel_path).splitlines(): + if line.startswith('__version__'): + delim = '"' if '"' in line else "'" + return line.split(delim)[1] + else: + raise RuntimeError("Unable to find version string.") + +setuptools.setup( + name="nnAudio", # Replace with your own username + version=get_version("nnAudio/__init__.py"), + author="KinWaiCheuk", + author_email="u3500684@connect.hku.hk", + description="A fast GPU audio processing toolbox with 1D convolutional neural network", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/KinWaiCheuk/nnAudio", + packages=setuptools.find_packages(), + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + ], + python_requires='>=3.6', +) diff --git a/third_party/nnAudio/tests/parameters.py b/third_party/nnAudio/tests/parameters.py new file mode 100644 index 00000000..c8356ac6 --- /dev/null +++ b/third_party/nnAudio/tests/parameters.py @@ -0,0 +1,38 @@ +# Creating parameters for STFT test +""" +It is equivalent to +[(1024, 128, 'ones'), + (1024, 128, 'hann'), + (1024, 128, 'hamming'), + (2048, 128, 'ones'), + (2048, 512, 'ones'), + (2048, 128, 'hann'), + (2048, 512, 'hann'), + (2048, 128, 'hamming'), + (2048, 512, 'hamming'), + (None, None, None)] +""" + +stft_parameters = [] +n_fft = [1024,2048] +hop_length = {128,512,1024} +window = ['ones', 'hann', 'hamming'] +for i in n_fft: + for k in window: + for j in hop_length: + if j < (i/2): + stft_parameters.append((i,j,k)) +stft_parameters.append((256, None, 'hann')) + +stft_with_win_parameters = [] +n_fft = [512,1024] +win_length = [400, 900] +hop_length = {128,256} +for i in n_fft: + for j in win_length: + if j < i: + for k in hop_length: + if k < (i/2): + stft_with_win_parameters.append((i,j,k)) + +mel_win_parameters = [(512,400), (1024, 1000)] \ No newline at end of file diff --git a/third_party/nnAudio/tests/test_spectrogram.py b/third_party/nnAudio/tests/test_spectrogram.py new file mode 100644 index 00000000..3aa074c1 --- /dev/null +++ b/third_party/nnAudio/tests/test_spectrogram.py @@ -0,0 +1,373 @@ +import pytest +import librosa +import torch +import matplotlib.pyplot as plt +from scipy.signal import chirp, sweep_poly +from nnAudio.Spectrogram import * +from parameters import * + +gpu_idx=0 + +# librosa example audio for testing +example_y, example_sr = librosa.load(librosa.util.example_audio_file()) + + +@pytest.mark.parametrize("n_fft, hop_length, window", stft_parameters) +@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}']) +def test_inverse2(n_fft, hop_length, window, device): + x = torch.tensor(example_y,device=device) + stft = STFT(n_fft=n_fft, hop_length=hop_length, window=window).to(device) + istft = iSTFT(n_fft=n_fft, hop_length=hop_length, window=window).to(device) + X = stft(x.unsqueeze(0), output_format="Complex") + x_recon = istft(X, length=x.shape[0], onesided=True).squeeze() + assert np.allclose(x.cpu(), x_recon.cpu(), rtol=1e-5, atol=1e-3) + +@pytest.mark.parametrize("n_fft, hop_length, window", stft_parameters) +@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}']) +def test_inverse(n_fft, hop_length, window, device): + x = torch.tensor(example_y, device=device) + stft = STFT(n_fft=n_fft, hop_length=hop_length, window=window, iSTFT=True).to(device) + X = stft(x.unsqueeze(0), output_format="Complex") + x_recon = stft.inverse(X, length=x.shape[0]).squeeze() + assert np.allclose(x.cpu(), x_recon.cpu(), rtol=1e-3, atol=1) + + + +# @pytest.mark.parametrize("n_fft, hop_length, window", stft_parameters) + +# def test_inverse_GPU(n_fft, hop_length, window): +# x = torch.tensor(example_y,device=f'cuda:{gpu_idx}') +# stft = STFT(n_fft=n_fft, hop_length=hop_length, window=window, device=f'cuda:{gpu_idx}') +# X = stft(x.unsqueeze(0), output_format="Complex") +# x_recon = stft.inverse(X, num_samples=x.shape[0]).squeeze() +# assert np.allclose(x.cpu(), x_recon.cpu(), rtol=1e-3, atol=1) + + +@pytest.mark.parametrize("n_fft, hop_length, window", stft_parameters) +@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}']) +def test_stft_complex(n_fft, hop_length, window, device): + x = example_y + stft = STFT(n_fft=n_fft, hop_length=hop_length, window=window).to(device) + X = stft(torch.tensor(x, device=device).unsqueeze(0), output_format="Complex") + X_real, X_imag = X[:, :, :, 0].squeeze(), X[:, :, :, 1].squeeze() + X_librosa = librosa.stft(x, n_fft=n_fft, hop_length=hop_length, window=window) + real_diff, imag_diff = np.allclose(X_real.cpu(), X_librosa.real, rtol=1e-3, atol=1e-3), \ + np.allclose(X_imag.cpu(), X_librosa.imag, rtol=1e-3, atol=1e-3) + + assert real_diff and imag_diff + +# @pytest.mark.parametrize("n_fft, hop_length, window", stft_parameters) +# def test_stft_complex_GPU(n_fft, hop_length, window): +# x = example_y +# stft = STFT(n_fft=n_fft, hop_length=hop_length, window=window, device=f'cuda:{gpu_idx}') +# X = stft(torch.tensor(x,device=f'cuda:{gpu_idx}').unsqueeze(0), output_format="Complex") +# X_real, X_imag = X[:, :, :, 0].squeeze().detach().cpu(), X[:, :, :, 1].squeeze().detach().cpu() +# X_librosa = librosa.stft(x, n_fft=n_fft, hop_length=hop_length, window=window) +# real_diff, imag_diff = np.allclose(X_real, X_librosa.real, rtol=1e-3, atol=1e-3), \ +# np.allclose(X_imag, X_librosa.imag, rtol=1e-3, atol=1e-3) + +# assert real_diff and imag_diff + +@pytest.mark.parametrize("n_fft, win_length, hop_length", stft_with_win_parameters) +@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}']) +def test_stft_complex_winlength(n_fft, win_length, hop_length, device): + x = example_y + stft = STFT(n_fft=n_fft, win_length=win_length, hop_length=hop_length).to(device) + X = stft(torch.tensor(x, device=device).unsqueeze(0), output_format="Complex") + X_real, X_imag = X[:, :, :, 0].squeeze(), X[:, :, :, 1].squeeze() + X_librosa = librosa.stft(x, n_fft=n_fft, win_length=win_length, hop_length=hop_length) + real_diff, imag_diff = np.allclose(X_real.cpu(), X_librosa.real, rtol=1e-3, atol=1e-3), \ + np.allclose(X_imag.cpu(), X_librosa.imag, rtol=1e-3, atol=1e-3) + assert real_diff and imag_diff + +@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}']) +def test_stft_magnitude(device): + x = example_y + stft = STFT(n_fft=2048, hop_length=512).to(device) + X = stft(torch.tensor(x, device=device).unsqueeze(0), output_format="Magnitude").squeeze() + X_librosa, _ = librosa.core.magphase(librosa.stft(x, n_fft=2048, hop_length=512)) + assert np.allclose(X.cpu(), X_librosa, rtol=1e-3, atol=1e-3) + +@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}']) +def test_stft_phase(device): + x = example_y + stft = STFT(n_fft=2048, hop_length=512).to(device) + X = stft(torch.tensor(x, device=device).unsqueeze(0), output_format="Phase") + X_real, X_imag = torch.cos(X).squeeze(), torch.sin(X).squeeze() + _, X_librosa = librosa.core.magphase(librosa.stft(x, n_fft=2048, hop_length=512)) + + real_diff, imag_diff = np.mean(np.abs(X_real.cpu().numpy() - X_librosa.real)), \ + np.mean(np.abs(X_imag.cpu().numpy() - X_librosa.imag)) + + # I find that np.allclose is too strict for allowing phase to be similar to librosa. + # Hence for phase we use average element-wise distance as the test metric. + assert real_diff < 2e-4 and imag_diff < 2e-4 + +@pytest.mark.parametrize("n_fft, win_length", mel_win_parameters) +@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}']) +def test_mel_spectrogram(n_fft, win_length, device): + x = example_y + melspec = MelSpectrogram(n_fft=n_fft, win_length=win_length, hop_length=512).to(device) + X = melspec(torch.tensor(x, device=device).unsqueeze(0)).squeeze() + X_librosa = librosa.feature.melspectrogram(x, n_fft=n_fft, win_length=win_length, hop_length=512) + assert np.allclose(X.cpu(), X_librosa, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}']) +def test_cqt_1992(device): + # Log sweep case + fs = 44100 + t = 1 + f0 = 55 + f1 = 22050 + s = np.linspace(0, t, fs*t) + x = chirp(s, f0, 1, f1, method='logarithmic') + x = x.astype(dtype=np.float32) + + # Magnitude + stft = CQT1992(sr=fs, fmin=220, output_format="Magnitude", + n_bins=80, bins_per_octave=24).to(device) + X = stft(torch.tensor(x, device=device).unsqueeze(0)) + + + # Complex + stft = CQT1992(sr=fs, fmin=220, output_format="Complex", + n_bins=80, bins_per_octave=24).to(device) + X = stft(torch.tensor(x, device=device).unsqueeze(0)) + + # Phase + stft = CQT1992(sr=fs, fmin=220, output_format="Phase", + n_bins=160, bins_per_octave=24).to(device) + X = stft(torch.tensor(x, device=device).unsqueeze(0)) + + assert True + +@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}']) +def test_cqt_2010(device): + # Log sweep case + fs = 44100 + t = 1 + f0 = 55 + f1 = 22050 + s = np.linspace(0, t, fs*t) + x = chirp(s, f0, 1, f1, method='logarithmic') + x = x.astype(dtype=np.float32) + + # Magnitude + stft = CQT2010(sr=fs, fmin=110, output_format="Magnitude", + n_bins=160, bins_per_octave=24).to(device) + X = stft(torch.tensor(x, device=device).unsqueeze(0)) + + # Complex + stft = CQT2010(sr=fs, fmin=110, output_format="Complex", + n_bins=160, bins_per_octave=24).to(device) + X = stft(torch.tensor(x, device=device).unsqueeze(0)) + + # Phase + stft = CQT2010(sr=fs, fmin=110, output_format="Phase", + n_bins=160, bins_per_octave=24).to(device) + X = stft(torch.tensor(x, device=device).unsqueeze(0)) + assert True + +@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}']) +def test_cqt_1992_v2_log(device): + # Log sweep case + fs = 44100 + t = 1 + f0 = 55 + f1 = 22050 + s = np.linspace(0, t, fs*t) + x = chirp(s, f0, 1, f1, method='logarithmic') + x = x.astype(dtype=np.float32) + + # Magnitude + stft = CQT1992v2(sr=fs, fmin=55, output_format="Magnitude", + n_bins=207, bins_per_octave=24).to(device) + X = stft(torch.tensor(x, device=device).unsqueeze(0)) + ground_truth = np.load("tests/ground-truths/log-sweep-cqt-1992-mag-ground-truth.npy") + X = torch.log(X + 1e-5) + assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3) + + # Complex + stft = CQT1992v2(sr=fs, fmin=55, output_format="Complex", + n_bins=207, bins_per_octave=24).to(device) + X = stft(torch.tensor(x, device=device).unsqueeze(0)) + ground_truth = np.load("tests/ground-truths/log-sweep-cqt-1992-complex-ground-truth.npy") + assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3) + + # Phase + stft = CQT1992v2(sr=fs, fmin=55, output_format="Phase", + n_bins=207, bins_per_octave=24).to(device) + X = stft(torch.tensor(x, device=device).unsqueeze(0)) + ground_truth = np.load("tests/ground-truths/log-sweep-cqt-1992-phase-ground-truth.npy") + assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3) + +@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}']) +def test_cqt_1992_v2_linear(device): + # Linear sweep case + fs = 44100 + t = 1 + f0 = 55 + f1 = 22050 + s = np.linspace(0, t, fs*t) + x = chirp(s, f0, 1, f1, method='linear') + x = x.astype(dtype=np.float32) + + # Magnitude + stft = CQT1992v2(sr=fs, fmin=55, output_format="Magnitude", + n_bins=207, bins_per_octave=24).to(device) + X = stft(torch.tensor(x, device=device).unsqueeze(0)) + ground_truth = np.load("tests/ground-truths/linear-sweep-cqt-1992-mag-ground-truth.npy") + X = torch.log(X + 1e-5) + assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3) + + # Complex + stft = CQT1992v2(sr=fs, fmin=55, output_format="Complex", + n_bins=207, bins_per_octave=24).to(device) + X = stft(torch.tensor(x, device=device).unsqueeze(0)) + ground_truth = np.load("tests/ground-truths/linear-sweep-cqt-1992-complex-ground-truth.npy") + assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3) + + # Phase + stft = CQT1992v2(sr=fs, fmin=55, output_format="Phase", + n_bins=207, bins_per_octave=24).to(device) + X = stft(torch.tensor(x, device=device).unsqueeze(0)) + ground_truth = np.load("tests/ground-truths/linear-sweep-cqt-1992-phase-ground-truth.npy") + assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3) + +@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}']) +def test_cqt_2010_v2_log(device): + # Log sweep case + fs = 44100 + t = 1 + f0 = 55 + f1 = 22050 + s = np.linspace(0, t, fs*t) + x = chirp(s, f0, 1, f1, method='logarithmic') + x = x.astype(dtype=np.float32) + + # Magnitude + stft = CQT2010v2(sr=fs, fmin=55, output_format="Magnitude", + n_bins=207, bins_per_octave=24).to(device) + X = stft(torch.tensor(x, device=device).unsqueeze(0)) + X = torch.log(X + 1e-2) +# np.save("tests/ground-truths/log-sweep-cqt-2010-mag-ground-truth", X.cpu()) + ground_truth = np.load("tests/ground-truths/log-sweep-cqt-2010-mag-ground-truth.npy") + assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3) + + # Complex + stft = CQT2010v2(sr=fs, fmin=55, output_format="Complex", + n_bins=207, bins_per_octave=24).to(device) + X = stft(torch.tensor(x, device=device).unsqueeze(0)) +# np.save("tests/ground-truths/log-sweep-cqt-2010-complex-ground-truth", X.cpu()) + ground_truth = np.load("tests/ground-truths/log-sweep-cqt-2010-complex-ground-truth.npy") + assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3) + +# # Phase +# stft = CQT2010v2(sr=fs, fmin=55, device=device, output_format="Phase", +# n_bins=207, bins_per_octave=24) +# X = stft(torch.tensor(x, device=device).unsqueeze(0)) +# # np.save("tests/ground-truths/log-sweep-cqt-2010-phase-ground-truth", X.cpu()) +# ground_truth = np.load("tests/ground-truths/log-sweep-cqt-2010-phase-ground-truth.npy") +# assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3) + +@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}']) +def test_cqt_2010_v2_linear(device): + # Linear sweep case + fs = 44100 + t = 1 + f0 = 55 + f1 = 22050 + s = np.linspace(0, t, fs*t) + x = chirp(s, f0, 1, f1, method='linear') + x = x.astype(dtype=np.float32) + + # Magnitude + stft = CQT2010v2(sr=fs, fmin=55, output_format="Magnitude", + n_bins=207, bins_per_octave=24).to(device) + X = stft(torch.tensor(x, device=device).unsqueeze(0)) + X = torch.log(X + 1e-2) +# np.save("tests/ground-truths/linear-sweep-cqt-2010-mag-ground-truth", X.cpu()) + ground_truth = np.load("tests/ground-truths/linear-sweep-cqt-2010-mag-ground-truth.npy") + assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3) + + # Complex + stft = CQT2010v2(sr=fs, fmin=55, output_format="Complex", + n_bins=207, bins_per_octave=24).to(device) + X = stft(torch.tensor(x, device=device).unsqueeze(0)) +# np.save("tests/ground-truths/linear-sweep-cqt-2010-complex-ground-truth", X.cpu()) + ground_truth = np.load("tests/ground-truths/linear-sweep-cqt-2010-complex-ground-truth.npy") + assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3) + + # Phase +# stft = CQT2010v2(sr=fs, fmin=55, device=device, output_format="Phase", +# n_bins=207, bins_per_octave=24) +# X = stft(torch.tensor(x, device=device).unsqueeze(0)) +# # np.save("tests/ground-truths/linear-sweep-cqt-2010-phase-ground-truth", X.cpu()) +# ground_truth = np.load("tests/ground-truths/linear-sweep-cqt-2010-phase-ground-truth.npy") +# assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3) + +@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}']) +def test_mfcc(device): + x = example_y + mfcc = MFCC(sr=example_sr).to(device) + X = mfcc(torch.tensor(x, device=device).unsqueeze(0)).squeeze() + X_librosa = librosa.feature.mfcc(x, sr=example_sr) + assert np.allclose(X.cpu(), X_librosa, rtol=1e-3, atol=1e-3) + + +x = torch.randn((4,44100)) # Create a batch of input for the following Data.Parallel test + +@pytest.mark.parametrize("device", [f'cuda:{gpu_idx}']) +def test_STFT_Parallel(device): + spec_layer = STFT(hop_length=512, n_fft=2048, window='hann', + freq_scale='no', + output_format='Complex').to(device) + inverse_spec_layer = iSTFT(hop_length=512, n_fft=2048, window='hann', + freq_scale='no').to(device) + + spec_layer_parallel = torch.nn.DataParallel(spec_layer) + inverse_spec_layer_parallel = torch.nn.DataParallel(inverse_spec_layer) + spec = spec_layer_parallel(x) + x_recon = inverse_spec_layer_parallel(spec, onesided=True, length=x.shape[-1]) + + assert np.allclose(x_recon.detach().cpu(), x.detach().cpu(), rtol=1e-3, atol=1e-3) + +@pytest.mark.parametrize("device", [f'cuda:{gpu_idx}']) +def test_MelSpectrogram_Parallel(device): + spec_layer = MelSpectrogram(sr=22050, n_fft=2048, n_mels=128, hop_length=512, + window='hann', center=True, pad_mode='reflect', + power=2.0, htk=False, fmin=0.0, fmax=None, norm=1, + verbose=True).to(device) + spec_layer_parallel = torch.nn.DataParallel(spec_layer) + spec = spec_layer_parallel(x) + +@pytest.mark.parametrize("device", [f'cuda:{gpu_idx}']) +def test_MFCC_Parallel(device): + spec_layer = MFCC().to(device) + spec_layer_parallel = torch.nn.DataParallel(spec_layer) + spec = spec_layer_parallel(x) + +@pytest.mark.parametrize("device", [f'cuda:{gpu_idx}']) +def test_CQT1992_Parallel(device): + spec_layer = CQT1992(fmin=110, n_bins=60, bins_per_octave=12).to(device) + spec_layer_parallel = torch.nn.DataParallel(spec_layer) + spec = spec_layer_parallel(x) + +@pytest.mark.parametrize("device", [f'cuda:{gpu_idx}']) +def test_CQT1992v2_Parallel(device): + spec_layer = CQT1992v2().to(device) + spec_layer_parallel = torch.nn.DataParallel(spec_layer) + spec = spec_layer_parallel(x) + +@pytest.mark.parametrize("device", [f'cuda:{gpu_idx}']) +def test_CQT2010_Parallel(device): + spec_layer = CQT2010().to(device) + spec_layer_parallel = torch.nn.DataParallel(spec_layer) + spec = spec_layer_parallel(x) + +@pytest.mark.parametrize("device", [f'cuda:{gpu_idx}']) +def test_CQT2010v2_Parallel(device): + spec_layer = CQT2010v2().to(device) + spec_layer_parallel = torch.nn.DataParallel(spec_layer) + spec = spec_layer_parallel(x) \ No newline at end of file