From 4ad885f8ca26c8c4fdca7ff6abfc59116814eb7f Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 9 Jun 2021 11:31:08 +0000 Subject: [PATCH 1/2] add feature notebook --- .notebook/audio_feature.ipynb | 224 ++++++++++++++++++++++++++++++++++ 1 file changed, 224 insertions(+) create mode 100644 .notebook/audio_feature.ipynb diff --git a/.notebook/audio_feature.ipynb b/.notebook/audio_feature.ipynb new file mode 100644 index 00000000..5febb0ae --- /dev/null +++ b/.notebook/audio_feature.ipynb @@ -0,0 +1,224 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 13, + "id": "matched-camera", + "metadata": {}, + "outputs": [], + "source": [ + "from nnAudio import Spectrogram\n", + "from scipy.io import wavfile\n", + "import torch\n", + "import soundfile as sf\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "middle-salem", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "16000\n", + "[43 75 69 ... 7 6 3]\n", + "(83792,)\n", + "int16\n", + "sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n", + "STFT kernels created, time used = 0.2142 seconds\n", + "tensor([[[[-4.0940e+03, 1.2600e+04],\n", + " [ 8.5108e+03, -5.4930e+03],\n", + " [-3.3631e+03, -1.7904e+03],\n", + " ...,\n", + " [ 8.2279e+03, -9.3340e+03],\n", + " [-3.1990e+03, 2.0969e+03],\n", + " [-1.2669e+03, 4.4488e+03]],\n", + "\n", + " [[ 3.4886e+03, -9.9620e+03],\n", + " [-4.5364e+03, 4.1907e+02],\n", + " [ 2.5074e+03, 7.1339e+03],\n", + " ...,\n", + " [-5.4819e+03, 3.9258e+01],\n", + " [ 4.7221e+03, 6.5887e+01],\n", + " [ 9.6492e+02, -3.4386e+03]],\n", + "\n", + " [[-3.4947e+03, 9.2981e+03],\n", + " [-7.5164e+03, 8.1856e+02],\n", + " [-5.3766e+03, -9.0889e+03],\n", + " ...,\n", + " [ 1.4317e+03, 5.7447e+03],\n", + " [-3.1178e+03, 3.0740e+03],\n", + " [-3.4351e+03, 5.6900e+02]],\n", + "\n", + " ...,\n", + "\n", + " [[ 6.7112e+01, -4.5737e+00],\n", + " [-9.6295e+00, 3.5554e+01],\n", + " [ 1.8527e+00, -1.0491e+01],\n", + " ...,\n", + " [-1.1157e+01, 3.4423e+00],\n", + " [ 3.1193e+00, -4.4388e+00],\n", + " [-8.8242e+00, 8.0324e+00]],\n", + "\n", + " [[-6.5080e+01, 2.9543e+00],\n", + " [ 3.9992e+01, -1.3836e+01],\n", + " [-9.2803e+00, 1.0318e+01],\n", + " ...,\n", + " [ 4.2928e+00, 9.2397e+00],\n", + " [ 3.6642e+00, 9.4680e+00],\n", + " [ 4.8932e+00, -2.5199e+01]],\n", + "\n", + " [[ 4.7264e+01, -1.0721e+00],\n", + " [-6.0516e+00, -1.4589e+01],\n", + " [ 1.3127e+01, 1.4995e+00],\n", + " ...,\n", + " [ 1.7333e+01, -1.4380e+01],\n", + " [-3.6046e+00, -6.1019e+00],\n", + " [ 1.3321e+01, 2.3184e+01]]]])\n" + ] + } + ], + "source": [ + "sr, song = wavfile.read('./BAC009S0764W0124.wav') # Loading your audio\n", + "print(sr)\n", + "print(song)\n", + "print(song.shape)\n", + "print(song.dtype)\n", + "x = song\n", + "x = torch.tensor(x).float() # casting the array into a PyTorch Tensor\n", + "\n", + "spec_layer = Spectrogram.STFT(n_fft=2048, freq_bins=None, hop_length=512,\n", + " window='hann', freq_scale='linear', center=True, pad_mode='reflect',\n", + " fmin=50,fmax=8000, sr=sr) # Initializing the model\n", + "\n", + "spec = spec_layer(x) # Feed-forward your waveform to get the spectrogram\n", + "print(spec)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "finished-sterling", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "16000\n", + "[43 75 69 ... 7 6 3]\n", + "(83792,)\n", + "int16\n", + "True\n", + "sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n", + "STFT kernels created, time used = 0.2495 seconds\n", + "torch.Size([1, 1025, 164, 2])\n", + "tensor([[[[-4.0940e+03, 1.2600e+04],\n", + " [ 8.5108e+03, -5.4930e+03],\n", + " [-3.3631e+03, -1.7904e+03],\n", + " ...,\n", + " [ 8.2279e+03, -9.3340e+03],\n", + " [-3.1990e+03, 2.0969e+03],\n", + " [-1.2669e+03, 4.4488e+03]],\n", + "\n", + " [[ 3.4886e+03, -9.9620e+03],\n", + " [-4.5364e+03, 4.1907e+02],\n", + " [ 2.5074e+03, 7.1339e+03],\n", + " ...,\n", + " [-5.4819e+03, 3.9258e+01],\n", + " [ 4.7221e+03, 6.5887e+01],\n", + " [ 9.6492e+02, -3.4386e+03]],\n", + "\n", + " [[-3.4947e+03, 9.2981e+03],\n", + " [-7.5164e+03, 8.1856e+02],\n", + " [-5.3766e+03, -9.0889e+03],\n", + " ...,\n", + " [ 1.4317e+03, 5.7447e+03],\n", + " [-3.1178e+03, 3.0740e+03],\n", + " [-3.4351e+03, 5.6900e+02]],\n", + "\n", + " ...,\n", + "\n", + " [[ 6.7112e+01, -4.5737e+00],\n", + " [-9.6295e+00, 3.5554e+01],\n", + " [ 1.8527e+00, -1.0491e+01],\n", + " ...,\n", + " [-1.1157e+01, 3.4423e+00],\n", + " [ 3.1193e+00, -4.4388e+00],\n", + " [-8.8242e+00, 8.0324e+00]],\n", + "\n", + " [[-6.5080e+01, 2.9543e+00],\n", + " [ 3.9992e+01, -1.3836e+01],\n", + " [-9.2803e+00, 1.0318e+01],\n", + " ...,\n", + " [ 4.2928e+00, 9.2397e+00],\n", + " [ 3.6642e+00, 9.4680e+00],\n", + " [ 4.8932e+00, -2.5199e+01]],\n", + "\n", + " [[ 4.7264e+01, -1.0721e+00],\n", + " [-6.0516e+00, -1.4589e+01],\n", + " [ 1.3127e+01, 1.4995e+00],\n", + " ...,\n", + " [ 1.7333e+01, -1.4380e+01],\n", + " [-3.6046e+00, -6.1019e+00],\n", + " [ 1.3321e+01, 2.3184e+01]]]])\n", + "True\n" + ] + } + ], + "source": [ + "wav, sr = sf.read('./BAC009S0764W0124.wav', dtype='int16')\n", + "print(sr)\n", + "print(wav)\n", + "print(wav.shape)\n", + "print(wav.dtype)\n", + "print(np.allclose(wav, song))\n", + "\n", + "x = wav\n", + "x = torch.tensor(x).float() # casting the array into a PyTorch Tensor\n", + "\n", + "spec_layer = Spectrogram.STFT(n_fft=2048, freq_bins=None, hop_length=512,\n", + " window='hann', freq_scale='linear', center=True, pad_mode='reflect',\n", + " fmin=50,fmax=8000, sr=sr) # Initializing the model\n", + "\n", + "wav_spec = spec_layer(x) # Feed-forward your waveform to get the spectrogram\n", + "print(wav_spec.shape)\n", + "print(wav_spec)\n", + "print(np.allclose(wav_spec, spec))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "running-technology", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From b08384cd35d89525799dc580760c2b358c982027 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Fri, 11 Jun 2021 10:14:03 +0000 Subject: [PATCH 2/2] using conv1d to do fft --- .notebook/audio_feature.ipynb | 995 ++++++++++++++++++++- third_party/nnAudio/.gitignore | 3 + third_party/nnAudio/nnAudio/Spectrogram.py | 11 +- third_party/nnAudio/setup.py | 13 +- 4 files changed, 1004 insertions(+), 18 deletions(-) create mode 100644 third_party/nnAudio/.gitignore diff --git a/.notebook/audio_feature.ipynb b/.notebook/audio_feature.ipynb index 5febb0ae..04b4a392 100644 --- a/.notebook/audio_feature.ipynb +++ b/.notebook/audio_feature.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 13, + "execution_count": 94, "id": "matched-camera", "metadata": {}, "outputs": [], @@ -16,7 +16,34 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 95, + "id": "quarterly-solution", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[43 75 69 ... 7 6 3]\n", + "[43 75 69 ... 7 6 3]\n", + "[43 75 69 ... 7 6 3]\n" + ] + } + ], + "source": [ + "import scipy.io.wavfile as wav\n", + "\n", + "rate,sig = wav.read('./BAC009S0764W0124.wav')\n", + "sr, song = wavfile.read('./BAC009S0764W0124.wav') # Loading your audio\n", + "sample, sr = sf.read('./BAC009S0764W0124.wav', dtype='int16')\n", + "print(sig)\n", + "print(song)\n", + "print(sample)" + ] + }, + { + "cell_type": "code", + "execution_count": 96, "id": "middle-salem", "metadata": {}, "outputs": [ @@ -29,7 +56,7 @@ "(83792,)\n", "int16\n", "sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n", - "STFT kernels created, time used = 0.2142 seconds\n", + "STFT kernels created, time used = 0.2733 seconds\n", "tensor([[[[-4.0940e+03, 1.2600e+04],\n", " [ 8.5108e+03, -5.4930e+03],\n", " [-3.3631e+03, -1.7904e+03],\n", @@ -101,7 +128,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 97, "id": "finished-sterling", "metadata": {}, "outputs": [ @@ -115,7 +142,7 @@ "int16\n", "True\n", "sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n", - "STFT kernels created, time used = 0.2495 seconds\n", + "STFT kernels created, time used = 0.2001 seconds\n", "torch.Size([1, 1025, 164, 2])\n", "tensor([[[[-4.0940e+03, 1.2600e+04],\n", " [ 8.5108e+03, -5.4930e+03],\n", @@ -193,10 +220,966 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 98, "id": "running-technology", "metadata": {}, "outputs": [], + "source": [ + "import decimal\n", + "\n", + "import numpy\n", + "import math\n", + "import logging\n", + "def round_half_up(number):\n", + " return int(decimal.Decimal(number).quantize(decimal.Decimal('1'), rounding=decimal.ROUND_HALF_UP))\n", + "\n", + "\n", + "def rolling_window(a, window, step=1):\n", + " # http://ellisvalentiner.com/post/2017-03-21-np-strides-trick\n", + " shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)\n", + " strides = a.strides + (a.strides[-1],)\n", + " return numpy.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)[::step]\n", + "\n", + "\n", + "def framesig(sig, frame_len, frame_step, dither=1.0, preemph=0.97, remove_dc_offset=True, wintype='hamming', stride_trick=True):\n", + " \"\"\"Frame a signal into overlapping frames.\n", + "\n", + " :param sig: the audio signal to frame.\n", + " :param frame_len: length of each frame measured in samples.\n", + " :param frame_step: number of samples after the start of the previous frame that the next frame should begin.\n", + " :param winfunc: the analysis window to apply to each frame. By default no window is applied.\n", + " :param stride_trick: use stride trick to compute the rolling window and window multiplication faster\n", + " :returns: an array of frames. Size is NUMFRAMES by frame_len.\n", + " \"\"\"\n", + " slen = len(sig)\n", + " frame_len = int(round_half_up(frame_len))\n", + " frame_step = int(round_half_up(frame_step))\n", + " if slen <= frame_len:\n", + " numframes = 1\n", + " else:\n", + " numframes = 1 + (( slen - frame_len) // frame_step)\n", + "\n", + " # check kaldi/src/feat/feature-window.h\n", + " padsignal = sig[:(numframes-1)*frame_step+frame_len]\n", + " if wintype is 'povey':\n", + " win = numpy.empty(frame_len)\n", + " for i in range(frame_len):\n", + " win[i] = (0.5-0.5*numpy.cos(2*numpy.pi/(frame_len-1)*i))**0.85 \n", + " else: # the hamming window\n", + " win = numpy.hamming(frame_len)\n", + " \n", + " if stride_trick:\n", + " frames = rolling_window(padsignal, window=frame_len, step=frame_step)\n", + " else:\n", + " indices = numpy.tile(numpy.arange(0, frame_len), (numframes, 1)) + numpy.tile(\n", + " numpy.arange(0, numframes * frame_step, frame_step), (frame_len, 1)).T\n", + " indices = numpy.array(indices, dtype=numpy.int32)\n", + " frames = padsignal[indices]\n", + " win = numpy.tile(win, (numframes, 1))\n", + " \n", + " frames = frames.astype(numpy.float32)\n", + " raw_frames = numpy.zeros(frames.shape)\n", + " for frm in range(frames.shape[0]):\n", + " raw_frames[frm,:] = frames[frm,:]\n", + " frames[frm,:] = do_dither(frames[frm,:], dither) # dither\n", + " frames[frm,:] = do_remove_dc_offset(frames[frm,:]) # remove dc offset\n", + " # raw_frames[frm,:] = frames[frm,:]\n", + " frames[frm,:] = do_preemphasis(frames[frm,:], preemph) # preemphasize\n", + "\n", + " return frames * win, raw_frames\n", + "\n", + "\n", + "def magspec(frames, NFFT):\n", + " \"\"\"Compute the magnitude spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).\n", + "\n", + " :param frames: the array of frames. Each row is a frame.\n", + " :param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.\n", + " :returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the magnitude spectrum of the corresponding frame.\n", + " \"\"\"\n", + " if numpy.shape(frames)[1] > NFFT:\n", + " logging.warn(\n", + " 'frame length (%d) is greater than FFT size (%d), frame will be truncated. Increase NFFT to avoid.',\n", + " numpy.shape(frames)[1], NFFT)\n", + " complex_spec = numpy.fft.rfft(frames, NFFT)\n", + " return numpy.absolute(complex_spec)\n", + "\n", + "\n", + "def powspec(frames, NFFT):\n", + " \"\"\"Compute the power spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).\n", + "\n", + " :param frames: the array of frames. Each row is a frame.\n", + " :param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.\n", + " :returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the power spectrum of the corresponding frame.\n", + " \"\"\"\n", + " return numpy.square(magspec(frames, NFFT))\n", + "\n", + "\n", + "def do_dither(signal, dither_value=1.0):\n", + " signal += numpy.random.normal(size=signal.shape) * dither_value\n", + " return signal\n", + " \n", + "def do_remove_dc_offset(signal):\n", + " signal -= numpy.mean(signal)\n", + " return signal\n", + "\n", + "def do_preemphasis(signal, coeff=0.97):\n", + " \"\"\"perform preemphasis on the input signal.\n", + "\n", + " :param signal: The signal to filter.\n", + " :param coeff: The preemphasis coefficient. 0 is no filter, default is 0.95.\n", + " :returns: the filtered signal.\n", + " \"\"\"\n", + " return numpy.append((1-coeff)*signal[0], signal[1:] - coeff * signal[:-1])" + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "id": "ignored-retreat", + "metadata": {}, + "outputs": [], + "source": [ + "def fbank(signal,samplerate=16000,winlen=0.025,winstep=0.01,\n", + " nfilt=40,nfft=512,lowfreq=0,highfreq=None,dither=1.0,remove_dc_offset=True, preemph=0.97, \n", + " wintype='hamming'):\n", + " highfreq= highfreq or samplerate/2\n", + " frames, raw_frames = framesig(signal, winlen*samplerate, winstep*samplerate, dither, preemph, remove_dc_offset, wintype)\n", + " spec = magspec(frames, nfft) # nearly the same until this part\n", + " rspec = magspec(raw_frames, nfft)\n", + " return spec, rspec\n", + "\n", + "\n", + "\n", + "def frames(signal,samplerate=16000,winlen=0.025,winstep=0.01,\n", + " nfilt=40,nfft=512,lowfreq=0,highfreq=None,dither=1.0,remove_dc_offset=True, preemph=0.97, \n", + " wintype='hamming'):\n", + " highfreq= highfreq or samplerate/2\n", + " frames, raw_frames = framesig(signal, winlen*samplerate, winstep*samplerate, dither, preemph, remove_dc_offset, wintype)\n", + " return raw_frames" + ] + }, + { + "cell_type": "code", + "execution_count": 100, + "id": "federal-teacher", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.nn.functional import conv1d, conv2d, fold\n", + "import scipy # used only in CFP\n", + "\n", + "import numpy as np\n", + "from time import time\n", + "\n", + "def pad_center(data, size, axis=-1, **kwargs):\n", + "\n", + " kwargs.setdefault('mode', 'constant')\n", + "\n", + " n = data.shape[axis]\n", + "\n", + " lpad = int((size - n) // 2)\n", + "\n", + " lengths = [(0, 0)] * data.ndim\n", + " lengths[axis] = (lpad, int(size - n - lpad))\n", + "\n", + " if lpad < 0:\n", + " raise ParameterError(('Target size ({:d}) must be '\n", + " 'at least input size ({:d})').format(size, n))\n", + "\n", + " return np.pad(data, lengths, **kwargs)\n", + "\n", + "\n", + "\n", + "sz_float = 4 # size of a float\n", + "epsilon = 10e-8 # fudge factor for normalization\n", + "\n", + "def create_fourier_kernels(n_fft, win_length=None, freq_bins=None, fmin=50,fmax=6000, sr=44100,\n", + " freq_scale='linear', window='hann', verbose=True):\n", + "\n", + " if freq_bins==None: freq_bins = n_fft//2+1\n", + " if win_length==None: win_length = n_fft\n", + "\n", + " s = np.arange(0, n_fft, 1.)\n", + " wsin = np.empty((freq_bins,1,n_fft))\n", + " wcos = np.empty((freq_bins,1,n_fft))\n", + " start_freq = fmin\n", + " end_freq = fmax\n", + " bins2freq = []\n", + " binslist = []\n", + "\n", + " # num_cycles = start_freq*d/44000.\n", + " # scaling_ind = np.log(end_freq/start_freq)/k\n", + "\n", + " # Choosing window shape\n", + "\n", + " #window_mask = get_window(window, int(win_length), fftbins=True)\n", + " window_mask = np.hamming(int(win_length))\n", + " window_mask = pad_center(window_mask, n_fft)\n", + "\n", + " if freq_scale == 'linear':\n", + " if verbose==True:\n", + " print(f\"sampling rate = {sr}. Please make sure the sampling rate is correct in order to\"\n", + " f\"get a valid freq range\")\n", + " \n", + " start_bin = start_freq*n_fft/sr\n", + " scaling_ind = (end_freq-start_freq)*(n_fft/sr)/freq_bins\n", + "\n", + " for k in range(freq_bins): # Only half of the bins contain useful info\n", + " # print(\"linear freq = {}\".format((k*scaling_ind+start_bin)*sr/n_fft))\n", + " bins2freq.append((k*scaling_ind+start_bin)*sr/n_fft)\n", + " binslist.append((k*scaling_ind+start_bin))\n", + " wsin[k,0,:] = np.sin(2*np.pi*(k*scaling_ind+start_bin)*s/n_fft)\n", + " wcos[k,0,:] = np.cos(2*np.pi*(k*scaling_ind+start_bin)*s/n_fft)\n", + "\n", + " elif freq_scale == 'log':\n", + " if verbose==True:\n", + " print(f\"sampling rate = {sr}. Please make sure the sampling rate is correct in order to\"\n", + " f\"get a valid freq range\")\n", + " start_bin = start_freq*n_fft/sr\n", + " scaling_ind = np.log(end_freq/start_freq)/freq_bins\n", + "\n", + " for k in range(freq_bins): # Only half of the bins contain useful info\n", + " # print(\"log freq = {}\".format(np.exp(k*scaling_ind)*start_bin*sr/n_fft))\n", + " bins2freq.append(np.exp(k*scaling_ind)*start_bin*sr/n_fft)\n", + " binslist.append((np.exp(k*scaling_ind)*start_bin))\n", + " wsin[k,0,:] = np.sin(2*np.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft)\n", + " wcos[k,0,:] = np.cos(2*np.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft)\n", + "\n", + " elif freq_scale == 'no':\n", + " for k in range(freq_bins): # Only half of the bins contain useful info\n", + " bins2freq.append(k*sr/n_fft)\n", + " binslist.append(k)\n", + " wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft)\n", + " wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft)\n", + " else:\n", + " print(\"Please select the correct frequency scale, 'linear' or 'log'\")\n", + " return wsin.astype(np.float32),wcos.astype(np.float32), bins2freq, binslist, window_mask.astype(np.float32)\n", + "\n", + "\n", + "\n", + "def broadcast_dim(x):\n", + " \"\"\"\n", + " Auto broadcast input so that it can fits into a Conv1d\n", + " \"\"\"\n", + "\n", + " if x.dim() == 2:\n", + " x = x[:, None, :]\n", + " elif x.dim() == 1:\n", + " # If nn.DataParallel is used, this broadcast doesn't work\n", + " x = x[None, None, :]\n", + " elif x.dim() == 3:\n", + " pass\n", + " else:\n", + " raise ValueError(\"Only support input with shape = (batch, len) or shape = (len)\")\n", + " return x\n", + "\n", + "\n", + "\n", + "### --------------------------- Spectrogram Classes ---------------------------###\n", + "class STFT(torch.nn.Module):\n", + "\n", + " def __init__(self, n_fft=2048, win_length=None, freq_bins=None, hop_length=None, window='hann',\n", + " freq_scale='no', center=True, pad_mode='reflect', iSTFT=False,\n", + " fmin=50, fmax=6000, sr=22050, trainable=False,\n", + " output_format=\"Complex\", verbose=True):\n", + "\n", + " super().__init__()\n", + "\n", + " # Trying to make the default setting same as librosa\n", + " if win_length==None: win_length = n_fft\n", + " if hop_length==None: hop_length = int(win_length // 4)\n", + "\n", + " self.output_format = output_format\n", + " self.trainable = trainable\n", + " self.stride = hop_length\n", + " self.center = center\n", + " self.pad_mode = pad_mode\n", + " self.n_fft = n_fft\n", + " self.freq_bins = freq_bins\n", + " self.trainable = trainable\n", + " self.pad_amount = self.n_fft // 2\n", + " self.window = window\n", + " self.win_length = win_length\n", + " self.iSTFT = iSTFT\n", + " self.trainable = trainable\n", + " start = time()\n", + "\n", + "\n", + "\n", + " # Create filter windows for stft\n", + " kernel_sin, kernel_cos, self.bins2freq, self.bin_list, window_mask = create_fourier_kernels(n_fft,\n", + " win_length=win_length,\n", + " freq_bins=freq_bins,\n", + " window=window,\n", + " freq_scale=freq_scale,\n", + " fmin=fmin,\n", + " fmax=fmax,\n", + " sr=sr,\n", + " verbose=verbose)\n", + "\n", + "\n", + " kernel_sin = torch.tensor(kernel_sin, dtype=torch.float)\n", + " kernel_cos = torch.tensor(kernel_cos, dtype=torch.float)\n", + " \n", + " # In this way, the inverse kernel and the forward kernel do not share the same memory...\n", + " kernel_sin_inv = torch.cat((kernel_sin, -kernel_sin[1:-1].flip(0)), 0)\n", + " kernel_cos_inv = torch.cat((kernel_cos, kernel_cos[1:-1].flip(0)), 0)\n", + " \n", + " if iSTFT:\n", + " self.register_buffer('kernel_sin_inv', kernel_sin_inv.unsqueeze(-1))\n", + " self.register_buffer('kernel_cos_inv', kernel_cos_inv.unsqueeze(-1))\n", + "\n", + " # Applying window functions to the Fourier kernels\n", + " if window:\n", + " window_mask = torch.tensor(window_mask)\n", + " wsin = kernel_sin * window_mask\n", + " wcos = kernel_cos * window_mask\n", + " else:\n", + " wsin = kernel_sin\n", + " wcos = kernel_cos\n", + " \n", + " if self.trainable==False:\n", + " self.register_buffer('wsin', wsin)\n", + " self.register_buffer('wcos', wcos) \n", + " \n", + " if self.trainable==True:\n", + " wsin = torch.nn.Parameter(wsin, requires_grad=self.trainable)\n", + " wcos = torch.nn.Parameter(wcos, requires_grad=self.trainable) \n", + " self.register_parameter('wsin', wsin)\n", + " self.register_parameter('wcos', wcos) \n", + " \n", + " # Prepare the shape of window mask so that it can be used later in inverse\n", + " # self.register_buffer('window_mask', window_mask.unsqueeze(0).unsqueeze(-1))\n", + " \n", + " if verbose==True:\n", + " print(\"STFT kernels created, time used = {:.4f} seconds\".format(time()-start))\n", + " else:\n", + " pass\n", + "\n", + " def forward(self, x, output_format=None):\n", + " \"\"\"\n", + " Convert a batch of waveforms to spectrograms.\n", + " \n", + " Parameters\n", + " ----------\n", + " x : torch tensor\n", + " Input signal should be in either of the following shapes.\\n\n", + " 1. ``(len_audio)``\\n\n", + " 2. ``(num_audio, len_audio)``\\n\n", + " 3. ``(num_audio, 1, len_audio)``\n", + " It will be automatically broadcast to the right shape\n", + " \n", + " output_format : str\n", + " Control the type of spectrogram to be return. Can be either ``Magnitude`` or ``Complex`` or ``Phase``.\n", + " Default value is ``Complex``. \n", + " \n", + " \"\"\"\n", + " output_format = output_format or self.output_format\n", + " self.num_samples = x.shape[-1]\n", + " \n", + " x = broadcast_dim(x)\n", + " if self.center:\n", + " if self.pad_mode == 'constant':\n", + " padding = nn.ConstantPad1d(self.pad_amount, 0)\n", + "\n", + " elif self.pad_mode == 'reflect':\n", + " if self.num_samples < self.pad_amount:\n", + " raise AssertionError(\"Signal length shorter than reflect padding length (n_fft // 2).\")\n", + " padding = nn.ReflectionPad1d(self.pad_amount)\n", + "\n", + " x = padding(x)\n", + " spec_imag = conv1d(x, self.wsin, stride=self.stride)\n", + " spec_real = conv1d(x, self.wcos, stride=self.stride) # Doing STFT by using conv1d\n", + "\n", + " # remove redundant parts\n", + " spec_real = spec_real[:, :self.freq_bins, :]\n", + " spec_imag = spec_imag[:, :self.freq_bins, :]\n", + "\n", + " if output_format=='Magnitude':\n", + " spec = spec_real.pow(2) + spec_imag.pow(2)\n", + " if self.trainable==True:\n", + " return torch.sqrt(spec+1e-8) # prevent Nan gradient when sqrt(0) due to output=0\n", + " else:\n", + " return torch.sqrt(spec)\n", + "\n", + " elif output_format=='Complex':\n", + " return torch.stack((spec_real,-spec_imag), -1) # Remember the minus sign for imaginary part\n", + "\n", + " elif output_format=='Phase':\n", + " return torch.atan2(-spec_imag+0.0,spec_real) # +0.0 removes -0.0 elements, which leads to error in calculating phase\n", + "\n", + " def inverse(self, X, onesided=True, length=None, refresh_win=True):\n", + " \"\"\"\n", + " This function is same as the :func:`~nnAudio.Spectrogram.iSTFT` class, \n", + " which is to convert spectrograms back to waveforms. \n", + " It only works for the complex value spectrograms. If you have the magnitude spectrograms,\n", + " please use :func:`~nnAudio.Spectrogram.Griffin_Lim`. \n", + " \n", + " Parameters\n", + " ----------\n", + " onesided : bool\n", + " If your spectrograms only have ``n_fft//2+1`` frequency bins, please use ``onesided=True``,\n", + " else use ``onesided=False``\n", + "\n", + " length : int\n", + " To make sure the inverse STFT has the same output length of the original waveform, please\n", + " set `length` as your intended waveform length. By default, ``length=None``,\n", + " which will remove ``n_fft//2`` samples from the start and the end of the output.\n", + " \n", + " refresh_win : bool\n", + " Recalculating the window sum square. If you have an input with fixed number of timesteps,\n", + " you can increase the speed by setting ``refresh_win=False``. Else please keep ``refresh_win=True``\n", + " \n", + " \n", + " \"\"\"\n", + " if (hasattr(self, 'kernel_sin_inv') != True) or (hasattr(self, 'kernel_cos_inv') != True):\n", + " raise NameError(\"Please activate the iSTFT module by setting `iSTFT=True` if you want to use `inverse`\") \n", + " \n", + " assert X.dim()==4 , \"Inverse iSTFT only works for complex number,\" \\\n", + " \"make sure our tensor is in the shape of (batch, freq_bins, timesteps, 2).\"\\\n", + " \"\\nIf you have a magnitude spectrogram, please consider using Griffin-Lim.\"\n", + " if onesided:\n", + " X = extend_fbins(X) # extend freq\n", + "\n", + " \n", + " X_real, X_imag = X[:, :, :, 0], X[:, :, :, 1]\n", + "\n", + " # broadcast dimensions to support 2D convolution\n", + " X_real_bc = X_real.unsqueeze(1)\n", + " X_imag_bc = X_imag.unsqueeze(1)\n", + " a1 = conv2d(X_real_bc, self.kernel_cos_inv, stride=(1,1))\n", + " b2 = conv2d(X_imag_bc, self.kernel_sin_inv, stride=(1,1))\n", + " \n", + " # compute real and imag part. signal lies in the real part\n", + " real = a1 - b2\n", + " real = real.squeeze(-2)*self.window_mask\n", + "\n", + " # Normalize the amplitude with n_fft\n", + " real /= (self.n_fft)\n", + "\n", + " # Overlap and Add algorithm to connect all the frames\n", + " real = overlap_add(real, self.stride)\n", + " \n", + " # Prepare the window sumsqure for division\n", + " # Only need to create this window once to save time\n", + " # Unless the input spectrograms have different time steps\n", + " if hasattr(self, 'w_sum')==False or refresh_win==True:\n", + " self.w_sum = torch_window_sumsquare(self.window_mask.flatten(), X.shape[2], self.stride, self.n_fft).flatten()\n", + " self.nonzero_indices = (self.w_sum>1e-10) \n", + " else:\n", + " pass\n", + " real[:, self.nonzero_indices] = real[:,self.nonzero_indices].div(self.w_sum[self.nonzero_indices])\n", + " # Remove padding\n", + " if length is None: \n", + " if self.center:\n", + " real = real[:, self.pad_amount:-self.pad_amount]\n", + "\n", + " else:\n", + " if self.center:\n", + " real = real[:, self.pad_amount:self.pad_amount + length] \n", + " else:\n", + " real = real[:, :length] \n", + " \n", + " return real\n", + " \n", + " def extra_repr(self) -> str:\n", + " return 'n_fft={}, Fourier Kernel size={}, iSTFT={}, trainable={}'.format(\n", + " self.n_fft, (*self.wsin.shape,), self.iSTFT, self.trainable\n", + " ) " + ] + }, + { + "cell_type": "code", + "execution_count": 128, + "id": "unusual-baker", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "16000\n", + "(83792,)\n", + "sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n", + "STFT kernels created, time used = 0.0153 seconds\n", + "torch.Size([521, 257])\n", + "(522, 257)\n", + "[[5.84560000e+04 2.55260664e+04 9.83611035e+03 ... 7.80710554e+00\n", + " 2.32206573e+01 1.90274487e+01]\n", + " [1.35420000e+04 3.47535000e+04 1.51204707e+04 ... 1.69094101e+02\n", + " 1.80534729e+02 1.84179596e+02]\n", + " [3.47560000e+04 2.83094609e+04 8.20204883e+03 ... 1.02080307e+02\n", + " 1.21321175e+02 1.08345497e+02]\n", + " ...\n", + " [9.36700000e+03 2.86213008e+04 1.41182402e+04 ... 1.19344498e+02\n", + " 1.25670158e+02 1.20691467e+02]\n", + " [2.87510000e+04 2.04348242e+04 8.76390625e+03 ... 9.74485092e+01\n", + " 9.01831894e+01 9.84055099e+01]\n", + " [4.45240000e+04 8.93593262e+03 4.39246826e+03 ... 6.16300154e+00\n", + " 8.94473553e+00 9.61348629e+00]]\n", + "[[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n", + " 2.40645984e+01 2.20000000e+01]\n", + " [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n", + " 1.18775735e+02 1.62000000e+02]\n", + " [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n", + " 9.57810428e+01 1.42000000e+02]\n", + " ...\n", + " [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n", + " 7.84053656e+01 9.00000000e+01]\n", + " [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n", + " 5.13101944e+01 3.50000000e+01]\n", + " [3.10200000e+04 1.59203961e+04 4.30198496e+03 ... 5.36851600e+01\n", + " 6.36197377e+01 4.40000000e+01]]\n" + ] + } + ], + "source": [ + "wav, sr = sf.read('./BAC009S0764W0124.wav', dtype='int16')\n", + "print(sr)\n", + "print(wav.shape)\n", + "\n", + "x = wav\n", + "x = torch.tensor(x).float() # casting the array into a PyTorch Tensor\n", + "\n", + "spec_layer = STFT(n_fft=512, win_length=400, hop_length=160,\n", + " window='', freq_scale='linear', center=False, pad_mode='constant',\n", + " fmin=0, fmax=8000, sr=sr, output_format='Magnitude')\n", + "wav_spec = spec_layer(x) # Feed-forward your waveform to get the spectrogram\n", + "wav_spec = wav_spec[0].T\n", + "print(wav_spec.shape)\n", + "\n", + "\n", + "spec, rspec = fbank(wav, samplerate=16000,winlen=0.025,winstep=0.01,\n", + " nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n", + " dither=0.0,remove_dc_offset=False, preemph=1.0, \n", + " wintype='hamming')\n", + "print(spec.shape)\n", + "\n", + "print(wav_spec.numpy())\n", + "print(rspec)\n", + "# print(spec)\n", + "\n", + "# spec, rspec = fbank(wav, samplerate=16000,winlen=0.032,winstep=0.01,\n", + "# nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n", + "# dither=0.0,remove_dc_offset=False, preemph=1.0, \n", + "# wintype='hamming')\n", + "# print(rspec)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "white-istanbul", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 129, + "id": "modern-rescue", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0. 0.11697778 0.41317591 0.75 0.96984631 0.96984631\n", + " 0.75 0.41317591 0.11697778 0. ]\n" + ] + }, + { + "data": { + "text/plain": [ + "array([0. , 0.0954915, 0.3454915, 0.6545085, 0.9045085, 1. ,\n", + " 0.9045085, 0.6545085, 0.3454915, 0.0954915])" + ] + }, + "execution_count": 129, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "print(np.hanning(10))\n", + "from scipy.signal import get_window\n", + "get_window('hann', 10, fftbins=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "professional-journalism", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 153, + "id": "involved-motion", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(522, 400)\n", + "[[ 43. 75. 69. ... 46. 46. 45.]\n", + " [ 210. 215. 216. ... -86. -89. -91.]\n", + " [ 128. 128. 128. ... -154. -151. -151.]\n", + " ...\n", + " [ -60. -61. -61. ... 112. 109. 110.]\n", + " [ 20. 22. 24. ... 91. 87. 87.]\n", + " [ 111. 107. 108. ... -6. -4. -8.]]\n", + "torch.Size([1, 1, 83792])\n", + "torch.Size([400, 1, 512])\n", + "torch.Size([1, 400, 521])\n", + "conv frame tensor([[ 43., 75., 69., ..., 46., 46., 45.],\n", + " [ 210., 215., 216., ..., -86., -89., -91.],\n", + " [ 128., 128., 128., ..., -154., -151., -151.],\n", + " ...,\n", + " [-143., -141., -142., ..., 96., 101., 101.],\n", + " [ -60., -61., -61., ..., 112., 109., 110.],\n", + " [ 20., 22., 24., ..., 91., 87., 87.]])\n", + "xx [[5.8976000e+04 2.5100676e+04 8.5960371e+03 ... 2.0281837e+01\n", + " 2.4064583e+01 2.2000000e+01]\n", + " [2.9266000e+04 2.7298107e+04 4.7724253e+03 ... 6.6926659e+01\n", + " 1.1877571e+02 1.6200000e+02]\n", + " [1.9630000e+04 2.8117480e+04 5.2880322e+03 ... 2.8501144e+01\n", + " 9.5781029e+01 1.4200000e+02]\n", + " ...\n", + " [2.1113000e+04 2.3099363e+04 7.1594033e+03 ... 3.1945959e+01\n", + " 9.1511757e+01 1.1500000e+02]\n", + " [1.6772000e+04 2.1322793e+04 4.0607855e+02 ... 2.6011946e+01\n", + " 7.8405365e+01 9.0000000e+01]\n", + " [3.8693000e+04 1.3598203e+04 6.7706826e+03 ... 6.1070789e+01\n", + " 5.1310158e+01 3.5000000e+01]]\n", + "torch.Size([521, 257])\n", + "yy [[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n", + " 2.40645984e+01 2.20000000e+01]\n", + " [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n", + " 1.18775735e+02 1.62000000e+02]\n", + " [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n", + " 9.57810428e+01 1.42000000e+02]\n", + " ...\n", + " [2.11130000e+04 2.30993602e+04 7.15940084e+03 ... 3.19459779e+01\n", + " 9.15117270e+01 1.15000000e+02]\n", + " [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n", + " 7.84053656e+01 9.00000000e+01]\n", + " [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n", + " 5.13101944e+01 3.50000000e+01]]\n", + "yy (522, 257)\n", + "[[5.8976000e+04 2.5100676e+04 8.5960371e+03 ... 2.0281837e+01\n", + " 2.4064583e+01 2.2000000e+01]\n", + " [2.9266000e+04 2.7298107e+04 4.7724253e+03 ... 6.6926659e+01\n", + " 1.1877571e+02 1.6200000e+02]\n", + " [1.9630000e+04 2.8117480e+04 5.2880322e+03 ... 2.8501144e+01\n", + " 9.5781029e+01 1.4200000e+02]\n", + " ...\n", + " [2.1113000e+04 2.3099363e+04 7.1594033e+03 ... 3.1945959e+01\n", + " 9.1511757e+01 1.1500000e+02]\n", + " [1.6772000e+04 2.1322793e+04 4.0607855e+02 ... 2.6011946e+01\n", + " 7.8405365e+01 9.0000000e+01]\n", + " [3.8693000e+04 1.3598203e+04 6.7706826e+03 ... 6.1070789e+01\n", + " 5.1310158e+01 3.5000000e+01]]\n", + "[[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n", + " 2.40645984e+01 2.20000000e+01]\n", + " [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n", + " 1.18775735e+02 1.62000000e+02]\n", + " [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n", + " 9.57810428e+01 1.42000000e+02]\n", + " ...\n", + " [2.11130000e+04 2.30993602e+04 7.15940084e+03 ... 3.19459779e+01\n", + " 9.15117270e+01 1.15000000e+02]\n", + " [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n", + " 7.84053656e+01 9.00000000e+01]\n", + " [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n", + " 5.13101944e+01 3.50000000e+01]]\n", + "False\n" + ] + } + ], + "source": [ + "f = frames(wav, samplerate=16000,winlen=0.025,winstep=0.01,\n", + " nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n", + " dither=0.0,remove_dc_offset=False, preemph=1.0, \n", + " wintype='hamming')\n", + "print(f.shape)\n", + "print(f)\n", + "\n", + "n_fft=512\n", + "freq_bins = n_fft//2+1\n", + "s = np.arange(0, n_fft, 1.)\n", + "wsin = np.empty((freq_bins,1,n_fft))\n", + "wcos = np.empty((freq_bins,1,n_fft))\n", + "for k in range(freq_bins): # Only half of the bins contain useful info\n", + " wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft)\n", + " wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft)\n", + "\n", + "\n", + "wsin = np.empty((n_fft,1,n_fft))\n", + "wcos = np.empty((n_fft,1,n_fft))\n", + "for k in range(n_fft): # Only half of the bins contain useful info\n", + " wsin[k,0,:] = np.eye(n_fft, n_fft)[k]\n", + " wcos[k,0,:] = np.eye(n_fft, n_fft)[k]\n", + " \n", + " \n", + "wsin = np.empty((400,1,n_fft))\n", + "wcos = np.empty((400,1,n_fft))\n", + "for k in range(400): # Only half of the bins contain useful info\n", + " wsin[k,0,:] = np.eye(400, n_fft)[k]\n", + " wcos[k,0,:] = np.eye(400, n_fft)[k]\n", + " \n", + "\n", + " \n", + "x = torch.tensor(wav).float() # casting the array into a PyTorch Tensor\n", + "x = x[None, None, :]\n", + "print(x.size())\n", + "kernel_sin = torch.tensor(wsin, dtype=torch.float)\n", + "kernel_cos = torch.tensor(wcos, dtype=torch.float)\n", + "print(kernel_sin.size())\n", + "\n", + "from torch.nn.functional import conv1d, conv2d, fold\n", + "spec_imag = conv1d(x, kernel_sin, stride=160)\n", + "spec_real = conv1d(x, kernel_cos, stride=160) # Doing STFT by using conv1d\n", + "\n", + "print(spec_imag.size())\n", + "print(\"conv frame\", spec_imag[0].T)\n", + "# print(spec_imag[0].T[:, :400])\n", + "\n", + "# remove redundant parts\n", + "# spec_real = spec_real[:, :freq_bins, :]\n", + "# spec_imag = spec_imag[:, :freq_bins, :]\n", + "# spec = spec_real.pow(2) + spec_imag.pow(2)\n", + "# spec = torch.sqrt(spec)\n", + "# print(spec)\n", + "\n", + "\n", + "\n", + "s = np.arange(0, 512, 1.)\n", + "# s = s[::-1]\n", + "wsin = np.empty((freq_bins, 400))\n", + "wcos = np.empty((freq_bins, 400))\n", + "for k in range(freq_bins): # Only half of the bins contain useful info\n", + " wsin[k,:] = np.sin(2*np.pi*k*s/n_fft)[:400]\n", + " wcos[k,:] = np.cos(2*np.pi*k*s/n_fft)[:400]\n", + "\n", + "spec_real = torch.mm(spec_imag[0].T, torch.tensor(wcos, dtype=torch.float).T)\n", + "spec_imag = torch.mm(spec_imag[0].T, torch.tensor(wsin, dtype=torch.float).T)\n", + "\n", + "\n", + "# remove redundant parts\n", + "spec = spec_real.pow(2) + spec_imag.pow(2)\n", + "spec = torch.sqrt(spec)\n", + "\n", + "print('xx', spec.numpy())\n", + "print(spec.size())\n", + "print('yy', rspec[:521, :])\n", + "print('yy', rspec.shape)\n", + "\n", + "\n", + "x = spec.numpy()\n", + "y = rspec[:-1, :]\n", + "print(x)\n", + "print(y)\n", + "print(np.allclose(x, y))" + ] + }, + { + "cell_type": "code", + "execution_count": 160, + "id": "mathematical-traffic", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([257, 1, 400])\n", + "tensor([[[5.8976e+04, 2.9266e+04, 1.9630e+04, ..., 1.6772e+04,\n", + " 3.8693e+04, 3.1020e+04],\n", + " [2.5101e+04, 2.7298e+04, 2.8117e+04, ..., 2.1323e+04,\n", + " 1.3598e+04, 1.5920e+04],\n", + " [8.5960e+03, 4.7724e+03, 5.2880e+03, ..., 4.0608e+02,\n", + " 6.7707e+03, 4.3020e+03],\n", + " ...,\n", + " [2.0282e+01, 6.6927e+01, 2.8501e+01, ..., 2.6012e+01,\n", + " 6.1071e+01, 5.3685e+01],\n", + " [2.4065e+01, 1.1878e+02, 9.5781e+01, ..., 7.8405e+01,\n", + " 5.1310e+01, 6.3620e+01],\n", + " [2.2000e+01, 1.6200e+02, 1.4200e+02, ..., 9.0000e+01,\n", + " 3.5000e+01, 4.4000e+01]]])\n", + "[[5.8976000e+04 2.5100672e+04 8.5960391e+03 ... 2.0281828e+01\n", + " 2.4064537e+01 2.2000000e+01]\n", + " [2.9266000e+04 2.7298107e+04 4.7724243e+03 ... 6.6926659e+01\n", + " 1.1877571e+02 1.6200000e+02]\n", + " [1.9630000e+04 2.8117475e+04 5.2880312e+03 ... 2.8501148e+01\n", + " 9.5781006e+01 1.4200000e+02]\n", + " ...\n", + " [1.6772000e+04 2.1322793e+04 4.0607657e+02 ... 2.6011934e+01\n", + " 7.8405350e+01 9.0000000e+01]\n", + " [3.8693000e+04 1.3598203e+04 6.7706841e+03 ... 6.1070808e+01\n", + " 5.1310150e+01 3.5000000e+01]\n", + " [3.1020000e+04 1.5920403e+04 4.3019902e+03 ... 5.3685162e+01\n", + " 6.3619797e+01 4.4000000e+01]]\n", + "[[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n", + " 2.40645984e+01 2.20000000e+01]\n", + " [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n", + " 1.18775735e+02 1.62000000e+02]\n", + " [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n", + " 9.57810428e+01 1.42000000e+02]\n", + " ...\n", + " [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n", + " 7.84053656e+01 9.00000000e+01]\n", + " [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n", + " 5.13101944e+01 3.50000000e+01]\n", + " [3.10200000e+04 1.59203961e+04 4.30198496e+03 ... 5.36851600e+01\n", + " 6.36197377e+01 4.40000000e+01]]\n", + "False\n" + ] + } + ], + "source": [ + "f = frames(wav, samplerate=16000,winlen=0.025,winstep=0.01,\n", + " nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n", + " dither=0.0,remove_dc_offset=False, preemph=1.0, \n", + " wintype='hamming')\n", + "\n", + "n_fft=512\n", + "freq_bins = n_fft//2+1\n", + "s = np.arange(0, n_fft, 1.)\n", + "wsin = np.empty((freq_bins,1,400))\n", + "wcos = np.empty((freq_bins,1,400)) #[Cout, Cin, kernel_size]\n", + "for k in range(freq_bins): # Only half of the bins contain useful info\n", + " wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft)[:400]\n", + " wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft)[:400]\n", + "\n", + " \n", + "x = torch.tensor(wav).float() # casting the array into a PyTorch Tensor\n", + "x = x[None, None, :] #[B, C, T]\n", + "\n", + "kernel_sin = torch.tensor(wsin, dtype=torch.float)\n", + "kernel_cos = torch.tensor(wcos, dtype=torch.float)\n", + "print(kernel_sin.size())\n", + "\n", + "from torch.nn.functional import conv1d, conv2d, fold\n", + "spec_imag = conv1d(x, kernel_sin, stride=160) #[1, Cout, T]\n", + "spec_real = conv1d(x, kernel_cos, stride=160) # Doing STFT by using conv1d\n", + "\n", + "# remove redundant parts\n", + "spec = spec_real.pow(2) + spec_imag.pow(2)\n", + "spec = torch.sqrt(spec)\n", + "print(spec)\n", + "\n", + "x = spec[0].T.numpy()\n", + "y = rspec[:, :]\n", + "print(x)\n", + "print(y)\n", + "print(np.allclose(x, y))" + ] + }, + { + "cell_type": "code", + "execution_count": 162, + "id": "olive-nicaragua", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel_launcher.py:1: RuntimeWarning: divide by zero encountered in true_divide\n", + " \"\"\"Entry point for launching an IPython kernel.\n" + ] + }, + { + "data": { + "text/plain": [ + "27241" + ] + }, + "execution_count": 162, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.argmax(np.abs(x -y) / np.abs(y))" + ] + }, + { + "cell_type": "code", + "execution_count": 165, + "id": "ultimate-assault", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.0" + ] + }, + "execution_count": 165, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y[np.unravel_index(27241, y.shape)]" + ] + }, + { + "cell_type": "code", + "execution_count": 166, + "id": "institutional-stock", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4.2412265e-10" + ] + }, + "execution_count": 166, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x[np.unravel_index(27241, y.shape)]" + ] + }, + { + "cell_type": "code", + "execution_count": 167, + "id": "integrated-courage", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 167, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.allclose(y, x)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "different-operation", + "metadata": {}, + "outputs": [], "source": [] } ], diff --git a/third_party/nnAudio/.gitignore b/third_party/nnAudio/.gitignore new file mode 100644 index 00000000..c09b8573 --- /dev/null +++ b/third_party/nnAudio/.gitignore @@ -0,0 +1,3 @@ +build +dist +*.egg-info/ diff --git a/third_party/nnAudio/nnAudio/Spectrogram.py b/third_party/nnAudio/nnAudio/Spectrogram.py index c92046ee..b5d79845 100755 --- a/third_party/nnAudio/nnAudio/Spectrogram.py +++ b/third_party/nnAudio/nnAudio/Spectrogram.py @@ -165,9 +165,13 @@ class STFT(torch.nn.Module): # self.kernel_cos = torch.nn.Parameter(self.kernel_cos, requires_grad=self.trainable) # Applying window functions to the Fourier kernels - window_mask = torch.tensor(window_mask) - wsin = kernel_sin * window_mask - wcos = kernel_cos * window_mask + if window: + window_mask = torch.tensor(window_mask) + wsin = kernel_sin * window_mask + wcos = kernel_cos * window_mask + else: + wsin = kernel_sin + wcos = kernel_cos if self.trainable==False: self.register_buffer('wsin', wsin) @@ -179,7 +183,6 @@ class STFT(torch.nn.Module): self.register_parameter('wsin', wsin) self.register_parameter('wcos', wcos) - # Prepare the shape of window mask so that it can be used later in inverse self.register_buffer('window_mask', window_mask.unsqueeze(0).unsqueeze(-1)) diff --git a/third_party/nnAudio/setup.py b/third_party/nnAudio/setup.py index 9b2f3688..cb69481a 100755 --- a/third_party/nnAudio/setup.py +++ b/third_party/nnAudio/setup.py @@ -2,29 +2,26 @@ import setuptools import codecs import os.path -with open("README.md", "r") as fh: - long_description = fh.read() - def read(rel_path): here = os.path.abspath(os.path.dirname(__file__)) with codecs.open(os.path.join(here, rel_path), 'r') as fp: - return fp.read() - + return fp.read() + def get_version(rel_path): for line in read(rel_path).splitlines(): if line.startswith('__version__'): delim = '"' if '"' in line else "'" return line.split(delim)[1] else: - raise RuntimeError("Unable to find version string.") - + raise RuntimeError("Unable to find version string.") + setuptools.setup( name="nnAudio", # Replace with your own username version=get_version("nnAudio/__init__.py"), author="KinWaiCheuk", author_email="u3500684@connect.hku.hk", description="A fast GPU audio processing toolbox with 1D convolutional neural network", - long_description=long_description, + long_description='', long_description_content_type="text/markdown", url="https://github.com/KinWaiCheuk/nnAudio", packages=setuptools.find_packages(),