You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/.notebook/audio_feature.ipynb

1208 lines
47 KiB

{
"cells": [
{
"cell_type": "code",
"execution_count": 94,
"id": "matched-camera",
"metadata": {},
"outputs": [],
"source": [
"from nnAudio import Spectrogram\n",
"from scipy.io import wavfile\n",
"import torch\n",
"import soundfile as sf\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 95,
"id": "quarterly-solution",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[43 75 69 ... 7 6 3]\n",
"[43 75 69 ... 7 6 3]\n",
"[43 75 69 ... 7 6 3]\n"
]
}
],
"source": [
"import scipy.io.wavfile as wav\n",
"\n",
"rate,sig = wav.read('./BAC009S0764W0124.wav')\n",
"sr, song = wavfile.read('./BAC009S0764W0124.wav') # Loading your audio\n",
"sample, sr = sf.read('./BAC009S0764W0124.wav', dtype='int16')\n",
"print(sig)\n",
"print(song)\n",
"print(sample)"
]
},
{
"cell_type": "code",
"execution_count": 96,
"id": "middle-salem",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"16000\n",
"[43 75 69 ... 7 6 3]\n",
"(83792,)\n",
"int16\n",
"sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n",
"STFT kernels created, time used = 0.2733 seconds\n",
"tensor([[[[-4.0940e+03, 1.2600e+04],\n",
" [ 8.5108e+03, -5.4930e+03],\n",
" [-3.3631e+03, -1.7904e+03],\n",
" ...,\n",
" [ 8.2279e+03, -9.3340e+03],\n",
" [-3.1990e+03, 2.0969e+03],\n",
" [-1.2669e+03, 4.4488e+03]],\n",
"\n",
" [[ 3.4886e+03, -9.9620e+03],\n",
" [-4.5364e+03, 4.1907e+02],\n",
" [ 2.5074e+03, 7.1339e+03],\n",
" ...,\n",
" [-5.4819e+03, 3.9258e+01],\n",
" [ 4.7221e+03, 6.5887e+01],\n",
" [ 9.6492e+02, -3.4386e+03]],\n",
"\n",
" [[-3.4947e+03, 9.2981e+03],\n",
" [-7.5164e+03, 8.1856e+02],\n",
" [-5.3766e+03, -9.0889e+03],\n",
" ...,\n",
" [ 1.4317e+03, 5.7447e+03],\n",
" [-3.1178e+03, 3.0740e+03],\n",
" [-3.4351e+03, 5.6900e+02]],\n",
"\n",
" ...,\n",
"\n",
" [[ 6.7112e+01, -4.5737e+00],\n",
" [-9.6295e+00, 3.5554e+01],\n",
" [ 1.8527e+00, -1.0491e+01],\n",
" ...,\n",
" [-1.1157e+01, 3.4423e+00],\n",
" [ 3.1193e+00, -4.4388e+00],\n",
" [-8.8242e+00, 8.0324e+00]],\n",
"\n",
" [[-6.5080e+01, 2.9543e+00],\n",
" [ 3.9992e+01, -1.3836e+01],\n",
" [-9.2803e+00, 1.0318e+01],\n",
" ...,\n",
" [ 4.2928e+00, 9.2397e+00],\n",
" [ 3.6642e+00, 9.4680e+00],\n",
" [ 4.8932e+00, -2.5199e+01]],\n",
"\n",
" [[ 4.7264e+01, -1.0721e+00],\n",
" [-6.0516e+00, -1.4589e+01],\n",
" [ 1.3127e+01, 1.4995e+00],\n",
" ...,\n",
" [ 1.7333e+01, -1.4380e+01],\n",
" [-3.6046e+00, -6.1019e+00],\n",
" [ 1.3321e+01, 2.3184e+01]]]])\n"
]
}
],
"source": [
"sr, song = wavfile.read('./BAC009S0764W0124.wav') # Loading your audio\n",
"print(sr)\n",
"print(song)\n",
"print(song.shape)\n",
"print(song.dtype)\n",
"x = song\n",
"x = torch.tensor(x).float() # casting the array into a PyTorch Tensor\n",
"\n",
"spec_layer = Spectrogram.STFT(n_fft=2048, freq_bins=None, hop_length=512,\n",
" window='hann', freq_scale='linear', center=True, pad_mode='reflect',\n",
" fmin=50,fmax=8000, sr=sr) # Initializing the model\n",
"\n",
"spec = spec_layer(x) # Feed-forward your waveform to get the spectrogram\n",
"print(spec)"
]
},
{
"cell_type": "code",
"execution_count": 97,
"id": "finished-sterling",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"16000\n",
"[43 75 69 ... 7 6 3]\n",
"(83792,)\n",
"int16\n",
"True\n",
"sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n",
"STFT kernels created, time used = 0.2001 seconds\n",
"torch.Size([1, 1025, 164, 2])\n",
"tensor([[[[-4.0940e+03, 1.2600e+04],\n",
" [ 8.5108e+03, -5.4930e+03],\n",
" [-3.3631e+03, -1.7904e+03],\n",
" ...,\n",
" [ 8.2279e+03, -9.3340e+03],\n",
" [-3.1990e+03, 2.0969e+03],\n",
" [-1.2669e+03, 4.4488e+03]],\n",
"\n",
" [[ 3.4886e+03, -9.9620e+03],\n",
" [-4.5364e+03, 4.1907e+02],\n",
" [ 2.5074e+03, 7.1339e+03],\n",
" ...,\n",
" [-5.4819e+03, 3.9258e+01],\n",
" [ 4.7221e+03, 6.5887e+01],\n",
" [ 9.6492e+02, -3.4386e+03]],\n",
"\n",
" [[-3.4947e+03, 9.2981e+03],\n",
" [-7.5164e+03, 8.1856e+02],\n",
" [-5.3766e+03, -9.0889e+03],\n",
" ...,\n",
" [ 1.4317e+03, 5.7447e+03],\n",
" [-3.1178e+03, 3.0740e+03],\n",
" [-3.4351e+03, 5.6900e+02]],\n",
"\n",
" ...,\n",
"\n",
" [[ 6.7112e+01, -4.5737e+00],\n",
" [-9.6295e+00, 3.5554e+01],\n",
" [ 1.8527e+00, -1.0491e+01],\n",
" ...,\n",
" [-1.1157e+01, 3.4423e+00],\n",
" [ 3.1193e+00, -4.4388e+00],\n",
" [-8.8242e+00, 8.0324e+00]],\n",
"\n",
" [[-6.5080e+01, 2.9543e+00],\n",
" [ 3.9992e+01, -1.3836e+01],\n",
" [-9.2803e+00, 1.0318e+01],\n",
" ...,\n",
" [ 4.2928e+00, 9.2397e+00],\n",
" [ 3.6642e+00, 9.4680e+00],\n",
" [ 4.8932e+00, -2.5199e+01]],\n",
"\n",
" [[ 4.7264e+01, -1.0721e+00],\n",
" [-6.0516e+00, -1.4589e+01],\n",
" [ 1.3127e+01, 1.4995e+00],\n",
" ...,\n",
" [ 1.7333e+01, -1.4380e+01],\n",
" [-3.6046e+00, -6.1019e+00],\n",
" [ 1.3321e+01, 2.3184e+01]]]])\n",
"True\n"
]
}
],
"source": [
"wav, sr = sf.read('./BAC009S0764W0124.wav', dtype='int16')\n",
"print(sr)\n",
"print(wav)\n",
"print(wav.shape)\n",
"print(wav.dtype)\n",
"print(np.allclose(wav, song))\n",
"\n",
"x = wav\n",
"x = torch.tensor(x).float() # casting the array into a PyTorch Tensor\n",
"\n",
"spec_layer = Spectrogram.STFT(n_fft=2048, freq_bins=None, hop_length=512,\n",
" window='hann', freq_scale='linear', center=True, pad_mode='reflect',\n",
" fmin=50,fmax=8000, sr=sr) # Initializing the model\n",
"\n",
"wav_spec = spec_layer(x) # Feed-forward your waveform to get the spectrogram\n",
"print(wav_spec.shape)\n",
"print(wav_spec)\n",
"print(np.allclose(wav_spec, spec))"
]
},
{
"cell_type": "code",
"execution_count": 98,
"id": "running-technology",
"metadata": {},
"outputs": [],
"source": [
"import decimal\n",
"\n",
"import numpy\n",
"import math\n",
"import logging\n",
"def round_half_up(number):\n",
" return int(decimal.Decimal(number).quantize(decimal.Decimal('1'), rounding=decimal.ROUND_HALF_UP))\n",
"\n",
"\n",
"def rolling_window(a, window, step=1):\n",
" # http://ellisvalentiner.com/post/2017-03-21-np-strides-trick\n",
" shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)\n",
" strides = a.strides + (a.strides[-1],)\n",
" return numpy.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)[::step]\n",
"\n",
"\n",
"def framesig(sig, frame_len, frame_step, dither=1.0, preemph=0.97, remove_dc_offset=True, wintype='hamming', stride_trick=True):\n",
" \"\"\"Frame a signal into overlapping frames.\n",
"\n",
" :param sig: the audio signal to frame.\n",
" :param frame_len: length of each frame measured in samples.\n",
" :param frame_step: number of samples after the start of the previous frame that the next frame should begin.\n",
" :param winfunc: the analysis window to apply to each frame. By default no window is applied.\n",
" :param stride_trick: use stride trick to compute the rolling window and window multiplication faster\n",
" :returns: an array of frames. Size is NUMFRAMES by frame_len.\n",
" \"\"\"\n",
" slen = len(sig)\n",
" frame_len = int(round_half_up(frame_len))\n",
" frame_step = int(round_half_up(frame_step))\n",
" if slen <= frame_len:\n",
" numframes = 1\n",
" else:\n",
" numframes = 1 + (( slen - frame_len) // frame_step)\n",
"\n",
" # check kaldi/src/feat/feature-window.h\n",
" padsignal = sig[:(numframes-1)*frame_step+frame_len]\n",
" if wintype is 'povey':\n",
" win = numpy.empty(frame_len)\n",
" for i in range(frame_len):\n",
" win[i] = (0.5-0.5*numpy.cos(2*numpy.pi/(frame_len-1)*i))**0.85 \n",
" else: # the hamming window\n",
" win = numpy.hamming(frame_len)\n",
" \n",
" if stride_trick:\n",
" frames = rolling_window(padsignal, window=frame_len, step=frame_step)\n",
" else:\n",
" indices = numpy.tile(numpy.arange(0, frame_len), (numframes, 1)) + numpy.tile(\n",
" numpy.arange(0, numframes * frame_step, frame_step), (frame_len, 1)).T\n",
" indices = numpy.array(indices, dtype=numpy.int32)\n",
" frames = padsignal[indices]\n",
" win = numpy.tile(win, (numframes, 1))\n",
" \n",
" frames = frames.astype(numpy.float32)\n",
" raw_frames = numpy.zeros(frames.shape)\n",
" for frm in range(frames.shape[0]):\n",
" raw_frames[frm,:] = frames[frm,:]\n",
" frames[frm,:] = do_dither(frames[frm,:], dither) # dither\n",
" frames[frm,:] = do_remove_dc_offset(frames[frm,:]) # remove dc offset\n",
" # raw_frames[frm,:] = frames[frm,:]\n",
" frames[frm,:] = do_preemphasis(frames[frm,:], preemph) # preemphasize\n",
"\n",
" return frames * win, raw_frames\n",
"\n",
"\n",
"def magspec(frames, NFFT):\n",
" \"\"\"Compute the magnitude spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).\n",
"\n",
" :param frames: the array of frames. Each row is a frame.\n",
" :param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.\n",
" :returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the magnitude spectrum of the corresponding frame.\n",
" \"\"\"\n",
" if numpy.shape(frames)[1] > NFFT:\n",
" logging.warn(\n",
" 'frame length (%d) is greater than FFT size (%d), frame will be truncated. Increase NFFT to avoid.',\n",
" numpy.shape(frames)[1], NFFT)\n",
" complex_spec = numpy.fft.rfft(frames, NFFT)\n",
" return numpy.absolute(complex_spec)\n",
"\n",
"\n",
"def powspec(frames, NFFT):\n",
" \"\"\"Compute the power spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).\n",
"\n",
" :param frames: the array of frames. Each row is a frame.\n",
" :param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.\n",
" :returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the power spectrum of the corresponding frame.\n",
" \"\"\"\n",
" return numpy.square(magspec(frames, NFFT))\n",
"\n",
"\n",
"def do_dither(signal, dither_value=1.0):\n",
" signal += numpy.random.normal(size=signal.shape) * dither_value\n",
" return signal\n",
" \n",
"def do_remove_dc_offset(signal):\n",
" signal -= numpy.mean(signal)\n",
" return signal\n",
"\n",
"def do_preemphasis(signal, coeff=0.97):\n",
" \"\"\"perform preemphasis on the input signal.\n",
"\n",
" :param signal: The signal to filter.\n",
" :param coeff: The preemphasis coefficient. 0 is no filter, default is 0.95.\n",
" :returns: the filtered signal.\n",
" \"\"\"\n",
" return numpy.append((1-coeff)*signal[0], signal[1:] - coeff * signal[:-1])"
]
},
{
"cell_type": "code",
"execution_count": 99,
"id": "ignored-retreat",
"metadata": {},
"outputs": [],
"source": [
"def fbank(signal,samplerate=16000,winlen=0.025,winstep=0.01,\n",
" nfilt=40,nfft=512,lowfreq=0,highfreq=None,dither=1.0,remove_dc_offset=True, preemph=0.97, \n",
" wintype='hamming'):\n",
" highfreq= highfreq or samplerate/2\n",
" frames, raw_frames = framesig(signal, winlen*samplerate, winstep*samplerate, dither, preemph, remove_dc_offset, wintype)\n",
" spec = magspec(frames, nfft) # nearly the same until this part\n",
" rspec = magspec(raw_frames, nfft)\n",
" return spec, rspec\n",
"\n",
"\n",
"\n",
"def frames(signal,samplerate=16000,winlen=0.025,winstep=0.01,\n",
" nfilt=40,nfft=512,lowfreq=0,highfreq=None,dither=1.0,remove_dc_offset=True, preemph=0.97, \n",
" wintype='hamming'):\n",
" highfreq= highfreq or samplerate/2\n",
" frames, raw_frames = framesig(signal, winlen*samplerate, winstep*samplerate, dither, preemph, remove_dc_offset, wintype)\n",
" return raw_frames"
]
},
{
"cell_type": "code",
"execution_count": 100,
"id": "federal-teacher",
"metadata": {},
"outputs": [],
"source": [
"\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.nn.functional import conv1d, conv2d, fold\n",
"import scipy # used only in CFP\n",
"\n",
"import numpy as np\n",
"from time import time\n",
"\n",
"def pad_center(data, size, axis=-1, **kwargs):\n",
"\n",
" kwargs.setdefault('mode', 'constant')\n",
"\n",
" n = data.shape[axis]\n",
"\n",
" lpad = int((size - n) // 2)\n",
"\n",
" lengths = [(0, 0)] * data.ndim\n",
" lengths[axis] = (lpad, int(size - n - lpad))\n",
"\n",
" if lpad < 0:\n",
" raise ParameterError(('Target size ({:d}) must be '\n",
" 'at least input size ({:d})').format(size, n))\n",
"\n",
" return np.pad(data, lengths, **kwargs)\n",
"\n",
"\n",
"\n",
"sz_float = 4 # size of a float\n",
"epsilon = 10e-8 # fudge factor for normalization\n",
"\n",
"def create_fourier_kernels(n_fft, win_length=None, freq_bins=None, fmin=50,fmax=6000, sr=44100,\n",
" freq_scale='linear', window='hann', verbose=True):\n",
"\n",
" if freq_bins==None: freq_bins = n_fft//2+1\n",
" if win_length==None: win_length = n_fft\n",
"\n",
" s = np.arange(0, n_fft, 1.)\n",
" wsin = np.empty((freq_bins,1,n_fft))\n",
" wcos = np.empty((freq_bins,1,n_fft))\n",
" start_freq = fmin\n",
" end_freq = fmax\n",
" bins2freq = []\n",
" binslist = []\n",
"\n",
" # num_cycles = start_freq*d/44000.\n",
" # scaling_ind = np.log(end_freq/start_freq)/k\n",
"\n",
" # Choosing window shape\n",
"\n",
" #window_mask = get_window(window, int(win_length), fftbins=True)\n",
" window_mask = np.hamming(int(win_length))\n",
" window_mask = pad_center(window_mask, n_fft)\n",
"\n",
" if freq_scale == 'linear':\n",
" if verbose==True:\n",
" print(f\"sampling rate = {sr}. Please make sure the sampling rate is correct in order to\"\n",
" f\"get a valid freq range\")\n",
" \n",
" start_bin = start_freq*n_fft/sr\n",
" scaling_ind = (end_freq-start_freq)*(n_fft/sr)/freq_bins\n",
"\n",
" for k in range(freq_bins): # Only half of the bins contain useful info\n",
" # print(\"linear freq = {}\".format((k*scaling_ind+start_bin)*sr/n_fft))\n",
" bins2freq.append((k*scaling_ind+start_bin)*sr/n_fft)\n",
" binslist.append((k*scaling_ind+start_bin))\n",
" wsin[k,0,:] = np.sin(2*np.pi*(k*scaling_ind+start_bin)*s/n_fft)\n",
" wcos[k,0,:] = np.cos(2*np.pi*(k*scaling_ind+start_bin)*s/n_fft)\n",
"\n",
" elif freq_scale == 'log':\n",
" if verbose==True:\n",
" print(f\"sampling rate = {sr}. Please make sure the sampling rate is correct in order to\"\n",
" f\"get a valid freq range\")\n",
" start_bin = start_freq*n_fft/sr\n",
" scaling_ind = np.log(end_freq/start_freq)/freq_bins\n",
"\n",
" for k in range(freq_bins): # Only half of the bins contain useful info\n",
" # print(\"log freq = {}\".format(np.exp(k*scaling_ind)*start_bin*sr/n_fft))\n",
" bins2freq.append(np.exp(k*scaling_ind)*start_bin*sr/n_fft)\n",
" binslist.append((np.exp(k*scaling_ind)*start_bin))\n",
" wsin[k,0,:] = np.sin(2*np.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft)\n",
" wcos[k,0,:] = np.cos(2*np.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft)\n",
"\n",
" elif freq_scale == 'no':\n",
" for k in range(freq_bins): # Only half of the bins contain useful info\n",
" bins2freq.append(k*sr/n_fft)\n",
" binslist.append(k)\n",
" wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft)\n",
" wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft)\n",
" else:\n",
" print(\"Please select the correct frequency scale, 'linear' or 'log'\")\n",
" return wsin.astype(np.float32),wcos.astype(np.float32), bins2freq, binslist, window_mask.astype(np.float32)\n",
"\n",
"\n",
"\n",
"def broadcast_dim(x):\n",
" \"\"\"\n",
" Auto broadcast input so that it can fits into a Conv1d\n",
" \"\"\"\n",
"\n",
" if x.dim() == 2:\n",
" x = x[:, None, :]\n",
" elif x.dim() == 1:\n",
" # If nn.DataParallel is used, this broadcast doesn't work\n",
" x = x[None, None, :]\n",
" elif x.dim() == 3:\n",
" pass\n",
" else:\n",
" raise ValueError(\"Only support input with shape = (batch, len) or shape = (len)\")\n",
" return x\n",
"\n",
"\n",
"\n",
"### --------------------------- Spectrogram Classes ---------------------------###\n",
"class STFT(torch.nn.Module):\n",
"\n",
" def __init__(self, n_fft=2048, win_length=None, freq_bins=None, hop_length=None, window='hann',\n",
" freq_scale='no', center=True, pad_mode='reflect', iSTFT=False,\n",
" fmin=50, fmax=6000, sr=22050, trainable=False,\n",
" output_format=\"Complex\", verbose=True):\n",
"\n",
" super().__init__()\n",
"\n",
" # Trying to make the default setting same as librosa\n",
" if win_length==None: win_length = n_fft\n",
" if hop_length==None: hop_length = int(win_length // 4)\n",
"\n",
" self.output_format = output_format\n",
" self.trainable = trainable\n",
" self.stride = hop_length\n",
" self.center = center\n",
" self.pad_mode = pad_mode\n",
" self.n_fft = n_fft\n",
" self.freq_bins = freq_bins\n",
" self.trainable = trainable\n",
" self.pad_amount = self.n_fft // 2\n",
" self.window = window\n",
" self.win_length = win_length\n",
" self.iSTFT = iSTFT\n",
" self.trainable = trainable\n",
" start = time()\n",
"\n",
"\n",
"\n",
" # Create filter windows for stft\n",
" kernel_sin, kernel_cos, self.bins2freq, self.bin_list, window_mask = create_fourier_kernels(n_fft,\n",
" win_length=win_length,\n",
" freq_bins=freq_bins,\n",
" window=window,\n",
" freq_scale=freq_scale,\n",
" fmin=fmin,\n",
" fmax=fmax,\n",
" sr=sr,\n",
" verbose=verbose)\n",
"\n",
"\n",
" kernel_sin = torch.tensor(kernel_sin, dtype=torch.float)\n",
" kernel_cos = torch.tensor(kernel_cos, dtype=torch.float)\n",
" \n",
" # In this way, the inverse kernel and the forward kernel do not share the same memory...\n",
" kernel_sin_inv = torch.cat((kernel_sin, -kernel_sin[1:-1].flip(0)), 0)\n",
" kernel_cos_inv = torch.cat((kernel_cos, kernel_cos[1:-1].flip(0)), 0)\n",
" \n",
" if iSTFT:\n",
" self.register_buffer('kernel_sin_inv', kernel_sin_inv.unsqueeze(-1))\n",
" self.register_buffer('kernel_cos_inv', kernel_cos_inv.unsqueeze(-1))\n",
"\n",
" # Applying window functions to the Fourier kernels\n",
" if window:\n",
" window_mask = torch.tensor(window_mask)\n",
" wsin = kernel_sin * window_mask\n",
" wcos = kernel_cos * window_mask\n",
" else:\n",
" wsin = kernel_sin\n",
" wcos = kernel_cos\n",
" \n",
" if self.trainable==False:\n",
" self.register_buffer('wsin', wsin)\n",
" self.register_buffer('wcos', wcos) \n",
" \n",
" if self.trainable==True:\n",
" wsin = torch.nn.Parameter(wsin, requires_grad=self.trainable)\n",
" wcos = torch.nn.Parameter(wcos, requires_grad=self.trainable) \n",
" self.register_parameter('wsin', wsin)\n",
" self.register_parameter('wcos', wcos) \n",
" \n",
" # Prepare the shape of window mask so that it can be used later in inverse\n",
" # self.register_buffer('window_mask', window_mask.unsqueeze(0).unsqueeze(-1))\n",
" \n",
" if verbose==True:\n",
" print(\"STFT kernels created, time used = {:.4f} seconds\".format(time()-start))\n",
" else:\n",
" pass\n",
"\n",
" def forward(self, x, output_format=None):\n",
" \"\"\"\n",
" Convert a batch of waveforms to spectrograms.\n",
" \n",
" Parameters\n",
" ----------\n",
" x : torch tensor\n",
" Input signal should be in either of the following shapes.\\n\n",
" 1. ``(len_audio)``\\n\n",
" 2. ``(num_audio, len_audio)``\\n\n",
" 3. ``(num_audio, 1, len_audio)``\n",
" It will be automatically broadcast to the right shape\n",
" \n",
" output_format : str\n",
" Control the type of spectrogram to be return. Can be either ``Magnitude`` or ``Complex`` or ``Phase``.\n",
" Default value is ``Complex``. \n",
" \n",
" \"\"\"\n",
" output_format = output_format or self.output_format\n",
" self.num_samples = x.shape[-1]\n",
" \n",
" x = broadcast_dim(x)\n",
" if self.center:\n",
" if self.pad_mode == 'constant':\n",
" padding = nn.ConstantPad1d(self.pad_amount, 0)\n",
"\n",
" elif self.pad_mode == 'reflect':\n",
" if self.num_samples < self.pad_amount:\n",
" raise AssertionError(\"Signal length shorter than reflect padding length (n_fft // 2).\")\n",
" padding = nn.ReflectionPad1d(self.pad_amount)\n",
"\n",
" x = padding(x)\n",
" spec_imag = conv1d(x, self.wsin, stride=self.stride)\n",
" spec_real = conv1d(x, self.wcos, stride=self.stride) # Doing STFT by using conv1d\n",
"\n",
" # remove redundant parts\n",
" spec_real = spec_real[:, :self.freq_bins, :]\n",
" spec_imag = spec_imag[:, :self.freq_bins, :]\n",
"\n",
" if output_format=='Magnitude':\n",
" spec = spec_real.pow(2) + spec_imag.pow(2)\n",
" if self.trainable==True:\n",
" return torch.sqrt(spec+1e-8) # prevent Nan gradient when sqrt(0) due to output=0\n",
" else:\n",
" return torch.sqrt(spec)\n",
"\n",
" elif output_format=='Complex':\n",
" return torch.stack((spec_real,-spec_imag), -1) # Remember the minus sign for imaginary part\n",
"\n",
" elif output_format=='Phase':\n",
" return torch.atan2(-spec_imag+0.0,spec_real) # +0.0 removes -0.0 elements, which leads to error in calculating phase\n",
"\n",
" def inverse(self, X, onesided=True, length=None, refresh_win=True):\n",
" \"\"\"\n",
" This function is same as the :func:`~nnAudio.Spectrogram.iSTFT` class, \n",
" which is to convert spectrograms back to waveforms. \n",
" It only works for the complex value spectrograms. If you have the magnitude spectrograms,\n",
" please use :func:`~nnAudio.Spectrogram.Griffin_Lim`. \n",
" \n",
" Parameters\n",
" ----------\n",
" onesided : bool\n",
" If your spectrograms only have ``n_fft//2+1`` frequency bins, please use ``onesided=True``,\n",
" else use ``onesided=False``\n",
"\n",
" length : int\n",
" To make sure the inverse STFT has the same output length of the original waveform, please\n",
" set `length` as your intended waveform length. By default, ``length=None``,\n",
" which will remove ``n_fft//2`` samples from the start and the end of the output.\n",
" \n",
" refresh_win : bool\n",
" Recalculating the window sum square. If you have an input with fixed number of timesteps,\n",
" you can increase the speed by setting ``refresh_win=False``. Else please keep ``refresh_win=True``\n",
" \n",
" \n",
" \"\"\"\n",
" if (hasattr(self, 'kernel_sin_inv') != True) or (hasattr(self, 'kernel_cos_inv') != True):\n",
" raise NameError(\"Please activate the iSTFT module by setting `iSTFT=True` if you want to use `inverse`\") \n",
" \n",
" assert X.dim()==4 , \"Inverse iSTFT only works for complex number,\" \\\n",
" \"make sure our tensor is in the shape of (batch, freq_bins, timesteps, 2).\"\\\n",
" \"\\nIf you have a magnitude spectrogram, please consider using Griffin-Lim.\"\n",
" if onesided:\n",
" X = extend_fbins(X) # extend freq\n",
"\n",
" \n",
" X_real, X_imag = X[:, :, :, 0], X[:, :, :, 1]\n",
"\n",
" # broadcast dimensions to support 2D convolution\n",
" X_real_bc = X_real.unsqueeze(1)\n",
" X_imag_bc = X_imag.unsqueeze(1)\n",
" a1 = conv2d(X_real_bc, self.kernel_cos_inv, stride=(1,1))\n",
" b2 = conv2d(X_imag_bc, self.kernel_sin_inv, stride=(1,1))\n",
" \n",
" # compute real and imag part. signal lies in the real part\n",
" real = a1 - b2\n",
" real = real.squeeze(-2)*self.window_mask\n",
"\n",
" # Normalize the amplitude with n_fft\n",
" real /= (self.n_fft)\n",
"\n",
" # Overlap and Add algorithm to connect all the frames\n",
" real = overlap_add(real, self.stride)\n",
" \n",
" # Prepare the window sumsqure for division\n",
" # Only need to create this window once to save time\n",
" # Unless the input spectrograms have different time steps\n",
" if hasattr(self, 'w_sum')==False or refresh_win==True:\n",
" self.w_sum = torch_window_sumsquare(self.window_mask.flatten(), X.shape[2], self.stride, self.n_fft).flatten()\n",
" self.nonzero_indices = (self.w_sum>1e-10) \n",
" else:\n",
" pass\n",
" real[:, self.nonzero_indices] = real[:,self.nonzero_indices].div(self.w_sum[self.nonzero_indices])\n",
" # Remove padding\n",
" if length is None: \n",
" if self.center:\n",
" real = real[:, self.pad_amount:-self.pad_amount]\n",
"\n",
" else:\n",
" if self.center:\n",
" real = real[:, self.pad_amount:self.pad_amount + length] \n",
" else:\n",
" real = real[:, :length] \n",
" \n",
" return real\n",
" \n",
" def extra_repr(self) -> str:\n",
" return 'n_fft={}, Fourier Kernel size={}, iSTFT={}, trainable={}'.format(\n",
" self.n_fft, (*self.wsin.shape,), self.iSTFT, self.trainable\n",
" ) "
]
},
{
"cell_type": "code",
"execution_count": 128,
"id": "unusual-baker",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"16000\n",
"(83792,)\n",
"sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n",
"STFT kernels created, time used = 0.0153 seconds\n",
"torch.Size([521, 257])\n",
"(522, 257)\n",
"[[5.84560000e+04 2.55260664e+04 9.83611035e+03 ... 7.80710554e+00\n",
" 2.32206573e+01 1.90274487e+01]\n",
" [1.35420000e+04 3.47535000e+04 1.51204707e+04 ... 1.69094101e+02\n",
" 1.80534729e+02 1.84179596e+02]\n",
" [3.47560000e+04 2.83094609e+04 8.20204883e+03 ... 1.02080307e+02\n",
" 1.21321175e+02 1.08345497e+02]\n",
" ...\n",
" [9.36700000e+03 2.86213008e+04 1.41182402e+04 ... 1.19344498e+02\n",
" 1.25670158e+02 1.20691467e+02]\n",
" [2.87510000e+04 2.04348242e+04 8.76390625e+03 ... 9.74485092e+01\n",
" 9.01831894e+01 9.84055099e+01]\n",
" [4.45240000e+04 8.93593262e+03 4.39246826e+03 ... 6.16300154e+00\n",
" 8.94473553e+00 9.61348629e+00]]\n",
"[[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n",
" 2.40645984e+01 2.20000000e+01]\n",
" [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n",
" 1.18775735e+02 1.62000000e+02]\n",
" [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n",
" 9.57810428e+01 1.42000000e+02]\n",
" ...\n",
" [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n",
" 7.84053656e+01 9.00000000e+01]\n",
" [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n",
" 5.13101944e+01 3.50000000e+01]\n",
" [3.10200000e+04 1.59203961e+04 4.30198496e+03 ... 5.36851600e+01\n",
" 6.36197377e+01 4.40000000e+01]]\n"
]
}
],
"source": [
"wav, sr = sf.read('./BAC009S0764W0124.wav', dtype='int16')\n",
"print(sr)\n",
"print(wav.shape)\n",
"\n",
"x = wav\n",
"x = torch.tensor(x).float() # casting the array into a PyTorch Tensor\n",
"\n",
"spec_layer = STFT(n_fft=512, win_length=400, hop_length=160,\n",
" window='', freq_scale='linear', center=False, pad_mode='constant',\n",
" fmin=0, fmax=8000, sr=sr, output_format='Magnitude')\n",
"wav_spec = spec_layer(x) # Feed-forward your waveform to get the spectrogram\n",
"wav_spec = wav_spec[0].T\n",
"print(wav_spec.shape)\n",
"\n",
"\n",
"spec, rspec = fbank(wav, samplerate=16000,winlen=0.025,winstep=0.01,\n",
" nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n",
" dither=0.0,remove_dc_offset=False, preemph=1.0, \n",
" wintype='hamming')\n",
"print(spec.shape)\n",
"\n",
"print(wav_spec.numpy())\n",
"print(rspec)\n",
"# print(spec)\n",
"\n",
"# spec, rspec = fbank(wav, samplerate=16000,winlen=0.032,winstep=0.01,\n",
"# nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n",
"# dither=0.0,remove_dc_offset=False, preemph=1.0, \n",
"# wintype='hamming')\n",
"# print(rspec)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "white-istanbul",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 129,
"id": "modern-rescue",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0. 0.11697778 0.41317591 0.75 0.96984631 0.96984631\n",
" 0.75 0.41317591 0.11697778 0. ]\n"
]
},
{
"data": {
"text/plain": [
"array([0. , 0.0954915, 0.3454915, 0.6545085, 0.9045085, 1. ,\n",
" 0.9045085, 0.6545085, 0.3454915, 0.0954915])"
]
},
"execution_count": 129,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(np.hanning(10))\n",
"from scipy.signal import get_window\n",
"get_window('hann', 10, fftbins=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "professional-journalism",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 153,
"id": "involved-motion",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(522, 400)\n",
"[[ 43. 75. 69. ... 46. 46. 45.]\n",
" [ 210. 215. 216. ... -86. -89. -91.]\n",
" [ 128. 128. 128. ... -154. -151. -151.]\n",
" ...\n",
" [ -60. -61. -61. ... 112. 109. 110.]\n",
" [ 20. 22. 24. ... 91. 87. 87.]\n",
" [ 111. 107. 108. ... -6. -4. -8.]]\n",
"torch.Size([1, 1, 83792])\n",
"torch.Size([400, 1, 512])\n",
"torch.Size([1, 400, 521])\n",
"conv frame tensor([[ 43., 75., 69., ..., 46., 46., 45.],\n",
" [ 210., 215., 216., ..., -86., -89., -91.],\n",
" [ 128., 128., 128., ..., -154., -151., -151.],\n",
" ...,\n",
" [-143., -141., -142., ..., 96., 101., 101.],\n",
" [ -60., -61., -61., ..., 112., 109., 110.],\n",
" [ 20., 22., 24., ..., 91., 87., 87.]])\n",
"xx [[5.8976000e+04 2.5100676e+04 8.5960371e+03 ... 2.0281837e+01\n",
" 2.4064583e+01 2.2000000e+01]\n",
" [2.9266000e+04 2.7298107e+04 4.7724253e+03 ... 6.6926659e+01\n",
" 1.1877571e+02 1.6200000e+02]\n",
" [1.9630000e+04 2.8117480e+04 5.2880322e+03 ... 2.8501144e+01\n",
" 9.5781029e+01 1.4200000e+02]\n",
" ...\n",
" [2.1113000e+04 2.3099363e+04 7.1594033e+03 ... 3.1945959e+01\n",
" 9.1511757e+01 1.1500000e+02]\n",
" [1.6772000e+04 2.1322793e+04 4.0607855e+02 ... 2.6011946e+01\n",
" 7.8405365e+01 9.0000000e+01]\n",
" [3.8693000e+04 1.3598203e+04 6.7706826e+03 ... 6.1070789e+01\n",
" 5.1310158e+01 3.5000000e+01]]\n",
"torch.Size([521, 257])\n",
"yy [[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n",
" 2.40645984e+01 2.20000000e+01]\n",
" [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n",
" 1.18775735e+02 1.62000000e+02]\n",
" [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n",
" 9.57810428e+01 1.42000000e+02]\n",
" ...\n",
" [2.11130000e+04 2.30993602e+04 7.15940084e+03 ... 3.19459779e+01\n",
" 9.15117270e+01 1.15000000e+02]\n",
" [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n",
" 7.84053656e+01 9.00000000e+01]\n",
" [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n",
" 5.13101944e+01 3.50000000e+01]]\n",
"yy (522, 257)\n",
"[[5.8976000e+04 2.5100676e+04 8.5960371e+03 ... 2.0281837e+01\n",
" 2.4064583e+01 2.2000000e+01]\n",
" [2.9266000e+04 2.7298107e+04 4.7724253e+03 ... 6.6926659e+01\n",
" 1.1877571e+02 1.6200000e+02]\n",
" [1.9630000e+04 2.8117480e+04 5.2880322e+03 ... 2.8501144e+01\n",
" 9.5781029e+01 1.4200000e+02]\n",
" ...\n",
" [2.1113000e+04 2.3099363e+04 7.1594033e+03 ... 3.1945959e+01\n",
" 9.1511757e+01 1.1500000e+02]\n",
" [1.6772000e+04 2.1322793e+04 4.0607855e+02 ... 2.6011946e+01\n",
" 7.8405365e+01 9.0000000e+01]\n",
" [3.8693000e+04 1.3598203e+04 6.7706826e+03 ... 6.1070789e+01\n",
" 5.1310158e+01 3.5000000e+01]]\n",
"[[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n",
" 2.40645984e+01 2.20000000e+01]\n",
" [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n",
" 1.18775735e+02 1.62000000e+02]\n",
" [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n",
" 9.57810428e+01 1.42000000e+02]\n",
" ...\n",
" [2.11130000e+04 2.30993602e+04 7.15940084e+03 ... 3.19459779e+01\n",
" 9.15117270e+01 1.15000000e+02]\n",
" [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n",
" 7.84053656e+01 9.00000000e+01]\n",
" [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n",
" 5.13101944e+01 3.50000000e+01]]\n",
"False\n"
]
}
],
"source": [
"f = frames(wav, samplerate=16000,winlen=0.025,winstep=0.01,\n",
" nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n",
" dither=0.0,remove_dc_offset=False, preemph=1.0, \n",
" wintype='hamming')\n",
"print(f.shape)\n",
"print(f)\n",
"\n",
"n_fft=512\n",
"freq_bins = n_fft//2+1\n",
"s = np.arange(0, n_fft, 1.)\n",
"wsin = np.empty((freq_bins,1,n_fft))\n",
"wcos = np.empty((freq_bins,1,n_fft))\n",
"for k in range(freq_bins): # Only half of the bins contain useful info\n",
" wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft)\n",
" wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft)\n",
"\n",
"\n",
"wsin = np.empty((n_fft,1,n_fft))\n",
"wcos = np.empty((n_fft,1,n_fft))\n",
"for k in range(n_fft): # Only half of the bins contain useful info\n",
" wsin[k,0,:] = np.eye(n_fft, n_fft)[k]\n",
" wcos[k,0,:] = np.eye(n_fft, n_fft)[k]\n",
" \n",
" \n",
"wsin = np.empty((400,1,n_fft))\n",
"wcos = np.empty((400,1,n_fft))\n",
"for k in range(400): # Only half of the bins contain useful info\n",
" wsin[k,0,:] = np.eye(400, n_fft)[k]\n",
" wcos[k,0,:] = np.eye(400, n_fft)[k]\n",
" \n",
"\n",
" \n",
"x = torch.tensor(wav).float() # casting the array into a PyTorch Tensor\n",
"x = x[None, None, :]\n",
"print(x.size())\n",
"kernel_sin = torch.tensor(wsin, dtype=torch.float)\n",
"kernel_cos = torch.tensor(wcos, dtype=torch.float)\n",
"print(kernel_sin.size())\n",
"\n",
"from torch.nn.functional import conv1d, conv2d, fold\n",
"spec_imag = conv1d(x, kernel_sin, stride=160)\n",
"spec_real = conv1d(x, kernel_cos, stride=160) # Doing STFT by using conv1d\n",
"\n",
"print(spec_imag.size())\n",
"print(\"conv frame\", spec_imag[0].T)\n",
"# print(spec_imag[0].T[:, :400])\n",
"\n",
"# remove redundant parts\n",
"# spec_real = spec_real[:, :freq_bins, :]\n",
"# spec_imag = spec_imag[:, :freq_bins, :]\n",
"# spec = spec_real.pow(2) + spec_imag.pow(2)\n",
"# spec = torch.sqrt(spec)\n",
"# print(spec)\n",
"\n",
"\n",
"\n",
"s = np.arange(0, 512, 1.)\n",
"# s = s[::-1]\n",
"wsin = np.empty((freq_bins, 400))\n",
"wcos = np.empty((freq_bins, 400))\n",
"for k in range(freq_bins): # Only half of the bins contain useful info\n",
" wsin[k,:] = np.sin(2*np.pi*k*s/n_fft)[:400]\n",
" wcos[k,:] = np.cos(2*np.pi*k*s/n_fft)[:400]\n",
"\n",
"spec_real = torch.mm(spec_imag[0].T, torch.tensor(wcos, dtype=torch.float).T)\n",
"spec_imag = torch.mm(spec_imag[0].T, torch.tensor(wsin, dtype=torch.float).T)\n",
"\n",
"\n",
"# remove redundant parts\n",
"spec = spec_real.pow(2) + spec_imag.pow(2)\n",
"spec = torch.sqrt(spec)\n",
"\n",
"print('xx', spec.numpy())\n",
"print(spec.size())\n",
"print('yy', rspec[:521, :])\n",
"print('yy', rspec.shape)\n",
"\n",
"\n",
"x = spec.numpy()\n",
"y = rspec[:-1, :]\n",
"print(x)\n",
"print(y)\n",
"print(np.allclose(x, y))"
]
},
{
"cell_type": "code",
"execution_count": 160,
"id": "mathematical-traffic",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([257, 1, 400])\n",
"tensor([[[5.8976e+04, 2.9266e+04, 1.9630e+04, ..., 1.6772e+04,\n",
" 3.8693e+04, 3.1020e+04],\n",
" [2.5101e+04, 2.7298e+04, 2.8117e+04, ..., 2.1323e+04,\n",
" 1.3598e+04, 1.5920e+04],\n",
" [8.5960e+03, 4.7724e+03, 5.2880e+03, ..., 4.0608e+02,\n",
" 6.7707e+03, 4.3020e+03],\n",
" ...,\n",
" [2.0282e+01, 6.6927e+01, 2.8501e+01, ..., 2.6012e+01,\n",
" 6.1071e+01, 5.3685e+01],\n",
" [2.4065e+01, 1.1878e+02, 9.5781e+01, ..., 7.8405e+01,\n",
" 5.1310e+01, 6.3620e+01],\n",
" [2.2000e+01, 1.6200e+02, 1.4200e+02, ..., 9.0000e+01,\n",
" 3.5000e+01, 4.4000e+01]]])\n",
"[[5.8976000e+04 2.5100672e+04 8.5960391e+03 ... 2.0281828e+01\n",
" 2.4064537e+01 2.2000000e+01]\n",
" [2.9266000e+04 2.7298107e+04 4.7724243e+03 ... 6.6926659e+01\n",
" 1.1877571e+02 1.6200000e+02]\n",
" [1.9630000e+04 2.8117475e+04 5.2880312e+03 ... 2.8501148e+01\n",
" 9.5781006e+01 1.4200000e+02]\n",
" ...\n",
" [1.6772000e+04 2.1322793e+04 4.0607657e+02 ... 2.6011934e+01\n",
" 7.8405350e+01 9.0000000e+01]\n",
" [3.8693000e+04 1.3598203e+04 6.7706841e+03 ... 6.1070808e+01\n",
" 5.1310150e+01 3.5000000e+01]\n",
" [3.1020000e+04 1.5920403e+04 4.3019902e+03 ... 5.3685162e+01\n",
" 6.3619797e+01 4.4000000e+01]]\n",
"[[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n",
" 2.40645984e+01 2.20000000e+01]\n",
" [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n",
" 1.18775735e+02 1.62000000e+02]\n",
" [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n",
" 9.57810428e+01 1.42000000e+02]\n",
" ...\n",
" [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n",
" 7.84053656e+01 9.00000000e+01]\n",
" [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n",
" 5.13101944e+01 3.50000000e+01]\n",
" [3.10200000e+04 1.59203961e+04 4.30198496e+03 ... 5.36851600e+01\n",
" 6.36197377e+01 4.40000000e+01]]\n",
"False\n"
]
}
],
"source": [
"f = frames(wav, samplerate=16000,winlen=0.025,winstep=0.01,\n",
" nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n",
" dither=0.0,remove_dc_offset=False, preemph=1.0, \n",
" wintype='hamming')\n",
"\n",
"n_fft=512\n",
"freq_bins = n_fft//2+1\n",
"s = np.arange(0, n_fft, 1.)\n",
"wsin = np.empty((freq_bins,1,400))\n",
"wcos = np.empty((freq_bins,1,400)) #[Cout, Cin, kernel_size]\n",
"for k in range(freq_bins): # Only half of the bins contain useful info\n",
" wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft)[:400]\n",
" wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft)[:400]\n",
"\n",
" \n",
"x = torch.tensor(wav).float() # casting the array into a PyTorch Tensor\n",
"x = x[None, None, :] #[B, C, T]\n",
"\n",
"kernel_sin = torch.tensor(wsin, dtype=torch.float)\n",
"kernel_cos = torch.tensor(wcos, dtype=torch.float)\n",
"print(kernel_sin.size())\n",
"\n",
"from torch.nn.functional import conv1d, conv2d, fold\n",
"spec_imag = conv1d(x, kernel_sin, stride=160) #[1, Cout, T]\n",
"spec_real = conv1d(x, kernel_cos, stride=160) # Doing STFT by using conv1d\n",
"\n",
"# remove redundant parts\n",
"spec = spec_real.pow(2) + spec_imag.pow(2)\n",
"spec = torch.sqrt(spec)\n",
"print(spec)\n",
"\n",
"x = spec[0].T.numpy()\n",
"y = rspec[:, :]\n",
"print(x)\n",
"print(y)\n",
"print(np.allclose(x, y))"
]
},
{
"cell_type": "code",
"execution_count": 162,
"id": "olive-nicaragua",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel_launcher.py:1: RuntimeWarning: divide by zero encountered in true_divide\n",
" \"\"\"Entry point for launching an IPython kernel.\n"
]
},
{
"data": {
"text/plain": [
"27241"
]
},
"execution_count": 162,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.argmax(np.abs(x -y) / np.abs(y))"
]
},
{
"cell_type": "code",
"execution_count": 165,
"id": "ultimate-assault",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.0"
]
},
"execution_count": 165,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y[np.unravel_index(27241, y.shape)]"
]
},
{
"cell_type": "code",
"execution_count": 166,
"id": "institutional-stock",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"4.2412265e-10"
]
},
"execution_count": 166,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x[np.unravel_index(27241, y.shape)]"
]
},
{
"cell_type": "code",
"execution_count": 167,
"id": "integrated-courage",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 167,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.allclose(y, x)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "different-operation",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}