From 4ad885f8ca26c8c4fdca7ff6abfc59116814eb7f Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 9 Jun 2021 11:31:08 +0000 Subject: [PATCH 01/15] add feature notebook --- .notebook/audio_feature.ipynb | 224 ++++++++++++++++++++++++++++++++++ 1 file changed, 224 insertions(+) create mode 100644 .notebook/audio_feature.ipynb diff --git a/.notebook/audio_feature.ipynb b/.notebook/audio_feature.ipynb new file mode 100644 index 000000000..5febb0aef --- /dev/null +++ b/.notebook/audio_feature.ipynb @@ -0,0 +1,224 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 13, + "id": "matched-camera", + "metadata": {}, + "outputs": [], + "source": [ + "from nnAudio import Spectrogram\n", + "from scipy.io import wavfile\n", + "import torch\n", + "import soundfile as sf\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "middle-salem", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "16000\n", + "[43 75 69 ... 7 6 3]\n", + "(83792,)\n", + "int16\n", + "sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n", + "STFT kernels created, time used = 0.2142 seconds\n", + "tensor([[[[-4.0940e+03, 1.2600e+04],\n", + " [ 8.5108e+03, -5.4930e+03],\n", + " [-3.3631e+03, -1.7904e+03],\n", + " ...,\n", + " [ 8.2279e+03, -9.3340e+03],\n", + " [-3.1990e+03, 2.0969e+03],\n", + " [-1.2669e+03, 4.4488e+03]],\n", + "\n", + " [[ 3.4886e+03, -9.9620e+03],\n", + " [-4.5364e+03, 4.1907e+02],\n", + " [ 2.5074e+03, 7.1339e+03],\n", + " ...,\n", + " [-5.4819e+03, 3.9258e+01],\n", + " [ 4.7221e+03, 6.5887e+01],\n", + " [ 9.6492e+02, -3.4386e+03]],\n", + "\n", + " [[-3.4947e+03, 9.2981e+03],\n", + " [-7.5164e+03, 8.1856e+02],\n", + " [-5.3766e+03, -9.0889e+03],\n", + " ...,\n", + " [ 1.4317e+03, 5.7447e+03],\n", + " [-3.1178e+03, 3.0740e+03],\n", + " [-3.4351e+03, 5.6900e+02]],\n", + "\n", + " ...,\n", + "\n", + " [[ 6.7112e+01, -4.5737e+00],\n", + " [-9.6295e+00, 3.5554e+01],\n", + " [ 1.8527e+00, -1.0491e+01],\n", + " ...,\n", + " [-1.1157e+01, 3.4423e+00],\n", + " [ 3.1193e+00, -4.4388e+00],\n", + " [-8.8242e+00, 8.0324e+00]],\n", + "\n", + " [[-6.5080e+01, 2.9543e+00],\n", + " [ 3.9992e+01, -1.3836e+01],\n", + " [-9.2803e+00, 1.0318e+01],\n", + " ...,\n", + " [ 4.2928e+00, 9.2397e+00],\n", + " [ 3.6642e+00, 9.4680e+00],\n", + " [ 4.8932e+00, -2.5199e+01]],\n", + "\n", + " [[ 4.7264e+01, -1.0721e+00],\n", + " [-6.0516e+00, -1.4589e+01],\n", + " [ 1.3127e+01, 1.4995e+00],\n", + " ...,\n", + " [ 1.7333e+01, -1.4380e+01],\n", + " [-3.6046e+00, -6.1019e+00],\n", + " [ 1.3321e+01, 2.3184e+01]]]])\n" + ] + } + ], + "source": [ + "sr, song = wavfile.read('./BAC009S0764W0124.wav') # Loading your audio\n", + "print(sr)\n", + "print(song)\n", + "print(song.shape)\n", + "print(song.dtype)\n", + "x = song\n", + "x = torch.tensor(x).float() # casting the array into a PyTorch Tensor\n", + "\n", + "spec_layer = Spectrogram.STFT(n_fft=2048, freq_bins=None, hop_length=512,\n", + " window='hann', freq_scale='linear', center=True, pad_mode='reflect',\n", + " fmin=50,fmax=8000, sr=sr) # Initializing the model\n", + "\n", + "spec = spec_layer(x) # Feed-forward your waveform to get the spectrogram\n", + "print(spec)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "finished-sterling", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "16000\n", + "[43 75 69 ... 7 6 3]\n", + "(83792,)\n", + "int16\n", + "True\n", + "sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n", + "STFT kernels created, time used = 0.2495 seconds\n", + "torch.Size([1, 1025, 164, 2])\n", + "tensor([[[[-4.0940e+03, 1.2600e+04],\n", + " [ 8.5108e+03, -5.4930e+03],\n", + " [-3.3631e+03, -1.7904e+03],\n", + " ...,\n", + " [ 8.2279e+03, -9.3340e+03],\n", + " [-3.1990e+03, 2.0969e+03],\n", + " [-1.2669e+03, 4.4488e+03]],\n", + "\n", + " [[ 3.4886e+03, -9.9620e+03],\n", + " [-4.5364e+03, 4.1907e+02],\n", + " [ 2.5074e+03, 7.1339e+03],\n", + " ...,\n", + " [-5.4819e+03, 3.9258e+01],\n", + " [ 4.7221e+03, 6.5887e+01],\n", + " [ 9.6492e+02, -3.4386e+03]],\n", + "\n", + " [[-3.4947e+03, 9.2981e+03],\n", + " [-7.5164e+03, 8.1856e+02],\n", + " [-5.3766e+03, -9.0889e+03],\n", + " ...,\n", + " [ 1.4317e+03, 5.7447e+03],\n", + " [-3.1178e+03, 3.0740e+03],\n", + " [-3.4351e+03, 5.6900e+02]],\n", + "\n", + " ...,\n", + "\n", + " [[ 6.7112e+01, -4.5737e+00],\n", + " [-9.6295e+00, 3.5554e+01],\n", + " [ 1.8527e+00, -1.0491e+01],\n", + " ...,\n", + " [-1.1157e+01, 3.4423e+00],\n", + " [ 3.1193e+00, -4.4388e+00],\n", + " [-8.8242e+00, 8.0324e+00]],\n", + "\n", + " [[-6.5080e+01, 2.9543e+00],\n", + " [ 3.9992e+01, -1.3836e+01],\n", + " [-9.2803e+00, 1.0318e+01],\n", + " ...,\n", + " [ 4.2928e+00, 9.2397e+00],\n", + " [ 3.6642e+00, 9.4680e+00],\n", + " [ 4.8932e+00, -2.5199e+01]],\n", + "\n", + " [[ 4.7264e+01, -1.0721e+00],\n", + " [-6.0516e+00, -1.4589e+01],\n", + " [ 1.3127e+01, 1.4995e+00],\n", + " ...,\n", + " [ 1.7333e+01, -1.4380e+01],\n", + " [-3.6046e+00, -6.1019e+00],\n", + " [ 1.3321e+01, 2.3184e+01]]]])\n", + "True\n" + ] + } + ], + "source": [ + "wav, sr = sf.read('./BAC009S0764W0124.wav', dtype='int16')\n", + "print(sr)\n", + "print(wav)\n", + "print(wav.shape)\n", + "print(wav.dtype)\n", + "print(np.allclose(wav, song))\n", + "\n", + "x = wav\n", + "x = torch.tensor(x).float() # casting the array into a PyTorch Tensor\n", + "\n", + "spec_layer = Spectrogram.STFT(n_fft=2048, freq_bins=None, hop_length=512,\n", + " window='hann', freq_scale='linear', center=True, pad_mode='reflect',\n", + " fmin=50,fmax=8000, sr=sr) # Initializing the model\n", + "\n", + "wav_spec = spec_layer(x) # Feed-forward your waveform to get the spectrogram\n", + "print(wav_spec.shape)\n", + "print(wav_spec)\n", + "print(np.allclose(wav_spec, spec))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "running-technology", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From b08384cd35d89525799dc580760c2b358c982027 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Fri, 11 Jun 2021 10:14:03 +0000 Subject: [PATCH 02/15] using conv1d to do fft --- .notebook/audio_feature.ipynb | 995 ++++++++++++++++++++- third_party/nnAudio/.gitignore | 3 + third_party/nnAudio/nnAudio/Spectrogram.py | 11 +- third_party/nnAudio/setup.py | 13 +- 4 files changed, 1004 insertions(+), 18 deletions(-) create mode 100644 third_party/nnAudio/.gitignore diff --git a/.notebook/audio_feature.ipynb b/.notebook/audio_feature.ipynb index 5febb0aef..04b4a3924 100644 --- a/.notebook/audio_feature.ipynb +++ b/.notebook/audio_feature.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 13, + "execution_count": 94, "id": "matched-camera", "metadata": {}, "outputs": [], @@ -16,7 +16,34 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 95, + "id": "quarterly-solution", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[43 75 69 ... 7 6 3]\n", + "[43 75 69 ... 7 6 3]\n", + "[43 75 69 ... 7 6 3]\n" + ] + } + ], + "source": [ + "import scipy.io.wavfile as wav\n", + "\n", + "rate,sig = wav.read('./BAC009S0764W0124.wav')\n", + "sr, song = wavfile.read('./BAC009S0764W0124.wav') # Loading your audio\n", + "sample, sr = sf.read('./BAC009S0764W0124.wav', dtype='int16')\n", + "print(sig)\n", + "print(song)\n", + "print(sample)" + ] + }, + { + "cell_type": "code", + "execution_count": 96, "id": "middle-salem", "metadata": {}, "outputs": [ @@ -29,7 +56,7 @@ "(83792,)\n", "int16\n", "sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n", - "STFT kernels created, time used = 0.2142 seconds\n", + "STFT kernels created, time used = 0.2733 seconds\n", "tensor([[[[-4.0940e+03, 1.2600e+04],\n", " [ 8.5108e+03, -5.4930e+03],\n", " [-3.3631e+03, -1.7904e+03],\n", @@ -101,7 +128,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 97, "id": "finished-sterling", "metadata": {}, "outputs": [ @@ -115,7 +142,7 @@ "int16\n", "True\n", "sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n", - "STFT kernels created, time used = 0.2495 seconds\n", + "STFT kernels created, time used = 0.2001 seconds\n", "torch.Size([1, 1025, 164, 2])\n", "tensor([[[[-4.0940e+03, 1.2600e+04],\n", " [ 8.5108e+03, -5.4930e+03],\n", @@ -193,10 +220,966 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 98, "id": "running-technology", "metadata": {}, "outputs": [], + "source": [ + "import decimal\n", + "\n", + "import numpy\n", + "import math\n", + "import logging\n", + "def round_half_up(number):\n", + " return int(decimal.Decimal(number).quantize(decimal.Decimal('1'), rounding=decimal.ROUND_HALF_UP))\n", + "\n", + "\n", + "def rolling_window(a, window, step=1):\n", + " # http://ellisvalentiner.com/post/2017-03-21-np-strides-trick\n", + " shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)\n", + " strides = a.strides + (a.strides[-1],)\n", + " return numpy.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)[::step]\n", + "\n", + "\n", + "def framesig(sig, frame_len, frame_step, dither=1.0, preemph=0.97, remove_dc_offset=True, wintype='hamming', stride_trick=True):\n", + " \"\"\"Frame a signal into overlapping frames.\n", + "\n", + " :param sig: the audio signal to frame.\n", + " :param frame_len: length of each frame measured in samples.\n", + " :param frame_step: number of samples after the start of the previous frame that the next frame should begin.\n", + " :param winfunc: the analysis window to apply to each frame. By default no window is applied.\n", + " :param stride_trick: use stride trick to compute the rolling window and window multiplication faster\n", + " :returns: an array of frames. Size is NUMFRAMES by frame_len.\n", + " \"\"\"\n", + " slen = len(sig)\n", + " frame_len = int(round_half_up(frame_len))\n", + " frame_step = int(round_half_up(frame_step))\n", + " if slen <= frame_len:\n", + " numframes = 1\n", + " else:\n", + " numframes = 1 + (( slen - frame_len) // frame_step)\n", + "\n", + " # check kaldi/src/feat/feature-window.h\n", + " padsignal = sig[:(numframes-1)*frame_step+frame_len]\n", + " if wintype is 'povey':\n", + " win = numpy.empty(frame_len)\n", + " for i in range(frame_len):\n", + " win[i] = (0.5-0.5*numpy.cos(2*numpy.pi/(frame_len-1)*i))**0.85 \n", + " else: # the hamming window\n", + " win = numpy.hamming(frame_len)\n", + " \n", + " if stride_trick:\n", + " frames = rolling_window(padsignal, window=frame_len, step=frame_step)\n", + " else:\n", + " indices = numpy.tile(numpy.arange(0, frame_len), (numframes, 1)) + numpy.tile(\n", + " numpy.arange(0, numframes * frame_step, frame_step), (frame_len, 1)).T\n", + " indices = numpy.array(indices, dtype=numpy.int32)\n", + " frames = padsignal[indices]\n", + " win = numpy.tile(win, (numframes, 1))\n", + " \n", + " frames = frames.astype(numpy.float32)\n", + " raw_frames = numpy.zeros(frames.shape)\n", + " for frm in range(frames.shape[0]):\n", + " raw_frames[frm,:] = frames[frm,:]\n", + " frames[frm,:] = do_dither(frames[frm,:], dither) # dither\n", + " frames[frm,:] = do_remove_dc_offset(frames[frm,:]) # remove dc offset\n", + " # raw_frames[frm,:] = frames[frm,:]\n", + " frames[frm,:] = do_preemphasis(frames[frm,:], preemph) # preemphasize\n", + "\n", + " return frames * win, raw_frames\n", + "\n", + "\n", + "def magspec(frames, NFFT):\n", + " \"\"\"Compute the magnitude spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).\n", + "\n", + " :param frames: the array of frames. Each row is a frame.\n", + " :param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.\n", + " :returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the magnitude spectrum of the corresponding frame.\n", + " \"\"\"\n", + " if numpy.shape(frames)[1] > NFFT:\n", + " logging.warn(\n", + " 'frame length (%d) is greater than FFT size (%d), frame will be truncated. Increase NFFT to avoid.',\n", + " numpy.shape(frames)[1], NFFT)\n", + " complex_spec = numpy.fft.rfft(frames, NFFT)\n", + " return numpy.absolute(complex_spec)\n", + "\n", + "\n", + "def powspec(frames, NFFT):\n", + " \"\"\"Compute the power spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).\n", + "\n", + " :param frames: the array of frames. Each row is a frame.\n", + " :param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.\n", + " :returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the power spectrum of the corresponding frame.\n", + " \"\"\"\n", + " return numpy.square(magspec(frames, NFFT))\n", + "\n", + "\n", + "def do_dither(signal, dither_value=1.0):\n", + " signal += numpy.random.normal(size=signal.shape) * dither_value\n", + " return signal\n", + " \n", + "def do_remove_dc_offset(signal):\n", + " signal -= numpy.mean(signal)\n", + " return signal\n", + "\n", + "def do_preemphasis(signal, coeff=0.97):\n", + " \"\"\"perform preemphasis on the input signal.\n", + "\n", + " :param signal: The signal to filter.\n", + " :param coeff: The preemphasis coefficient. 0 is no filter, default is 0.95.\n", + " :returns: the filtered signal.\n", + " \"\"\"\n", + " return numpy.append((1-coeff)*signal[0], signal[1:] - coeff * signal[:-1])" + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "id": "ignored-retreat", + "metadata": {}, + "outputs": [], + "source": [ + "def fbank(signal,samplerate=16000,winlen=0.025,winstep=0.01,\n", + " nfilt=40,nfft=512,lowfreq=0,highfreq=None,dither=1.0,remove_dc_offset=True, preemph=0.97, \n", + " wintype='hamming'):\n", + " highfreq= highfreq or samplerate/2\n", + " frames, raw_frames = framesig(signal, winlen*samplerate, winstep*samplerate, dither, preemph, remove_dc_offset, wintype)\n", + " spec = magspec(frames, nfft) # nearly the same until this part\n", + " rspec = magspec(raw_frames, nfft)\n", + " return spec, rspec\n", + "\n", + "\n", + "\n", + "def frames(signal,samplerate=16000,winlen=0.025,winstep=0.01,\n", + " nfilt=40,nfft=512,lowfreq=0,highfreq=None,dither=1.0,remove_dc_offset=True, preemph=0.97, \n", + " wintype='hamming'):\n", + " highfreq= highfreq or samplerate/2\n", + " frames, raw_frames = framesig(signal, winlen*samplerate, winstep*samplerate, dither, preemph, remove_dc_offset, wintype)\n", + " return raw_frames" + ] + }, + { + "cell_type": "code", + "execution_count": 100, + "id": "federal-teacher", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.nn.functional import conv1d, conv2d, fold\n", + "import scipy # used only in CFP\n", + "\n", + "import numpy as np\n", + "from time import time\n", + "\n", + "def pad_center(data, size, axis=-1, **kwargs):\n", + "\n", + " kwargs.setdefault('mode', 'constant')\n", + "\n", + " n = data.shape[axis]\n", + "\n", + " lpad = int((size - n) // 2)\n", + "\n", + " lengths = [(0, 0)] * data.ndim\n", + " lengths[axis] = (lpad, int(size - n - lpad))\n", + "\n", + " if lpad < 0:\n", + " raise ParameterError(('Target size ({:d}) must be '\n", + " 'at least input size ({:d})').format(size, n))\n", + "\n", + " return np.pad(data, lengths, **kwargs)\n", + "\n", + "\n", + "\n", + "sz_float = 4 # size of a float\n", + "epsilon = 10e-8 # fudge factor for normalization\n", + "\n", + "def create_fourier_kernels(n_fft, win_length=None, freq_bins=None, fmin=50,fmax=6000, sr=44100,\n", + " freq_scale='linear', window='hann', verbose=True):\n", + "\n", + " if freq_bins==None: freq_bins = n_fft//2+1\n", + " if win_length==None: win_length = n_fft\n", + "\n", + " s = np.arange(0, n_fft, 1.)\n", + " wsin = np.empty((freq_bins,1,n_fft))\n", + " wcos = np.empty((freq_bins,1,n_fft))\n", + " start_freq = fmin\n", + " end_freq = fmax\n", + " bins2freq = []\n", + " binslist = []\n", + "\n", + " # num_cycles = start_freq*d/44000.\n", + " # scaling_ind = np.log(end_freq/start_freq)/k\n", + "\n", + " # Choosing window shape\n", + "\n", + " #window_mask = get_window(window, int(win_length), fftbins=True)\n", + " window_mask = np.hamming(int(win_length))\n", + " window_mask = pad_center(window_mask, n_fft)\n", + "\n", + " if freq_scale == 'linear':\n", + " if verbose==True:\n", + " print(f\"sampling rate = {sr}. Please make sure the sampling rate is correct in order to\"\n", + " f\"get a valid freq range\")\n", + " \n", + " start_bin = start_freq*n_fft/sr\n", + " scaling_ind = (end_freq-start_freq)*(n_fft/sr)/freq_bins\n", + "\n", + " for k in range(freq_bins): # Only half of the bins contain useful info\n", + " # print(\"linear freq = {}\".format((k*scaling_ind+start_bin)*sr/n_fft))\n", + " bins2freq.append((k*scaling_ind+start_bin)*sr/n_fft)\n", + " binslist.append((k*scaling_ind+start_bin))\n", + " wsin[k,0,:] = np.sin(2*np.pi*(k*scaling_ind+start_bin)*s/n_fft)\n", + " wcos[k,0,:] = np.cos(2*np.pi*(k*scaling_ind+start_bin)*s/n_fft)\n", + "\n", + " elif freq_scale == 'log':\n", + " if verbose==True:\n", + " print(f\"sampling rate = {sr}. Please make sure the sampling rate is correct in order to\"\n", + " f\"get a valid freq range\")\n", + " start_bin = start_freq*n_fft/sr\n", + " scaling_ind = np.log(end_freq/start_freq)/freq_bins\n", + "\n", + " for k in range(freq_bins): # Only half of the bins contain useful info\n", + " # print(\"log freq = {}\".format(np.exp(k*scaling_ind)*start_bin*sr/n_fft))\n", + " bins2freq.append(np.exp(k*scaling_ind)*start_bin*sr/n_fft)\n", + " binslist.append((np.exp(k*scaling_ind)*start_bin))\n", + " wsin[k,0,:] = np.sin(2*np.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft)\n", + " wcos[k,0,:] = np.cos(2*np.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft)\n", + "\n", + " elif freq_scale == 'no':\n", + " for k in range(freq_bins): # Only half of the bins contain useful info\n", + " bins2freq.append(k*sr/n_fft)\n", + " binslist.append(k)\n", + " wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft)\n", + " wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft)\n", + " else:\n", + " print(\"Please select the correct frequency scale, 'linear' or 'log'\")\n", + " return wsin.astype(np.float32),wcos.astype(np.float32), bins2freq, binslist, window_mask.astype(np.float32)\n", + "\n", + "\n", + "\n", + "def broadcast_dim(x):\n", + " \"\"\"\n", + " Auto broadcast input so that it can fits into a Conv1d\n", + " \"\"\"\n", + "\n", + " if x.dim() == 2:\n", + " x = x[:, None, :]\n", + " elif x.dim() == 1:\n", + " # If nn.DataParallel is used, this broadcast doesn't work\n", + " x = x[None, None, :]\n", + " elif x.dim() == 3:\n", + " pass\n", + " else:\n", + " raise ValueError(\"Only support input with shape = (batch, len) or shape = (len)\")\n", + " return x\n", + "\n", + "\n", + "\n", + "### --------------------------- Spectrogram Classes ---------------------------###\n", + "class STFT(torch.nn.Module):\n", + "\n", + " def __init__(self, n_fft=2048, win_length=None, freq_bins=None, hop_length=None, window='hann',\n", + " freq_scale='no', center=True, pad_mode='reflect', iSTFT=False,\n", + " fmin=50, fmax=6000, sr=22050, trainable=False,\n", + " output_format=\"Complex\", verbose=True):\n", + "\n", + " super().__init__()\n", + "\n", + " # Trying to make the default setting same as librosa\n", + " if win_length==None: win_length = n_fft\n", + " if hop_length==None: hop_length = int(win_length // 4)\n", + "\n", + " self.output_format = output_format\n", + " self.trainable = trainable\n", + " self.stride = hop_length\n", + " self.center = center\n", + " self.pad_mode = pad_mode\n", + " self.n_fft = n_fft\n", + " self.freq_bins = freq_bins\n", + " self.trainable = trainable\n", + " self.pad_amount = self.n_fft // 2\n", + " self.window = window\n", + " self.win_length = win_length\n", + " self.iSTFT = iSTFT\n", + " self.trainable = trainable\n", + " start = time()\n", + "\n", + "\n", + "\n", + " # Create filter windows for stft\n", + " kernel_sin, kernel_cos, self.bins2freq, self.bin_list, window_mask = create_fourier_kernels(n_fft,\n", + " win_length=win_length,\n", + " freq_bins=freq_bins,\n", + " window=window,\n", + " freq_scale=freq_scale,\n", + " fmin=fmin,\n", + " fmax=fmax,\n", + " sr=sr,\n", + " verbose=verbose)\n", + "\n", + "\n", + " kernel_sin = torch.tensor(kernel_sin, dtype=torch.float)\n", + " kernel_cos = torch.tensor(kernel_cos, dtype=torch.float)\n", + " \n", + " # In this way, the inverse kernel and the forward kernel do not share the same memory...\n", + " kernel_sin_inv = torch.cat((kernel_sin, -kernel_sin[1:-1].flip(0)), 0)\n", + " kernel_cos_inv = torch.cat((kernel_cos, kernel_cos[1:-1].flip(0)), 0)\n", + " \n", + " if iSTFT:\n", + " self.register_buffer('kernel_sin_inv', kernel_sin_inv.unsqueeze(-1))\n", + " self.register_buffer('kernel_cos_inv', kernel_cos_inv.unsqueeze(-1))\n", + "\n", + " # Applying window functions to the Fourier kernels\n", + " if window:\n", + " window_mask = torch.tensor(window_mask)\n", + " wsin = kernel_sin * window_mask\n", + " wcos = kernel_cos * window_mask\n", + " else:\n", + " wsin = kernel_sin\n", + " wcos = kernel_cos\n", + " \n", + " if self.trainable==False:\n", + " self.register_buffer('wsin', wsin)\n", + " self.register_buffer('wcos', wcos) \n", + " \n", + " if self.trainable==True:\n", + " wsin = torch.nn.Parameter(wsin, requires_grad=self.trainable)\n", + " wcos = torch.nn.Parameter(wcos, requires_grad=self.trainable) \n", + " self.register_parameter('wsin', wsin)\n", + " self.register_parameter('wcos', wcos) \n", + " \n", + " # Prepare the shape of window mask so that it can be used later in inverse\n", + " # self.register_buffer('window_mask', window_mask.unsqueeze(0).unsqueeze(-1))\n", + " \n", + " if verbose==True:\n", + " print(\"STFT kernels created, time used = {:.4f} seconds\".format(time()-start))\n", + " else:\n", + " pass\n", + "\n", + " def forward(self, x, output_format=None):\n", + " \"\"\"\n", + " Convert a batch of waveforms to spectrograms.\n", + " \n", + " Parameters\n", + " ----------\n", + " x : torch tensor\n", + " Input signal should be in either of the following shapes.\\n\n", + " 1. ``(len_audio)``\\n\n", + " 2. ``(num_audio, len_audio)``\\n\n", + " 3. ``(num_audio, 1, len_audio)``\n", + " It will be automatically broadcast to the right shape\n", + " \n", + " output_format : str\n", + " Control the type of spectrogram to be return. Can be either ``Magnitude`` or ``Complex`` or ``Phase``.\n", + " Default value is ``Complex``. \n", + " \n", + " \"\"\"\n", + " output_format = output_format or self.output_format\n", + " self.num_samples = x.shape[-1]\n", + " \n", + " x = broadcast_dim(x)\n", + " if self.center:\n", + " if self.pad_mode == 'constant':\n", + " padding = nn.ConstantPad1d(self.pad_amount, 0)\n", + "\n", + " elif self.pad_mode == 'reflect':\n", + " if self.num_samples < self.pad_amount:\n", + " raise AssertionError(\"Signal length shorter than reflect padding length (n_fft // 2).\")\n", + " padding = nn.ReflectionPad1d(self.pad_amount)\n", + "\n", + " x = padding(x)\n", + " spec_imag = conv1d(x, self.wsin, stride=self.stride)\n", + " spec_real = conv1d(x, self.wcos, stride=self.stride) # Doing STFT by using conv1d\n", + "\n", + " # remove redundant parts\n", + " spec_real = spec_real[:, :self.freq_bins, :]\n", + " spec_imag = spec_imag[:, :self.freq_bins, :]\n", + "\n", + " if output_format=='Magnitude':\n", + " spec = spec_real.pow(2) + spec_imag.pow(2)\n", + " if self.trainable==True:\n", + " return torch.sqrt(spec+1e-8) # prevent Nan gradient when sqrt(0) due to output=0\n", + " else:\n", + " return torch.sqrt(spec)\n", + "\n", + " elif output_format=='Complex':\n", + " return torch.stack((spec_real,-spec_imag), -1) # Remember the minus sign for imaginary part\n", + "\n", + " elif output_format=='Phase':\n", + " return torch.atan2(-spec_imag+0.0,spec_real) # +0.0 removes -0.0 elements, which leads to error in calculating phase\n", + "\n", + " def inverse(self, X, onesided=True, length=None, refresh_win=True):\n", + " \"\"\"\n", + " This function is same as the :func:`~nnAudio.Spectrogram.iSTFT` class, \n", + " which is to convert spectrograms back to waveforms. \n", + " It only works for the complex value spectrograms. If you have the magnitude spectrograms,\n", + " please use :func:`~nnAudio.Spectrogram.Griffin_Lim`. \n", + " \n", + " Parameters\n", + " ----------\n", + " onesided : bool\n", + " If your spectrograms only have ``n_fft//2+1`` frequency bins, please use ``onesided=True``,\n", + " else use ``onesided=False``\n", + "\n", + " length : int\n", + " To make sure the inverse STFT has the same output length of the original waveform, please\n", + " set `length` as your intended waveform length. By default, ``length=None``,\n", + " which will remove ``n_fft//2`` samples from the start and the end of the output.\n", + " \n", + " refresh_win : bool\n", + " Recalculating the window sum square. If you have an input with fixed number of timesteps,\n", + " you can increase the speed by setting ``refresh_win=False``. Else please keep ``refresh_win=True``\n", + " \n", + " \n", + " \"\"\"\n", + " if (hasattr(self, 'kernel_sin_inv') != True) or (hasattr(self, 'kernel_cos_inv') != True):\n", + " raise NameError(\"Please activate the iSTFT module by setting `iSTFT=True` if you want to use `inverse`\") \n", + " \n", + " assert X.dim()==4 , \"Inverse iSTFT only works for complex number,\" \\\n", + " \"make sure our tensor is in the shape of (batch, freq_bins, timesteps, 2).\"\\\n", + " \"\\nIf you have a magnitude spectrogram, please consider using Griffin-Lim.\"\n", + " if onesided:\n", + " X = extend_fbins(X) # extend freq\n", + "\n", + " \n", + " X_real, X_imag = X[:, :, :, 0], X[:, :, :, 1]\n", + "\n", + " # broadcast dimensions to support 2D convolution\n", + " X_real_bc = X_real.unsqueeze(1)\n", + " X_imag_bc = X_imag.unsqueeze(1)\n", + " a1 = conv2d(X_real_bc, self.kernel_cos_inv, stride=(1,1))\n", + " b2 = conv2d(X_imag_bc, self.kernel_sin_inv, stride=(1,1))\n", + " \n", + " # compute real and imag part. signal lies in the real part\n", + " real = a1 - b2\n", + " real = real.squeeze(-2)*self.window_mask\n", + "\n", + " # Normalize the amplitude with n_fft\n", + " real /= (self.n_fft)\n", + "\n", + " # Overlap and Add algorithm to connect all the frames\n", + " real = overlap_add(real, self.stride)\n", + " \n", + " # Prepare the window sumsqure for division\n", + " # Only need to create this window once to save time\n", + " # Unless the input spectrograms have different time steps\n", + " if hasattr(self, 'w_sum')==False or refresh_win==True:\n", + " self.w_sum = torch_window_sumsquare(self.window_mask.flatten(), X.shape[2], self.stride, self.n_fft).flatten()\n", + " self.nonzero_indices = (self.w_sum>1e-10) \n", + " else:\n", + " pass\n", + " real[:, self.nonzero_indices] = real[:,self.nonzero_indices].div(self.w_sum[self.nonzero_indices])\n", + " # Remove padding\n", + " if length is None: \n", + " if self.center:\n", + " real = real[:, self.pad_amount:-self.pad_amount]\n", + "\n", + " else:\n", + " if self.center:\n", + " real = real[:, self.pad_amount:self.pad_amount + length] \n", + " else:\n", + " real = real[:, :length] \n", + " \n", + " return real\n", + " \n", + " def extra_repr(self) -> str:\n", + " return 'n_fft={}, Fourier Kernel size={}, iSTFT={}, trainable={}'.format(\n", + " self.n_fft, (*self.wsin.shape,), self.iSTFT, self.trainable\n", + " ) " + ] + }, + { + "cell_type": "code", + "execution_count": 128, + "id": "unusual-baker", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "16000\n", + "(83792,)\n", + "sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n", + "STFT kernels created, time used = 0.0153 seconds\n", + "torch.Size([521, 257])\n", + "(522, 257)\n", + "[[5.84560000e+04 2.55260664e+04 9.83611035e+03 ... 7.80710554e+00\n", + " 2.32206573e+01 1.90274487e+01]\n", + " [1.35420000e+04 3.47535000e+04 1.51204707e+04 ... 1.69094101e+02\n", + " 1.80534729e+02 1.84179596e+02]\n", + " [3.47560000e+04 2.83094609e+04 8.20204883e+03 ... 1.02080307e+02\n", + " 1.21321175e+02 1.08345497e+02]\n", + " ...\n", + " [9.36700000e+03 2.86213008e+04 1.41182402e+04 ... 1.19344498e+02\n", + " 1.25670158e+02 1.20691467e+02]\n", + " [2.87510000e+04 2.04348242e+04 8.76390625e+03 ... 9.74485092e+01\n", + " 9.01831894e+01 9.84055099e+01]\n", + " [4.45240000e+04 8.93593262e+03 4.39246826e+03 ... 6.16300154e+00\n", + " 8.94473553e+00 9.61348629e+00]]\n", + "[[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n", + " 2.40645984e+01 2.20000000e+01]\n", + " [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n", + " 1.18775735e+02 1.62000000e+02]\n", + " [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n", + " 9.57810428e+01 1.42000000e+02]\n", + " ...\n", + " [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n", + " 7.84053656e+01 9.00000000e+01]\n", + " [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n", + " 5.13101944e+01 3.50000000e+01]\n", + " [3.10200000e+04 1.59203961e+04 4.30198496e+03 ... 5.36851600e+01\n", + " 6.36197377e+01 4.40000000e+01]]\n" + ] + } + ], + "source": [ + "wav, sr = sf.read('./BAC009S0764W0124.wav', dtype='int16')\n", + "print(sr)\n", + "print(wav.shape)\n", + "\n", + "x = wav\n", + "x = torch.tensor(x).float() # casting the array into a PyTorch Tensor\n", + "\n", + "spec_layer = STFT(n_fft=512, win_length=400, hop_length=160,\n", + " window='', freq_scale='linear', center=False, pad_mode='constant',\n", + " fmin=0, fmax=8000, sr=sr, output_format='Magnitude')\n", + "wav_spec = spec_layer(x) # Feed-forward your waveform to get the spectrogram\n", + "wav_spec = wav_spec[0].T\n", + "print(wav_spec.shape)\n", + "\n", + "\n", + "spec, rspec = fbank(wav, samplerate=16000,winlen=0.025,winstep=0.01,\n", + " nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n", + " dither=0.0,remove_dc_offset=False, preemph=1.0, \n", + " wintype='hamming')\n", + "print(spec.shape)\n", + "\n", + "print(wav_spec.numpy())\n", + "print(rspec)\n", + "# print(spec)\n", + "\n", + "# spec, rspec = fbank(wav, samplerate=16000,winlen=0.032,winstep=0.01,\n", + "# nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n", + "# dither=0.0,remove_dc_offset=False, preemph=1.0, \n", + "# wintype='hamming')\n", + "# print(rspec)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "white-istanbul", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 129, + "id": "modern-rescue", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0. 0.11697778 0.41317591 0.75 0.96984631 0.96984631\n", + " 0.75 0.41317591 0.11697778 0. ]\n" + ] + }, + { + "data": { + "text/plain": [ + "array([0. , 0.0954915, 0.3454915, 0.6545085, 0.9045085, 1. ,\n", + " 0.9045085, 0.6545085, 0.3454915, 0.0954915])" + ] + }, + "execution_count": 129, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "print(np.hanning(10))\n", + "from scipy.signal import get_window\n", + "get_window('hann', 10, fftbins=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "professional-journalism", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 153, + "id": "involved-motion", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(522, 400)\n", + "[[ 43. 75. 69. ... 46. 46. 45.]\n", + " [ 210. 215. 216. ... -86. -89. -91.]\n", + " [ 128. 128. 128. ... -154. -151. -151.]\n", + " ...\n", + " [ -60. -61. -61. ... 112. 109. 110.]\n", + " [ 20. 22. 24. ... 91. 87. 87.]\n", + " [ 111. 107. 108. ... -6. -4. -8.]]\n", + "torch.Size([1, 1, 83792])\n", + "torch.Size([400, 1, 512])\n", + "torch.Size([1, 400, 521])\n", + "conv frame tensor([[ 43., 75., 69., ..., 46., 46., 45.],\n", + " [ 210., 215., 216., ..., -86., -89., -91.],\n", + " [ 128., 128., 128., ..., -154., -151., -151.],\n", + " ...,\n", + " [-143., -141., -142., ..., 96., 101., 101.],\n", + " [ -60., -61., -61., ..., 112., 109., 110.],\n", + " [ 20., 22., 24., ..., 91., 87., 87.]])\n", + "xx [[5.8976000e+04 2.5100676e+04 8.5960371e+03 ... 2.0281837e+01\n", + " 2.4064583e+01 2.2000000e+01]\n", + " [2.9266000e+04 2.7298107e+04 4.7724253e+03 ... 6.6926659e+01\n", + " 1.1877571e+02 1.6200000e+02]\n", + " [1.9630000e+04 2.8117480e+04 5.2880322e+03 ... 2.8501144e+01\n", + " 9.5781029e+01 1.4200000e+02]\n", + " ...\n", + " [2.1113000e+04 2.3099363e+04 7.1594033e+03 ... 3.1945959e+01\n", + " 9.1511757e+01 1.1500000e+02]\n", + " [1.6772000e+04 2.1322793e+04 4.0607855e+02 ... 2.6011946e+01\n", + " 7.8405365e+01 9.0000000e+01]\n", + " [3.8693000e+04 1.3598203e+04 6.7706826e+03 ... 6.1070789e+01\n", + " 5.1310158e+01 3.5000000e+01]]\n", + "torch.Size([521, 257])\n", + "yy [[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n", + " 2.40645984e+01 2.20000000e+01]\n", + " [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n", + " 1.18775735e+02 1.62000000e+02]\n", + " [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n", + " 9.57810428e+01 1.42000000e+02]\n", + " ...\n", + " [2.11130000e+04 2.30993602e+04 7.15940084e+03 ... 3.19459779e+01\n", + " 9.15117270e+01 1.15000000e+02]\n", + " [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n", + " 7.84053656e+01 9.00000000e+01]\n", + " [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n", + " 5.13101944e+01 3.50000000e+01]]\n", + "yy (522, 257)\n", + "[[5.8976000e+04 2.5100676e+04 8.5960371e+03 ... 2.0281837e+01\n", + " 2.4064583e+01 2.2000000e+01]\n", + " [2.9266000e+04 2.7298107e+04 4.7724253e+03 ... 6.6926659e+01\n", + " 1.1877571e+02 1.6200000e+02]\n", + " [1.9630000e+04 2.8117480e+04 5.2880322e+03 ... 2.8501144e+01\n", + " 9.5781029e+01 1.4200000e+02]\n", + " ...\n", + " [2.1113000e+04 2.3099363e+04 7.1594033e+03 ... 3.1945959e+01\n", + " 9.1511757e+01 1.1500000e+02]\n", + " [1.6772000e+04 2.1322793e+04 4.0607855e+02 ... 2.6011946e+01\n", + " 7.8405365e+01 9.0000000e+01]\n", + " [3.8693000e+04 1.3598203e+04 6.7706826e+03 ... 6.1070789e+01\n", + " 5.1310158e+01 3.5000000e+01]]\n", + "[[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n", + " 2.40645984e+01 2.20000000e+01]\n", + " [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n", + " 1.18775735e+02 1.62000000e+02]\n", + " [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n", + " 9.57810428e+01 1.42000000e+02]\n", + " ...\n", + " [2.11130000e+04 2.30993602e+04 7.15940084e+03 ... 3.19459779e+01\n", + " 9.15117270e+01 1.15000000e+02]\n", + " [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n", + " 7.84053656e+01 9.00000000e+01]\n", + " [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n", + " 5.13101944e+01 3.50000000e+01]]\n", + "False\n" + ] + } + ], + "source": [ + "f = frames(wav, samplerate=16000,winlen=0.025,winstep=0.01,\n", + " nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n", + " dither=0.0,remove_dc_offset=False, preemph=1.0, \n", + " wintype='hamming')\n", + "print(f.shape)\n", + "print(f)\n", + "\n", + "n_fft=512\n", + "freq_bins = n_fft//2+1\n", + "s = np.arange(0, n_fft, 1.)\n", + "wsin = np.empty((freq_bins,1,n_fft))\n", + "wcos = np.empty((freq_bins,1,n_fft))\n", + "for k in range(freq_bins): # Only half of the bins contain useful info\n", + " wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft)\n", + " wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft)\n", + "\n", + "\n", + "wsin = np.empty((n_fft,1,n_fft))\n", + "wcos = np.empty((n_fft,1,n_fft))\n", + "for k in range(n_fft): # Only half of the bins contain useful info\n", + " wsin[k,0,:] = np.eye(n_fft, n_fft)[k]\n", + " wcos[k,0,:] = np.eye(n_fft, n_fft)[k]\n", + " \n", + " \n", + "wsin = np.empty((400,1,n_fft))\n", + "wcos = np.empty((400,1,n_fft))\n", + "for k in range(400): # Only half of the bins contain useful info\n", + " wsin[k,0,:] = np.eye(400, n_fft)[k]\n", + " wcos[k,0,:] = np.eye(400, n_fft)[k]\n", + " \n", + "\n", + " \n", + "x = torch.tensor(wav).float() # casting the array into a PyTorch Tensor\n", + "x = x[None, None, :]\n", + "print(x.size())\n", + "kernel_sin = torch.tensor(wsin, dtype=torch.float)\n", + "kernel_cos = torch.tensor(wcos, dtype=torch.float)\n", + "print(kernel_sin.size())\n", + "\n", + "from torch.nn.functional import conv1d, conv2d, fold\n", + "spec_imag = conv1d(x, kernel_sin, stride=160)\n", + "spec_real = conv1d(x, kernel_cos, stride=160) # Doing STFT by using conv1d\n", + "\n", + "print(spec_imag.size())\n", + "print(\"conv frame\", spec_imag[0].T)\n", + "# print(spec_imag[0].T[:, :400])\n", + "\n", + "# remove redundant parts\n", + "# spec_real = spec_real[:, :freq_bins, :]\n", + "# spec_imag = spec_imag[:, :freq_bins, :]\n", + "# spec = spec_real.pow(2) + spec_imag.pow(2)\n", + "# spec = torch.sqrt(spec)\n", + "# print(spec)\n", + "\n", + "\n", + "\n", + "s = np.arange(0, 512, 1.)\n", + "# s = s[::-1]\n", + "wsin = np.empty((freq_bins, 400))\n", + "wcos = np.empty((freq_bins, 400))\n", + "for k in range(freq_bins): # Only half of the bins contain useful info\n", + " wsin[k,:] = np.sin(2*np.pi*k*s/n_fft)[:400]\n", + " wcos[k,:] = np.cos(2*np.pi*k*s/n_fft)[:400]\n", + "\n", + "spec_real = torch.mm(spec_imag[0].T, torch.tensor(wcos, dtype=torch.float).T)\n", + "spec_imag = torch.mm(spec_imag[0].T, torch.tensor(wsin, dtype=torch.float).T)\n", + "\n", + "\n", + "# remove redundant parts\n", + "spec = spec_real.pow(2) + spec_imag.pow(2)\n", + "spec = torch.sqrt(spec)\n", + "\n", + "print('xx', spec.numpy())\n", + "print(spec.size())\n", + "print('yy', rspec[:521, :])\n", + "print('yy', rspec.shape)\n", + "\n", + "\n", + "x = spec.numpy()\n", + "y = rspec[:-1, :]\n", + "print(x)\n", + "print(y)\n", + "print(np.allclose(x, y))" + ] + }, + { + "cell_type": "code", + "execution_count": 160, + "id": "mathematical-traffic", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([257, 1, 400])\n", + "tensor([[[5.8976e+04, 2.9266e+04, 1.9630e+04, ..., 1.6772e+04,\n", + " 3.8693e+04, 3.1020e+04],\n", + " [2.5101e+04, 2.7298e+04, 2.8117e+04, ..., 2.1323e+04,\n", + " 1.3598e+04, 1.5920e+04],\n", + " [8.5960e+03, 4.7724e+03, 5.2880e+03, ..., 4.0608e+02,\n", + " 6.7707e+03, 4.3020e+03],\n", + " ...,\n", + " [2.0282e+01, 6.6927e+01, 2.8501e+01, ..., 2.6012e+01,\n", + " 6.1071e+01, 5.3685e+01],\n", + " [2.4065e+01, 1.1878e+02, 9.5781e+01, ..., 7.8405e+01,\n", + " 5.1310e+01, 6.3620e+01],\n", + " [2.2000e+01, 1.6200e+02, 1.4200e+02, ..., 9.0000e+01,\n", + " 3.5000e+01, 4.4000e+01]]])\n", + "[[5.8976000e+04 2.5100672e+04 8.5960391e+03 ... 2.0281828e+01\n", + " 2.4064537e+01 2.2000000e+01]\n", + " [2.9266000e+04 2.7298107e+04 4.7724243e+03 ... 6.6926659e+01\n", + " 1.1877571e+02 1.6200000e+02]\n", + " [1.9630000e+04 2.8117475e+04 5.2880312e+03 ... 2.8501148e+01\n", + " 9.5781006e+01 1.4200000e+02]\n", + " ...\n", + " [1.6772000e+04 2.1322793e+04 4.0607657e+02 ... 2.6011934e+01\n", + " 7.8405350e+01 9.0000000e+01]\n", + " [3.8693000e+04 1.3598203e+04 6.7706841e+03 ... 6.1070808e+01\n", + " 5.1310150e+01 3.5000000e+01]\n", + " [3.1020000e+04 1.5920403e+04 4.3019902e+03 ... 5.3685162e+01\n", + " 6.3619797e+01 4.4000000e+01]]\n", + "[[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n", + " 2.40645984e+01 2.20000000e+01]\n", + " [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n", + " 1.18775735e+02 1.62000000e+02]\n", + " [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n", + " 9.57810428e+01 1.42000000e+02]\n", + " ...\n", + " [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n", + " 7.84053656e+01 9.00000000e+01]\n", + " [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n", + " 5.13101944e+01 3.50000000e+01]\n", + " [3.10200000e+04 1.59203961e+04 4.30198496e+03 ... 5.36851600e+01\n", + " 6.36197377e+01 4.40000000e+01]]\n", + "False\n" + ] + } + ], + "source": [ + "f = frames(wav, samplerate=16000,winlen=0.025,winstep=0.01,\n", + " nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n", + " dither=0.0,remove_dc_offset=False, preemph=1.0, \n", + " wintype='hamming')\n", + "\n", + "n_fft=512\n", + "freq_bins = n_fft//2+1\n", + "s = np.arange(0, n_fft, 1.)\n", + "wsin = np.empty((freq_bins,1,400))\n", + "wcos = np.empty((freq_bins,1,400)) #[Cout, Cin, kernel_size]\n", + "for k in range(freq_bins): # Only half of the bins contain useful info\n", + " wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft)[:400]\n", + " wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft)[:400]\n", + "\n", + " \n", + "x = torch.tensor(wav).float() # casting the array into a PyTorch Tensor\n", + "x = x[None, None, :] #[B, C, T]\n", + "\n", + "kernel_sin = torch.tensor(wsin, dtype=torch.float)\n", + "kernel_cos = torch.tensor(wcos, dtype=torch.float)\n", + "print(kernel_sin.size())\n", + "\n", + "from torch.nn.functional import conv1d, conv2d, fold\n", + "spec_imag = conv1d(x, kernel_sin, stride=160) #[1, Cout, T]\n", + "spec_real = conv1d(x, kernel_cos, stride=160) # Doing STFT by using conv1d\n", + "\n", + "# remove redundant parts\n", + "spec = spec_real.pow(2) + spec_imag.pow(2)\n", + "spec = torch.sqrt(spec)\n", + "print(spec)\n", + "\n", + "x = spec[0].T.numpy()\n", + "y = rspec[:, :]\n", + "print(x)\n", + "print(y)\n", + "print(np.allclose(x, y))" + ] + }, + { + "cell_type": "code", + "execution_count": 162, + "id": "olive-nicaragua", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel_launcher.py:1: RuntimeWarning: divide by zero encountered in true_divide\n", + " \"\"\"Entry point for launching an IPython kernel.\n" + ] + }, + { + "data": { + "text/plain": [ + "27241" + ] + }, + "execution_count": 162, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.argmax(np.abs(x -y) / np.abs(y))" + ] + }, + { + "cell_type": "code", + "execution_count": 165, + "id": "ultimate-assault", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.0" + ] + }, + "execution_count": 165, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y[np.unravel_index(27241, y.shape)]" + ] + }, + { + "cell_type": "code", + "execution_count": 166, + "id": "institutional-stock", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4.2412265e-10" + ] + }, + "execution_count": 166, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x[np.unravel_index(27241, y.shape)]" + ] + }, + { + "cell_type": "code", + "execution_count": 167, + "id": "integrated-courage", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 167, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.allclose(y, x)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "different-operation", + "metadata": {}, + "outputs": [], "source": [] } ], diff --git a/third_party/nnAudio/.gitignore b/third_party/nnAudio/.gitignore new file mode 100644 index 000000000..c09b85733 --- /dev/null +++ b/third_party/nnAudio/.gitignore @@ -0,0 +1,3 @@ +build +dist +*.egg-info/ diff --git a/third_party/nnAudio/nnAudio/Spectrogram.py b/third_party/nnAudio/nnAudio/Spectrogram.py index c92046ee4..b5d798457 100755 --- a/third_party/nnAudio/nnAudio/Spectrogram.py +++ b/third_party/nnAudio/nnAudio/Spectrogram.py @@ -165,9 +165,13 @@ class STFT(torch.nn.Module): # self.kernel_cos = torch.nn.Parameter(self.kernel_cos, requires_grad=self.trainable) # Applying window functions to the Fourier kernels - window_mask = torch.tensor(window_mask) - wsin = kernel_sin * window_mask - wcos = kernel_cos * window_mask + if window: + window_mask = torch.tensor(window_mask) + wsin = kernel_sin * window_mask + wcos = kernel_cos * window_mask + else: + wsin = kernel_sin + wcos = kernel_cos if self.trainable==False: self.register_buffer('wsin', wsin) @@ -179,7 +183,6 @@ class STFT(torch.nn.Module): self.register_parameter('wsin', wsin) self.register_parameter('wcos', wcos) - # Prepare the shape of window mask so that it can be used later in inverse self.register_buffer('window_mask', window_mask.unsqueeze(0).unsqueeze(-1)) diff --git a/third_party/nnAudio/setup.py b/third_party/nnAudio/setup.py index 9b2f36884..cb69481af 100755 --- a/third_party/nnAudio/setup.py +++ b/third_party/nnAudio/setup.py @@ -2,29 +2,26 @@ import setuptools import codecs import os.path -with open("README.md", "r") as fh: - long_description = fh.read() - def read(rel_path): here = os.path.abspath(os.path.dirname(__file__)) with codecs.open(os.path.join(here, rel_path), 'r') as fp: - return fp.read() - + return fp.read() + def get_version(rel_path): for line in read(rel_path).splitlines(): if line.startswith('__version__'): delim = '"' if '"' in line else "'" return line.split(delim)[1] else: - raise RuntimeError("Unable to find version string.") - + raise RuntimeError("Unable to find version string.") + setuptools.setup( name="nnAudio", # Replace with your own username version=get_version("nnAudio/__init__.py"), author="KinWaiCheuk", author_email="u3500684@connect.hku.hk", description="A fast GPU audio processing toolbox with 1D convolutional neural network", - long_description=long_description, + long_description='', long_description_content_type="text/markdown", url="https://github.com/KinWaiCheuk/nnAudio", packages=setuptools.find_packages(), From abccbb5c7dfd4a5e183dd4d3ce35fa1761028ae4 Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Fri, 11 Jun 2021 20:18:33 +0800 Subject: [PATCH 03/15] WIP: add kaldi-style frame and stft --- third_party/paddle_audio/frontend.py | 146 +++++++++++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 third_party/paddle_audio/frontend.py diff --git a/third_party/paddle_audio/frontend.py b/third_party/paddle_audio/frontend.py new file mode 100644 index 000000000..1b337732e --- /dev/null +++ b/third_party/paddle_audio/frontend.py @@ -0,0 +1,146 @@ +from typing import Tuple +import numpy as np +import paddle +from paddle import Tensor +from paddle import nn +from paddle.nn import functional as F + + +def frame(x: Tensor, + num_samples: Tensor, + win_length: int, + hop_length: int, + clip: bool = True) -> Tuple[Tensor, Tensor]: + """Extract frames from audio. + + Parameters + ---------- + x : Tensor + Shape (N, T), batched waveform. + num_samples : Tensor + Shape (N, ), number of samples of each waveform. + win_length : int + Window length. + hop_length : int + Number of samples shifted between ajancent frames. + clip : bool, optional + Whether to clip audio that does not fit into the last frame, by + default True + + Returns + ------- + frames : Tensor + Shape (N, T', win_length). + num_frames : Tensor + Shape (N, ) number of valid frames + """ + assert hop_length <= win_length + num_frames = (num_samples - win_length) // hop_length + padding = (0, 0) + if not clip: + num_frames += 1 + # NOTE: pad hop_length - 1 to the right to ensure that there is at most + # one frame dangling to the righe edge + padding = (0, hop_length - 1) + + weight = paddle.eye(win_length).unsqueeze(1) + + frames = F.conv1d(x.unsqueeze(1), + weight, + padding=padding, + stride=(hop_length, )) + return frames, num_frames + + +class STFT(nn.Layer): + """A module for computing stft transformation in a differentiable way. + + Parameters + ------------ + n_fft : int + Number of samples in a frame. + + hop_length : int + Number of samples shifted between adjacent frames. + + win_length : int + Length of the window. + + clip: bool + Whether to clip audio is necesaary. + """ + def __init__(self, + n_fft: int, + hop_length: int, + win_length: int, + window_type: str = None, + clip: bool = True): + super().__init__() + + self.hop_length = hop_length + self.n_bin = 1 + n_fft // 2 + self.n_fft = n_fft + self.clip = clip + + # calculate window + if window_type is None: + window = np.ones(win_length) + elif window_type == "hann": + window = np.hanning(win_length) + elif window_type == "hamming": + window = np.hamming(win_length) + else: + raise ValueError("Not supported yet!") + + if win_length < n_fft: + window = F.pad(window, (0, n_fft - win_length)) + elif win_length > n_fft: + window = window[:n_fft] + + # (n_bins, n_fft) complex + kernel_size = min(n_fft, win_length) + weight = np.fft.fft(np.eye(n_fft))[:self.n_bin, :kernel_size] + w_real = weight.real + w_imag = weight.imag + + # (2 * n_bins, kernel_size) + w = np.concatenate([w_real, w_imag], axis=0) + w = w * window + + # (2 * n_bins, 1, kernel_size) # (C_out, C_in, kernel_size) + w = np.expand_dims(w, 1) + weight = paddle.cast(paddle.to_tensor(w), paddle.get_default_dtype()) + self.register_buffer("weight", weight) + + def forward(self, x: Tensor, num_samples: Tensor) -> Tuple[Tensor, Tensor]: + """Compute the stft transform. + Parameters + ------------ + x : Tensor [shape=(B, T)] + The input waveform. + num_samples : Tensor + Number of samples of each waveform. + Returns + ------------ + D : Tensor + Shape(N, T', n_bins, 2) Spectrogram. + + num_frames: Tensor + Shape (N,) number of samples of each spectrogram + """ + num_frames = (num_samples - self.win_length) // self.hop_length + padding = (0, 0) + if not self.clip: + num_frames += 1 + padding = (0, self.hop_length - 1) + + batch_size, _, _ = paddle.shape(x) + x = x.unsqueeze(-1) + D = F.conv1d(self.weight, + x, + stride=(self.hop_length, ), + padding=padding, + data_format="NLC") + D = paddle.reshape(D, [batch_size, -1, self.n_bin, 2]) + return D, num_frames + From d179fc92d94ec8b89a6a7f0175171dcb3aa732cd Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 16 Jun 2021 09:08:33 +0000 Subject: [PATCH 04/15] speech deployment architecture --- speechnn/CMakeLists.txt | 0 speechnn/core/CMakeLists.txt | 0 speechnn/core/decoder/CMakeLists.txt | 0 speechnn/core/frontend/CMakeLists.txt | 0 speechnn/core/frontend/audio/CMakeLists.txt | 0 speechnn/core/frontend/text/CMakeLists.txt | 0 speechnn/core/model/CMakeLists.txt | 0 speechnn/core/protocol/CMakeLists.txt | 0 speechnn/core/utils/CMakeLists.txt | 0 speechnn/third_party/CMakeLists.txt | 0 10 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 speechnn/CMakeLists.txt create mode 100644 speechnn/core/CMakeLists.txt create mode 100644 speechnn/core/decoder/CMakeLists.txt create mode 100644 speechnn/core/frontend/CMakeLists.txt create mode 100644 speechnn/core/frontend/audio/CMakeLists.txt create mode 100644 speechnn/core/frontend/text/CMakeLists.txt create mode 100644 speechnn/core/model/CMakeLists.txt create mode 100644 speechnn/core/protocol/CMakeLists.txt create mode 100644 speechnn/core/utils/CMakeLists.txt create mode 100644 speechnn/third_party/CMakeLists.txt diff --git a/speechnn/CMakeLists.txt b/speechnn/CMakeLists.txt new file mode 100644 index 000000000..e69de29bb diff --git a/speechnn/core/CMakeLists.txt b/speechnn/core/CMakeLists.txt new file mode 100644 index 000000000..e69de29bb diff --git a/speechnn/core/decoder/CMakeLists.txt b/speechnn/core/decoder/CMakeLists.txt new file mode 100644 index 000000000..e69de29bb diff --git a/speechnn/core/frontend/CMakeLists.txt b/speechnn/core/frontend/CMakeLists.txt new file mode 100644 index 000000000..e69de29bb diff --git a/speechnn/core/frontend/audio/CMakeLists.txt b/speechnn/core/frontend/audio/CMakeLists.txt new file mode 100644 index 000000000..e69de29bb diff --git a/speechnn/core/frontend/text/CMakeLists.txt b/speechnn/core/frontend/text/CMakeLists.txt new file mode 100644 index 000000000..e69de29bb diff --git a/speechnn/core/model/CMakeLists.txt b/speechnn/core/model/CMakeLists.txt new file mode 100644 index 000000000..e69de29bb diff --git a/speechnn/core/protocol/CMakeLists.txt b/speechnn/core/protocol/CMakeLists.txt new file mode 100644 index 000000000..e69de29bb diff --git a/speechnn/core/utils/CMakeLists.txt b/speechnn/core/utils/CMakeLists.txt new file mode 100644 index 000000000..e69de29bb diff --git a/speechnn/third_party/CMakeLists.txt b/speechnn/third_party/CMakeLists.txt new file mode 100644 index 000000000..e69de29bb From 9ddae26a362998c5e3404b2c4cd69962bb098948 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 16 Jun 2021 09:12:05 +0000 Subject: [PATCH 05/15] add delpoy mergify label --- .mergify.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.mergify.yml b/.mergify.yml index b11fd5c1f..03e57e14b 100644 --- a/.mergify.yml +++ b/.mergify.yml @@ -87,3 +87,9 @@ pull_request_rules: actions: label: add: ["Docker"] + - name: "auto add label=Deployment" + conditions: + - files~=^speechnn/ + actions: + label: + add: ["Deployment"] From 698d7a9bdb3de1a763ed8ba7a71b68241e3eea17 Mon Sep 17 00:00:00 2001 From: Haoxin Ma <745165806@qq.com> Date: Thu, 17 Jun 2021 07:16:52 +0000 Subject: [PATCH 06/15] move batch_size, work_nums, shuffle_method, sortagrad to collator --- deepspeech/exps/deepspeech2/config.py | 20 +++++------------ deepspeech/exps/deepspeech2/model.py | 18 +++++++-------- deepspeech/exps/u2/config.py | 6 ++++- .../frontend/featurizer/speech_featurizer.py | 10 --------- deepspeech/io/collator.py | 22 ------------------- examples/aishell/s0/conf/deepspeech2.yaml | 9 ++++---- examples/tiny/s0/conf/deepspeech2.yaml | 9 ++++---- 7 files changed, 29 insertions(+), 65 deletions(-) diff --git a/deepspeech/exps/deepspeech2/config.py b/deepspeech/exps/deepspeech2/config.py index 1ce5346f6..faaff1aad 100644 --- a/deepspeech/exps/deepspeech2/config.py +++ b/deepspeech/exps/deepspeech2/config.py @@ -28,20 +28,6 @@ _C.data = CN( augmentation_config="", max_duration=float('inf'), min_duration=0.0, - stride_ms=10.0, # ms - window_ms=20.0, # ms - n_fft=None, # fft points - max_freq=None, # None for samplerate/2 - specgram_type='linear', # 'linear', 'mfcc', 'fbank' - feat_dim=0, # 'mfcc', 'fbank' - delat_delta=False, # 'mfcc', 'fbank' - target_sample_rate=16000, # target sample rate - use_dB_normalization=True, - target_dB=-20, - batch_size=32, # batch size - num_workers=0, # data loader workers - sortagrad=False, # sorted in first epoch when True - shuffle_method="batch_shuffle", # 'batch_shuffle', 'instance_shuffle' )) _C.model = CN( @@ -72,7 +58,11 @@ _C.collator =CN( use_dB_normalization=True, target_dB=-20, dither=1.0, # feature dither - keep_transcription_text=False + keep_transcription_text=False, + batch_size=32, # batch size + num_workers=0, # data loader workers + sortagrad=False, # sorted in first epoch when True + shuffle_method="batch_shuffle", # 'batch_shuffle', 'instance_shuffle' )) DeepSpeech2Model.params(_C.model) diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 5833382a4..b54192dd3 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -55,7 +55,7 @@ class DeepSpeech2Trainer(Trainer): 'train_loss': float(loss), } msg += "train time: {:>.3f}s, ".format(iteration_time) - msg += "batch size: {}, ".format(self.config.data.batch_size) + msg += "batch size: {}, ".format(self.config.collator.batch_size) msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses_np.items()) logger.info(msg) @@ -149,31 +149,31 @@ class DeepSpeech2Trainer(Trainer): if self.parallel: batch_sampler = SortagradDistributedBatchSampler( train_dataset, - batch_size=config.data.batch_size, + batch_size=config.collator.batch_size, num_replicas=None, rank=None, shuffle=True, drop_last=True, - sortagrad=config.data.sortagrad, - shuffle_method=config.data.shuffle_method) + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) else: batch_sampler = SortagradBatchSampler( train_dataset, shuffle=True, - batch_size=config.data.batch_size, + batch_size=config.collator.batch_size, drop_last=True, - sortagrad=config.data.sortagrad, - shuffle_method=config.data.shuffle_method) + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) collate_fn = SpeechCollator.from_config(config) self.train_loader = DataLoader( train_dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, - num_workers=config.data.num_workers) + num_workers=config.collator.num_workers) self.valid_loader = DataLoader( dev_dataset, - batch_size=config.data.batch_size, + batch_size=config.collator.batch_size, shuffle=False, drop_last=False, collate_fn=collate_fn) diff --git a/deepspeech/exps/u2/config.py b/deepspeech/exps/u2/config.py index 19080be76..42725c74f 100644 --- a/deepspeech/exps/u2/config.py +++ b/deepspeech/exps/u2/config.py @@ -26,7 +26,11 @@ _C.collator =CfgNode( dict( augmentation_config="", unit_type="char", - keep_transcription_text=False + keep_transcription_text=False, + batch_size=32, # batch size + num_workers=0, # data loader workers + sortagrad=False, # sorted in first epoch when True + shuffle_method="batch_shuffle" # 'batch_shuffle', 'instance_shuffle' )) _C.model = U2Model.params() diff --git a/deepspeech/frontend/featurizer/speech_featurizer.py b/deepspeech/frontend/featurizer/speech_featurizer.py index 852d26c9a..0fbbc5648 100644 --- a/deepspeech/frontend/featurizer/speech_featurizer.py +++ b/deepspeech/frontend/featurizer/speech_featurizer.py @@ -151,13 +151,3 @@ class SpeechFeaturizer(object): TextFeaturizer: object. """ return self._text_featurizer - - - # @property - # def text_feature(self): - # """Return the text feature object. - - # Returns: - # TextFeaturizer: object. - # """ - # return self._text_featurizer diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index 8b8575dbd..ac817a192 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -203,34 +203,22 @@ class SpeechCollator(): where transcription part could be token ids or text. :rtype: tuple of (2darray, list) """ - start_time = time.time() if isinstance(audio_file, str) and audio_file.startswith('tar:'): speech_segment = SpeechSegment.from_file( self._subfile_from_tar(audio_file), transcript) else: speech_segment = SpeechSegment.from_file(audio_file, transcript) - load_wav_time = time.time() - start_time - #logger.debug(f"load wav time: {load_wav_time}") # audio augment - start_time = time.time() self._augmentation_pipeline.transform_audio(speech_segment) - audio_aug_time = time.time() - start_time - #logger.debug(f"audio augmentation time: {audio_aug_time}") - start_time = time.time() specgram, transcript_part = self._speech_featurizer.featurize( speech_segment, self._keep_transcription_text) if self._normalizer: specgram = self._normalizer.apply(specgram) - feature_time = time.time() - start_time - #logger.debug(f"audio & test feature time: {feature_time}") # specgram augment - start_time = time.time() specgram = self._augmentation_pipeline.transform_feature(specgram) - feature_aug_time = time.time() - start_time - #logger.debug(f"audio feature augmentation time: {feature_aug_time}") return specgram, transcript_part def __call__(self, batch): @@ -283,16 +271,6 @@ class SpeechCollator(): return utts, padded_audios, audio_lens, padded_texts, text_lens - # @property - # def text_feature(self): - # return self._speech_featurizer.text_feature - - - # @property - # def stride_ms(self): - # return self._speech_featurizer.stride_ms - -########### @property def manifest(self): diff --git a/examples/aishell/s0/conf/deepspeech2.yaml b/examples/aishell/s0/conf/deepspeech2.yaml index e5ab8e046..54ce240e7 100644 --- a/examples/aishell/s0/conf/deepspeech2.yaml +++ b/examples/aishell/s0/conf/deepspeech2.yaml @@ -5,16 +5,13 @@ data: test_manifest: data/manifest.test mean_std_filepath: data/mean_std.json vocab_filepath: data/vocab.txt - batch_size: 64 # one gpu min_input_len: 0.0 max_input_len: 27.0 # second min_output_len: 0.0 max_output_len: .inf min_output_input_ratio: 0.00 max_output_input_ratio: .inf - sortagrad: True - shuffle_method: batch_shuffle - num_workers: 0 + collator: augmentation_config: conf/augmentation.json @@ -32,6 +29,10 @@ collator: target_dB: -20 dither: 1.0 keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 0 + batch_size: 64 # one gpu model: num_conv_layers: 2 diff --git a/examples/tiny/s0/conf/deepspeech2.yaml b/examples/tiny/s0/conf/deepspeech2.yaml index 6680e5686..434cf264c 100644 --- a/examples/tiny/s0/conf/deepspeech2.yaml +++ b/examples/tiny/s0/conf/deepspeech2.yaml @@ -6,16 +6,13 @@ data: mean_std_filepath: data/mean_std.json unit_type: char vocab_filepath: data/vocab.txt - batch_size: 4 min_input_len: 0.0 max_input_len: 27.0 min_output_len: 0.0 max_output_len: 400.0 min_output_input_ratio: 0.05 max_output_input_ratio: 10.0 - sortagrad: True - shuffle_method: batch_shuffle - num_workers: 0 + collator: augmentation_config: conf/augmentation.json @@ -33,6 +30,10 @@ collator: target_dB: -20 dither: 1.0 keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 0 + batch_size: 4 model: num_conv_layers: 2 From 557427736e9f2fba6715cc3ce18b3175a3c42cd8 Mon Sep 17 00:00:00 2001 From: Haoxin Ma <745165806@qq.com> Date: Fri, 18 Jun 2021 06:41:28 +0000 Subject: [PATCH 07/15] move redundant params --- deepspeech/exps/deepspeech2/config.py | 30 +++---- deepspeech/exps/deepspeech2/model.py | 14 ++-- deepspeech/exps/u2/config.py | 12 +-- deepspeech/exps/u2/model.py | 35 ++++---- deepspeech/io/collator.py | 36 ++++++-- deepspeech/io/dataset.py | 105 +----------------------- examples/aishell/s1/conf/conformer.yaml | 14 ++-- examples/tiny/s0/conf/deepspeech2.yaml | 10 +-- examples/tiny/s1/conf/transformer.yaml | 22 ++--- 9 files changed, 96 insertions(+), 182 deletions(-) diff --git a/deepspeech/exps/deepspeech2/config.py b/deepspeech/exps/deepspeech2/config.py index faaff1aad..050a50b00 100644 --- a/deepspeech/exps/deepspeech2/config.py +++ b/deepspeech/exps/deepspeech2/config.py @@ -21,32 +21,18 @@ _C.data = CN( train_manifest="", dev_manifest="", test_manifest="", - unit_type="char", - vocab_filepath="", - spm_model_prefix="", - mean_std_filepath="", - augmentation_config="", max_duration=float('inf'), min_duration=0.0, )) -_C.model = CN( - dict( - num_conv_layers=2, #Number of stacking convolution layers. - num_rnn_layers=3, #Number of stacking RNN layers. - rnn_layer_size=1024, #RNN layer size (number of RNN cells). - use_gru=True, #Use gru if set True. Use simple rnn if set False. - share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported. - )) - _C.collator =CN( dict( - augmentation_config="", - random_seed=0, - mean_std_filepath="", unit_type="char", vocab_filepath="", spm_model_prefix="", + mean_std_filepath="", + augmentation_config="", + random_seed=0, specgram_type='linear', # 'linear', 'mfcc', 'fbank' feat_dim=0, # 'mfcc', 'fbank' delta_delta=False, # 'mfcc', 'fbank' @@ -65,6 +51,16 @@ _C.collator =CN( shuffle_method="batch_shuffle", # 'batch_shuffle', 'instance_shuffle' )) +_C.model = CN( + dict( + num_conv_layers=2, #Number of stacking convolution layers. + num_rnn_layers=3, #Number of stacking RNN layers. + rnn_layer_size=1024, #RNN layer size (number of RNN cells). + use_gru=True, #Use gru if set True. Use simple rnn if set False. + share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported. + )) + + DeepSpeech2Model.params(_C.model) _C.training = CN( diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index b54192dd3..1eefc871b 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -143,7 +143,6 @@ class DeepSpeech2Trainer(Trainer): train_dataset = ManifestDataset.from_config(config) config.data.manifest = config.data.dev_manifest - config.data.augmentation_config = "" dev_dataset = ManifestDataset.from_config(config) if self.parallel: @@ -165,18 +164,22 @@ class DeepSpeech2Trainer(Trainer): sortagrad=config.collator.sortagrad, shuffle_method=config.collator.shuffle_method) - collate_fn = SpeechCollator.from_config(config) + collate_fn_train = SpeechCollator.from_config(config) + + + config.collator.augmentation_config = "" + collate_fn_dev = SpeechCollator.from_config(config) self.train_loader = DataLoader( train_dataset, batch_sampler=batch_sampler, - collate_fn=collate_fn, + collate_fn=collate_fn_train, num_workers=config.collator.num_workers) self.valid_loader = DataLoader( dev_dataset, batch_size=config.collator.batch_size, shuffle=False, drop_last=False, - collate_fn=collate_fn) + collate_fn=collate_fn_dev) logger.info("Setup train/valid Dataloader!") @@ -324,8 +327,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): # return raw text config.data.manifest = config.data.test_manifest - config.data.keep_transcription_text = True - config.data.augmentation_config = "" # filter test examples, will cause less examples, but no mismatch with training # and can use large batch size , save training time, so filter test egs now. # config.data.min_input_len = 0.0 # second @@ -337,6 +338,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): test_dataset = ManifestDataset.from_config(config) config.collator.keep_transcription_text = True + config.collator.augmentation_config = "" # return text ord id self.test_loader = DataLoader( test_dataset, diff --git a/deepspeech/exps/u2/config.py b/deepspeech/exps/u2/config.py index 42725c74f..d8735453c 100644 --- a/deepspeech/exps/u2/config.py +++ b/deepspeech/exps/u2/config.py @@ -17,21 +17,13 @@ from deepspeech.exps.u2.model import U2Tester from deepspeech.exps.u2.model import U2Trainer from deepspeech.io.dataset import ManifestDataset from deepspeech.models.u2 import U2Model +from deepspeech.io.collator import SpeechCollator _C = CfgNode() _C.data = ManifestDataset.params() -_C.collator =CfgNode( - dict( - augmentation_config="", - unit_type="char", - keep_transcription_text=False, - batch_size=32, # batch size - num_workers=0, # data loader workers - sortagrad=False, # sorted in first epoch when True - shuffle_method="batch_shuffle" # 'batch_shuffle', 'instance_shuffle' - )) +_C.collator = SpeechCollator.params() _C.model = U2Model.params() diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 164903e69..836afa361 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -100,7 +100,7 @@ class U2Trainer(Trainer): if (batch_index + 1) % train_conf.log_interval == 0: msg += "train time: {:>.3f}s, ".format(iteration_time) - msg += "batch size: {}, ".format(self.config.data.batch_size) + msg += "batch size: {}, ".format(self.config.collator.batch_size) msg += "accum: {}, ".format(train_conf.accum_grad) msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses_np.items()) @@ -211,51 +211,52 @@ class U2Trainer(Trainer): def setup_dataloader(self): config = self.config.clone() config.defrost() - config.data.keep_transcription_text = False + config.collator.keep_transcription_text = False # train/valid dataset, return token ids config.data.manifest = config.data.train_manifest train_dataset = ManifestDataset.from_config(config) config.data.manifest = config.data.dev_manifest - config.data.augmentation_config = "" dev_dataset = ManifestDataset.from_config(config) - collate_fn = SpeechCollator.from_config(config) + collate_fn_train = SpeechCollator.from_config(config) + + config.collator.augmentation_config = "" + collate_fn_dev = SpeechCollator.from_config(config) + if self.parallel: batch_sampler = SortagradDistributedBatchSampler( train_dataset, - batch_size=config.data.batch_size, + batch_size=config.collator.batch_size, num_replicas=None, rank=None, shuffle=True, drop_last=True, - sortagrad=config.data.sortagrad, - shuffle_method=config.data.shuffle_method) + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) else: batch_sampler = SortagradBatchSampler( train_dataset, shuffle=True, - batch_size=config.data.batch_size, + batch_size=config.collator.batch_size, drop_last=True, - sortagrad=config.data.sortagrad, - shuffle_method=config.data.shuffle_method) + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) self.train_loader = DataLoader( train_dataset, batch_sampler=batch_sampler, - collate_fn=collate_fn, - num_workers=config.data.num_workers, ) + collate_fn=collate_fn_train, + num_workers=config.collator.num_workers, ) self.valid_loader = DataLoader( dev_dataset, - batch_size=config.data.batch_size, + batch_size=config.collator.batch_size, shuffle=False, drop_last=False, - collate_fn=collate_fn) + collate_fn=collate_fn_dev) # test dataset, return raw text config.data.manifest = config.data.test_manifest - config.data.keep_transcription_text = True - config.data.augmentation_config = "" # filter test examples, will cause less examples, but no mismatch with training # and can use large batch size , save training time, so filter test egs now. # config.data.min_input_len = 0.0 # second @@ -264,9 +265,11 @@ class U2Trainer(Trainer): # config.data.max_output_len = float('inf') # tokens # config.data.min_output_input_ratio = 0.00 # config.data.max_output_input_ratio = float('inf') + test_dataset = ManifestDataset.from_config(config) # return text ord id config.collator.keep_transcription_text = True + config.collator.augmentation_config = "" self.test_loader = DataLoader( test_dataset, batch_size=config.decoding.batch_size, diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index ac817a192..ab1e91652 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -75,8 +75,8 @@ class SpeechCollator(): """ assert 'augmentation_config' in config.collator assert 'keep_transcription_text' in config.collator - assert 'mean_std_filepath' in config.data - assert 'vocab_filepath' in config.data + assert 'mean_std_filepath' in config.collator + assert 'vocab_filepath' in config.collator assert 'specgram_type' in config.collator assert 'n_fft' in config.collator assert config.collator @@ -94,9 +94,9 @@ class SpeechCollator(): speech_collator = cls( aug_file=aug_file, random_seed=0, - mean_std_filepath=config.data.mean_std_filepath, + mean_std_filepath=config.collator.mean_std_filepath, unit_type=config.collator.unit_type, - vocab_filepath=config.data.vocab_filepath, + vocab_filepath=config.collator.vocab_filepath, spm_model_prefix=config.collator.spm_model_prefix, specgram_type=config.collator.specgram_type, feat_dim=config.collator.feat_dim, @@ -129,11 +129,31 @@ class SpeechCollator(): target_dB=-20, dither=1.0, keep_transcription_text=True): - """ - Padding audio features with zeros to make them have the same shape (or - a user-defined shape) within one bach. + """SpeechCollator Collator - if ``keep_transcription_text`` is False, text is token ids else is raw string. + Args: + unit_type(str): token unit type, e.g. char, word, spm + vocab_filepath (str): vocab file path. + mean_std_filepath (str): mean and std file path, which suffix is *.npy + spm_model_prefix (str): spm model prefix, need if `unit_type` is spm. + augmentation_config (str, optional): augmentation json str. Defaults to '{}'. + stride_ms (float, optional): stride size in ms. Defaults to 10.0. + window_ms (float, optional): window size in ms. Defaults to 20.0. + n_fft (int, optional): fft points for rfft. Defaults to None. + max_freq (int, optional): max cut freq. Defaults to None. + target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000. + specgram_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'. + feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None. + delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False. + use_dB_normalization (bool, optional): do dB normalization. Defaults to True. + target_dB (int, optional): target dB. Defaults to -20. + random_seed (int, optional): for random generator. Defaults to 0. + keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False. + if ``keep_transcription_text`` is False, text is token ids else is raw string. + + Do augmentations + Padding audio features with zeros to make them have the same shape (or + a user-defined shape) within one batch. """ self._keep_transcription_text = keep_transcription_text diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index 24d8486a8..70383b4da 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -40,15 +40,7 @@ class ManifestDataset(Dataset): def params(cls, config: Optional[CfgNode]=None) -> CfgNode: default = CfgNode( dict( - train_manifest="", - dev_manifest="", - test_manifest="", manifest="", - unit_type="char", - vocab_filepath="", - spm_model_prefix="", - mean_std_filepath="", - augmentation_config="", max_input_len=27.0, min_input_len=0.0, max_output_len=float('inf'), @@ -73,25 +65,10 @@ class ManifestDataset(Dataset): """ assert 'manifest' in config.data assert config.data.manifest - assert 'keep_transcription_text' in config.collator - - if isinstance(config.data.augmentation_config, (str, bytes)): - if config.data.augmentation_config: - aug_file = io.open( - config.data.augmentation_config, mode='r', encoding='utf8') - else: - aug_file = io.StringIO(initial_value='{}', newline='') - else: - aug_file = config.data.augmentation_config - assert isinstance(aug_file, io.StringIO) + dataset = cls( manifest_path=config.data.manifest, - unit_type=config.data.unit_type, - vocab_filepath=config.data.vocab_filepath, - mean_std_filepath=config.data.mean_std_filepath, - spm_model_prefix=config.data.spm_model_prefix, - augmentation_config=aug_file.read(), max_input_len=config.data.max_input_len, min_input_len=config.data.min_input_len, max_output_len=config.data.max_output_len, @@ -101,23 +78,8 @@ class ManifestDataset(Dataset): ) return dataset - - def _read_vocab(self, vocab_filepath): - """Load vocabulary from file.""" - vocab_lines = [] - with open(vocab_filepath, 'r', encoding='utf-8') as file: - vocab_lines.extend(file.readlines()) - vocab_list = [line[:-1] for line in vocab_lines] - return vocab_list - - def __init__(self, manifest_path, - unit_type, - vocab_filepath, - mean_std_filepath, - spm_model_prefix=None, - augmentation_config='{}', max_input_len=float('inf'), min_input_len=0.0, max_output_len=float('inf'), @@ -128,34 +90,16 @@ class ManifestDataset(Dataset): Args: manifest_path (str): manifest josn file path - unit_type(str): token unit type, e.g. char, word, spm - vocab_filepath (str): vocab file path. - mean_std_filepath (str): mean and std file path, which suffix is *.npy - spm_model_prefix (str): spm model prefix, need if `unit_type` is spm. - augmentation_config (str, optional): augmentation json str. Defaults to '{}'. max_input_len ([type], optional): maximum output seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf'). min_input_len (float, optional): minimum input seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0. max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0. min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0. max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0. min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05. - stride_ms (float, optional): stride size in ms. Defaults to 10.0. - window_ms (float, optional): window size in ms. Defaults to 20.0. - n_fft (int, optional): fft points for rfft. Defaults to None. - max_freq (int, optional): max cut freq. Defaults to None. - target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000. - specgram_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'. - feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None. - delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False. - use_dB_normalization (bool, optional): do dB normalization. Defaults to True. - target_dB (int, optional): target dB. Defaults to -20. - random_seed (int, optional): for random generator. Defaults to 0. - keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False. + """ super().__init__() - # self._rng = np.random.RandomState(random_seed) - # read manifest self._manifest = read_manifest( manifest_path=manifest_path, @@ -167,51 +111,6 @@ class ManifestDataset(Dataset): min_output_input_ratio=min_output_input_ratio) self._manifest.sort(key=lambda x: x["feat_shape"][0]) - # self._vocab_list = self._read_vocab(vocab_filepath) - - - # @property - # def manifest(self): - # return self._manifest - - # @property - # def vocab_size(self): - # """Return the vocabulary size. - - # Returns: - # int: Vocabulary size. - # """ - # return len(self._vocab_list) - - # @property - # def vocab_list(self): - # """Return the vocabulary in list. - - # Returns: - # List[str]: - # """ - # return self._vocab_list - - # @property - # def vocab_dict(self): - # """Return the vocabulary in dict. - - # Returns: - # Dict[str, int]: - # """ - # vocab_dict = dict( - # [(token, idx) for (idx, token) in enumerate(self._vocab_list)]) - # return vocab_dict - - # @property - # def feature_size(self): - # """Return the audio feature size. - - # Returns: - # int: audio feature size. - # """ - # return self._manifest[0]["feat_shape"][-1] - def __len__(self): return len(self._manifest) diff --git a/examples/aishell/s1/conf/conformer.yaml b/examples/aishell/s1/conf/conformer.yaml index b880f8587..116c91927 100644 --- a/examples/aishell/s1/conf/conformer.yaml +++ b/examples/aishell/s1/conf/conformer.yaml @@ -3,17 +3,20 @@ data: train_manifest: data/manifest.train dev_manifest: data/manifest.dev test_manifest: data/manifest.test - vocab_filepath: data/vocab.txt - unit_type: 'char' - spm_model_prefix: '' - augmentation_config: conf/augmentation.json - batch_size: 64 min_input_len: 0.5 max_input_len: 20.0 # second min_output_len: 0.0 max_output_len: 400.0 min_output_input_ratio: 0.05 max_output_input_ratio: 10.0 + + +collator: + vocab_filepath: data/vocab.txt + unit_type: 'char' + spm_model_prefix: '' + augmentation_config: conf/augmentation.json + batch_size: 64 raw_wav: True # use raw_wav or kaldi feature specgram_type: fbank #linear, mfcc, fbank feat_dim: 80 @@ -32,7 +35,6 @@ data: shuffle_method: batch_shuffle num_workers: 2 - # network architecture model: cmvn_file: "data/mean_std.json" diff --git a/examples/tiny/s0/conf/deepspeech2.yaml b/examples/tiny/s0/conf/deepspeech2.yaml index 434cf264c..6737d1b75 100644 --- a/examples/tiny/s0/conf/deepspeech2.yaml +++ b/examples/tiny/s0/conf/deepspeech2.yaml @@ -2,10 +2,7 @@ data: train_manifest: data/manifest.tiny dev_manifest: data/manifest.tiny - test_manifest: data/manifest.tiny - mean_std_filepath: data/mean_std.json - unit_type: char - vocab_filepath: data/vocab.txt + test_manifest: data/manifest.tiny min_input_len: 0.0 max_input_len: 27.0 min_output_len: 0.0 @@ -15,6 +12,9 @@ data: collator: + mean_std_filepath: data/mean_std.json + unit_type: char + vocab_filepath: data/vocab.txt augmentation_config: conf/augmentation.json random_seed: 0 spm_model_prefix: @@ -43,7 +43,7 @@ model: share_rnn_weights: True training: - n_epoch: 23 + n_epoch: 24 lr: 1e-5 lr_decay: 1.0 weight_decay: 1e-06 diff --git a/examples/tiny/s1/conf/transformer.yaml b/examples/tiny/s1/conf/transformer.yaml index 5e28e4e87..250995faa 100644 --- a/examples/tiny/s1/conf/transformer.yaml +++ b/examples/tiny/s1/conf/transformer.yaml @@ -3,26 +3,20 @@ data: train_manifest: data/manifest.tiny dev_manifest: data/manifest.tiny test_manifest: data/manifest.tiny - vocab_filepath: data/vocab.txt - unit_type: 'spm' - spm_model_prefix: 'data/bpe_unigram_200' - mean_std_filepath: "" - batch_size: 4 min_input_len: 0.5 # second max_input_len: 20.0 # second min_output_len: 0.0 # tokens max_output_len: 400.0 # tokens min_output_input_ratio: 0.05 max_output_input_ratio: 10.0 - raw_wav: True # use raw_wav or kaldi feature - sortagrad: True - shuffle_method: batch_shuffle - num_workers: 0 #2 - + collator: + vocab_filepath: data/vocab.txt + mean_std_filepath: "" augmentation_config: conf/augmentation.json random_seed: 0 - spm_model_prefix: + unit_type: 'spm' + spm_model_prefix: 'data/bpe_unigram_200' specgram_type: fbank feat_dim: 80 delta_delta: False @@ -35,6 +29,12 @@ collator: target_dB: -20 dither: 1.0 keep_transcription_text: False + batch_size: 4 + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 0 #2 + raw_wav: True # use raw_wav or kaldi feature + # network architecture model: From 089a8ed602721acf43c676b37249987ebd8bfa3b Mon Sep 17 00:00:00 2001 From: Haoxin Ma <745165806@qq.com> Date: Fri, 18 Jun 2021 09:47:53 +0000 Subject: [PATCH 08/15] fix deepspeech2/model.py and deepspeech2/config.py --- deepspeech/exps/deepspeech2/config.py | 76 ++++----------------------- deepspeech/exps/deepspeech2/model.py | 39 ++++++++++++++ 2 files changed, 50 insertions(+), 65 deletions(-) diff --git a/deepspeech/exps/deepspeech2/config.py b/deepspeech/exps/deepspeech2/config.py index 050a50b00..7d2250fc7 100644 --- a/deepspeech/exps/deepspeech2/config.py +++ b/deepspeech/exps/deepspeech2/config.py @@ -11,80 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from yacs.config import CfgNode as CN +from yacs.config import CfgNode from deepspeech.models.deepspeech2 import DeepSpeech2Model +from deepspeech.io.dataset import ManifestDataset +from deepspeech.io.collator import SpeechCollator +from deepspeech.exps.deepspeech2.model import DeepSpeech2Trainer +from deepspeech.exps.deepspeech2.model import DeepSpeech2Tester -_C = CN() -_C.data = CN( - dict( - train_manifest="", - dev_manifest="", - test_manifest="", - max_duration=float('inf'), - min_duration=0.0, - )) -_C.collator =CN( - dict( - unit_type="char", - vocab_filepath="", - spm_model_prefix="", - mean_std_filepath="", - augmentation_config="", - random_seed=0, - specgram_type='linear', # 'linear', 'mfcc', 'fbank' - feat_dim=0, # 'mfcc', 'fbank' - delta_delta=False, # 'mfcc', 'fbank' - stride_ms=10.0, # ms - window_ms=20.0, # ms - n_fft=None, # fft points - max_freq=None, # None for samplerate/2 - target_sample_rate=16000, # target sample rate - use_dB_normalization=True, - target_dB=-20, - dither=1.0, # feature dither - keep_transcription_text=False, - batch_size=32, # batch size - num_workers=0, # data loader workers - sortagrad=False, # sorted in first epoch when True - shuffle_method="batch_shuffle", # 'batch_shuffle', 'instance_shuffle' - )) +_C = CfgNode() -_C.model = CN( - dict( - num_conv_layers=2, #Number of stacking convolution layers. - num_rnn_layers=3, #Number of stacking RNN layers. - rnn_layer_size=1024, #RNN layer size (number of RNN cells). - use_gru=True, #Use gru if set True. Use simple rnn if set False. - share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported. - )) +_C.data = ManifestDataset.params() +_C.collator = SpeechCollator.params() -DeepSpeech2Model.params(_C.model) +_C.model = DeepSpeech2Model.params() -_C.training = CN( - dict( - lr=5e-4, # learning rate - lr_decay=1.0, # learning rate decay - weight_decay=1e-6, # the coeff of weight decay - global_grad_clip=5.0, # the global norm clip - n_epoch=50, # train epochs - )) +_C.training = DeepSpeech2Trainer.params() -_C.decoding = CN( - dict( - alpha=2.5, # Coef of LM for beam search. - beta=0.3, # Coef of WC for beam search. - cutoff_prob=1.0, # Cutoff probability for pruning. - cutoff_top_n=40, # Cutoff number for pruning. - lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model. - decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy - error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer' - num_proc_bsearch=8, # # of CPUs for beam search. - beam_size=500, # Beam search width. - batch_size=128, # decoding batch size - )) +_C.decoding = DeepSpeech2Tester.params() def get_cfg_defaults(): diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 1eefc871b..c11d1e259 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -34,10 +34,28 @@ from deepspeech.utils import layer_tools from deepspeech.utils import mp_tools from deepspeech.utils.log import Log +from typing import Optional +from yacs.config import CfgNode logger = Log(__name__).getlog() class DeepSpeech2Trainer(Trainer): + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + # training config + default = CfgNode( + dict( + lr=5e-4, # learning rate + lr_decay=1.0, # learning rate decay + weight_decay=1e-6, # the coeff of weight decay + global_grad_clip=5.0, # the global norm clip + n_epoch=50, # train epochs + )) + + if config is not None: + config.merge_from_other_cfg(default) + return default + def __init__(self, config, args): super().__init__(config, args) @@ -184,6 +202,27 @@ class DeepSpeech2Trainer(Trainer): class DeepSpeech2Tester(DeepSpeech2Trainer): + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + # testing config + default = CfgNode( + dict( + alpha=2.5, # Coef of LM for beam search. + beta=0.3, # Coef of WC for beam search. + cutoff_prob=1.0, # Cutoff probability for pruning. + cutoff_top_n=40, # Cutoff number for pruning. + lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model. + decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy + error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer' + num_proc_bsearch=8, # # of CPUs for beam search. + beam_size=500, # Beam search width. + batch_size=128, # decoding batch size + )) + + if config is not None: + config.merge_from_other_cfg(default) + return default + def __init__(self, config, args): super().__init__(config, args) From 3a743f3717f692ff9cdbbcb24244fbc8ae5ce93b Mon Sep 17 00:00:00 2001 From: Haoxin Ma <745165806@qq.com> Date: Fri, 18 Jun 2021 10:09:35 +0000 Subject: [PATCH 09/15] fix pre-commit --- deepspeech/exps/deepspeech2/config.py | 9 +-- deepspeech/exps/deepspeech2/model.py | 58 +++++++------- deepspeech/exps/u2/config.py | 2 +- deepspeech/exps/u2/model.py | 19 +++-- deepspeech/io/collator.py | 108 +++++++++++++------------- deepspeech/io/dataset.py | 16 +--- deepspeech/models/u2.py | 1 - 7 files changed, 108 insertions(+), 105 deletions(-) diff --git a/deepspeech/exps/deepspeech2/config.py b/deepspeech/exps/deepspeech2/config.py index 7d2250fc7..2f0f5c24b 100644 --- a/deepspeech/exps/deepspeech2/config.py +++ b/deepspeech/exps/deepspeech2/config.py @@ -13,12 +13,11 @@ # limitations under the License. from yacs.config import CfgNode -from deepspeech.models.deepspeech2 import DeepSpeech2Model -from deepspeech.io.dataset import ManifestDataset -from deepspeech.io.collator import SpeechCollator -from deepspeech.exps.deepspeech2.model import DeepSpeech2Trainer from deepspeech.exps.deepspeech2.model import DeepSpeech2Tester - +from deepspeech.exps.deepspeech2.model import DeepSpeech2Trainer +from deepspeech.io.collator import SpeechCollator +from deepspeech.io.dataset import ManifestDataset +from deepspeech.models.deepspeech2 import DeepSpeech2Model _C = CfgNode() diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index c11d1e259..deb8752b7 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -15,11 +15,13 @@ import time from collections import defaultdict from pathlib import Path +from typing import Optional import numpy as np import paddle from paddle import distributed as dist from paddle.io import DataLoader +from yacs.config import CfgNode from deepspeech.io.collator import SpeechCollator from deepspeech.io.dataset import ManifestDataset @@ -33,9 +35,6 @@ from deepspeech.utils import error_rate from deepspeech.utils import layer_tools from deepspeech.utils import mp_tools from deepspeech.utils.log import Log - -from typing import Optional -from yacs.config import CfgNode logger = Log(__name__).getlog() @@ -44,13 +43,13 @@ class DeepSpeech2Trainer(Trainer): def params(cls, config: Optional[CfgNode]=None) -> CfgNode: # training config default = CfgNode( - dict( - lr=5e-4, # learning rate - lr_decay=1.0, # learning rate decay - weight_decay=1e-6, # the coeff of weight decay - global_grad_clip=5.0, # the global norm clip - n_epoch=50, # train epochs - )) + dict( + lr=5e-4, # learning rate + lr_decay=1.0, # learning rate decay + weight_decay=1e-6, # the coeff of weight decay + global_grad_clip=5.0, # the global norm clip + n_epoch=50, # train epochs + )) if config is not None: config.merge_from_other_cfg(default) @@ -184,7 +183,6 @@ class DeepSpeech2Trainer(Trainer): collate_fn_train = SpeechCollator.from_config(config) - config.collator.augmentation_config = "" collate_fn_dev = SpeechCollator.from_config(config) self.train_loader = DataLoader( @@ -206,18 +204,18 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): def params(cls, config: Optional[CfgNode]=None) -> CfgNode: # testing config default = CfgNode( - dict( - alpha=2.5, # Coef of LM for beam search. - beta=0.3, # Coef of WC for beam search. - cutoff_prob=1.0, # Cutoff probability for pruning. - cutoff_top_n=40, # Cutoff number for pruning. - lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model. - decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy - error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer' - num_proc_bsearch=8, # # of CPUs for beam search. - beam_size=500, # Beam search width. - batch_size=128, # decoding batch size - )) + dict( + alpha=2.5, # Coef of LM for beam search. + beta=0.3, # Coef of WC for beam search. + cutoff_prob=1.0, # Cutoff probability for pruning. + cutoff_top_n=40, # Cutoff number for pruning. + lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model. + decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy + error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer' + num_proc_bsearch=8, # # of CPUs for beam search. + beam_size=500, # Beam search width. + batch_size=128, # decoding batch size + )) if config is not None: config.merge_from_other_cfg(default) @@ -235,7 +233,13 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): trans.append(''.join([chr(i) for i in ids])) return trans - def compute_metrics(self, utts, audio, audio_len, texts, texts_len, fout = None): + def compute_metrics(self, + utts, + audio, + audio_len, + texts, + texts_len, + fout=None): cfg = self.config.decoding errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors @@ -257,7 +261,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): cutoff_top_n=cfg.cutoff_top_n, num_processes=cfg.num_proc_bsearch) - for utt, target, result in zip(utts, target_transcripts, result_transcripts): + for utt, target, result in zip(utts, target_transcripts, + result_transcripts): errors, len_ref = errors_func(target, result) errors_sum += errors len_refs += len_ref @@ -287,7 +292,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): with open(self.args.result_file, 'w') as fout: for i, batch in enumerate(self.test_loader): utts, audio, audio_len, texts, texts_len = batch - metrics = self.compute_metrics(utts, audio, audio_len, texts, texts_len, fout) + metrics = self.compute_metrics(utts, audio, audio_len, texts, + texts_len, fout) errors_sum += metrics['errors_sum'] len_refs += metrics['len_refs'] num_ins += metrics['num_ins'] diff --git a/deepspeech/exps/u2/config.py b/deepspeech/exps/u2/config.py index d8735453c..4ec7bd190 100644 --- a/deepspeech/exps/u2/config.py +++ b/deepspeech/exps/u2/config.py @@ -15,9 +15,9 @@ from yacs.config import CfgNode from deepspeech.exps.u2.model import U2Tester from deepspeech.exps.u2.model import U2Trainer +from deepspeech.io.collator import SpeechCollator from deepspeech.io.dataset import ManifestDataset from deepspeech.models.u2 import U2Model -from deepspeech.io.collator import SpeechCollator _C = CfgNode() diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 836afa361..055518755 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -78,7 +78,8 @@ class U2Trainer(Trainer): start = time.time() utt, audio, audio_len, text, text_len = batch_data - loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len) + loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, + text_len) # loss div by `batch_size * accum_grad` loss /= train_conf.accum_grad loss.backward() @@ -121,7 +122,8 @@ class U2Trainer(Trainer): total_loss = 0.0 for i, batch in enumerate(self.valid_loader): utt, audio, audio_len, text, text_len = batch - loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len) + loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, + text_len) if paddle.isfinite(loss): num_utts = batch[1].shape[0] num_seen_utts += num_utts @@ -221,7 +223,7 @@ class U2Trainer(Trainer): dev_dataset = ManifestDataset.from_config(config) collate_fn_train = SpeechCollator.from_config(config) - + config.collator.augmentation_config = "" collate_fn_dev = SpeechCollator.from_config(config) @@ -372,7 +374,13 @@ class U2Tester(U2Trainer): trans.append(''.join([chr(i) for i in ids])) return trans - def compute_metrics(self, utts, audio, audio_len, texts, texts_len, fout=None): + def compute_metrics(self, + utts, + audio, + audio_len, + texts, + texts_len, + fout=None): cfg = self.config.decoding errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors @@ -399,7 +407,8 @@ class U2Tester(U2Trainer): simulate_streaming=cfg.simulate_streaming) decode_time = time.time() - start_time - for utt, target, result in zip(utts, target_transcripts, result_transcripts): + for utt, target, result in zip(utts, target_transcripts, + result_transcripts): errors, len_ref = errors_func(target, result) errors_sum += errors len_refs += len_ref diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index ab1e91652..ecf7024c1 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -11,21 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import io +import time +from collections import namedtuple +from typing import Optional + import numpy as np +from yacs.config import CfgNode -from deepspeech.frontend.utility import IGNORE_ID -from deepspeech.io.utility import pad_sequence -from deepspeech.utils.log import Log from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer from deepspeech.frontend.normalizer import FeatureNormalizer from deepspeech.frontend.speech import SpeechSegment -import io -import time -from yacs.config import CfgNode -from typing import Optional - -from collections import namedtuple +from deepspeech.frontend.utility import IGNORE_ID +from deepspeech.io.utility import pad_sequence +from deepspeech.utils.log import Log __all__ = ["SpeechCollator"] @@ -34,6 +34,7 @@ logger = Log(__name__).getlog() # namedtupe need global for pickle. TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object']) + class SpeechCollator(): @classmethod def params(cls, config: Optional[CfgNode]=None) -> CfgNode: @@ -56,8 +57,7 @@ class SpeechCollator(): use_dB_normalization=True, target_dB=-20, dither=1.0, # feature dither - keep_transcription_text=False - )) + keep_transcription_text=False)) if config is not None: config.merge_from_other_cfg(default) @@ -84,7 +84,9 @@ class SpeechCollator(): if isinstance(config.collator.augmentation_config, (str, bytes)): if config.collator.augmentation_config: aug_file = io.open( - config.collator.augmentation_config, mode='r', encoding='utf8') + config.collator.augmentation_config, + mode='r', + encoding='utf8') else: aug_file = io.StringIO(initial_value='{}', newline='') else: @@ -92,43 +94,46 @@ class SpeechCollator(): assert isinstance(aug_file, io.StringIO) speech_collator = cls( - aug_file=aug_file, - random_seed=0, - mean_std_filepath=config.collator.mean_std_filepath, - unit_type=config.collator.unit_type, - vocab_filepath=config.collator.vocab_filepath, - spm_model_prefix=config.collator.spm_model_prefix, - specgram_type=config.collator.specgram_type, - feat_dim=config.collator.feat_dim, - delta_delta=config.collator.delta_delta, - stride_ms=config.collator.stride_ms, - window_ms=config.collator.window_ms, - n_fft=config.collator.n_fft, - max_freq=config.collator.max_freq, - target_sample_rate=config.collator.target_sample_rate, - use_dB_normalization=config.collator.use_dB_normalization, - target_dB=config.collator.target_dB, - dither=config.collator.dither, - keep_transcription_text=config.collator.keep_transcription_text - ) + aug_file=aug_file, + random_seed=0, + mean_std_filepath=config.collator.mean_std_filepath, + unit_type=config.collator.unit_type, + vocab_filepath=config.collator.vocab_filepath, + spm_model_prefix=config.collator.spm_model_prefix, + specgram_type=config.collator.specgram_type, + feat_dim=config.collator.feat_dim, + delta_delta=config.collator.delta_delta, + stride_ms=config.collator.stride_ms, + window_ms=config.collator.window_ms, + n_fft=config.collator.n_fft, + max_freq=config.collator.max_freq, + target_sample_rate=config.collator.target_sample_rate, + use_dB_normalization=config.collator.use_dB_normalization, + target_dB=config.collator.target_dB, + dither=config.collator.dither, + keep_transcription_text=config.collator.keep_transcription_text) return speech_collator - def __init__(self, aug_file, mean_std_filepath, - vocab_filepath, spm_model_prefix, - random_seed=0, - unit_type="char", - specgram_type='linear', # 'linear', 'mfcc', 'fbank' - feat_dim=0, # 'mfcc', 'fbank' - delta_delta=False, # 'mfcc', 'fbank' - stride_ms=10.0, # ms - window_ms=20.0, # ms - n_fft=None, # fft points - max_freq=None, # None for samplerate/2 - target_sample_rate=16000, # target sample rate - use_dB_normalization=True, - target_dB=-20, - dither=1.0, - keep_transcription_text=True): + def __init__( + self, + aug_file, + mean_std_filepath, + vocab_filepath, + spm_model_prefix, + random_seed=0, + unit_type="char", + specgram_type='linear', # 'linear', 'mfcc', 'fbank' + feat_dim=0, # 'mfcc', 'fbank' + delta_delta=False, # 'mfcc', 'fbank' + stride_ms=10.0, # ms + window_ms=20.0, # ms + n_fft=None, # fft points + max_freq=None, # None for samplerate/2 + target_sample_rate=16000, # target sample rate + use_dB_normalization=True, + target_dB=-20, + dither=1.0, + keep_transcription_text=True): """SpeechCollator Collator Args: @@ -159,9 +164,8 @@ class SpeechCollator(): self._local_data = TarLocalData(tar2info={}, tar2object={}) self._augmentation_pipeline = AugmentationPipeline( - augmentation_config=aug_file.read(), - random_seed=random_seed) - + augmentation_config=aug_file.read(), random_seed=random_seed) + self._normalizer = FeatureNormalizer( mean_std_filepath) if mean_std_filepath else None @@ -290,8 +294,6 @@ class SpeechCollator(): text_lens = np.array(text_lens).astype(np.int64) return utts, padded_audios, audio_lens, padded_texts, text_lens - - @property def manifest(self): return self._manifest @@ -318,4 +320,4 @@ class SpeechCollator(): @property def stride_ms(self): - return self._speech_featurizer.stride_ms \ No newline at end of file + return self._speech_featurizer.stride_ms diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index 70383b4da..92c60f35c 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -12,19 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import io -import tarfile -import time -from collections import namedtuple from typing import Optional -import numpy as np from paddle.io import Dataset from yacs.config import CfgNode -from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline -from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer -from deepspeech.frontend.normalizer import FeatureNormalizer -from deepspeech.frontend.speech import SpeechSegment from deepspeech.frontend.utility import read_manifest from deepspeech.utils.log import Log @@ -46,8 +38,7 @@ class ManifestDataset(Dataset): max_output_len=float('inf'), min_output_len=0.0, max_output_input_ratio=float('inf'), - min_output_input_ratio=0.0, - )) + min_output_input_ratio=0.0, )) if config is not None: config.merge_from_other_cfg(default) @@ -66,7 +57,6 @@ class ManifestDataset(Dataset): assert 'manifest' in config.data assert config.data.manifest - dataset = cls( manifest_path=config.data.manifest, max_input_len=config.data.max_input_len, @@ -74,8 +64,7 @@ class ManifestDataset(Dataset): max_output_len=config.data.max_output_len, min_output_len=config.data.min_output_len, max_output_input_ratio=config.data.max_output_input_ratio, - min_output_input_ratio=config.data.min_output_input_ratio, - ) + min_output_input_ratio=config.data.min_output_input_ratio, ) return dataset def __init__(self, @@ -111,7 +100,6 @@ class ManifestDataset(Dataset): min_output_input_ratio=min_output_input_ratio) self._manifest.sort(key=lambda x: x["feat_shape"][0]) - def __len__(self): return len(self._manifest) diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py index bcfddaef0..238e2d35c 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2.py @@ -905,7 +905,6 @@ class U2InferModel(U2Model): def __init__(self, configs: dict): super().__init__(configs) - def forward(self, feats, feats_lengths, From 3652b87f33877d4b64b75398f9f99c34b1e5b02e Mon Sep 17 00:00:00 2001 From: Haoxin Ma <745165806@qq.com> Date: Fri, 18 Jun 2021 10:11:17 +0000 Subject: [PATCH 10/15] fix --- deepspeech/io/collator.py | 1 - deepspeech/io/dataset.py | 1 - 2 files changed, 2 deletions(-) diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index ecf7024c1..1061f97cf 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import io -import time from collections import namedtuple from typing import Optional diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index 92c60f35c..3fc4e9887 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import io from typing import Optional from paddle.io import Dataset From 8c1bf1a730de9bd6a2a0d8393fd99be4bb8b9657 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Mon, 21 Jun 2021 03:10:14 +0000 Subject: [PATCH 11/15] fix ds2 conf for new data pipeline --- examples/aishell/s0/conf/deepspeech2.yaml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/aishell/s0/conf/deepspeech2.yaml b/examples/aishell/s0/conf/deepspeech2.yaml index 54ce240e7..8cc4c4c9c 100644 --- a/examples/aishell/s0/conf/deepspeech2.yaml +++ b/examples/aishell/s0/conf/deepspeech2.yaml @@ -3,8 +3,6 @@ data: train_manifest: data/manifest.train dev_manifest: data/manifest.dev test_manifest: data/manifest.test - mean_std_filepath: data/mean_std.json - vocab_filepath: data/vocab.txt min_input_len: 0.0 max_input_len: 27.0 # second min_output_len: 0.0 @@ -14,6 +12,9 @@ data: collator: + mean_std_filepath: data/mean_std.json + unit_type: char + vocab_filepath: data/vocab.txt augmentation_config: conf/augmentation.json random_seed: 0 spm_model_prefix: From 5a3a9e1f5055260f966d24680d5bb2e83f1d5b54 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 22 Jun 2021 02:58:58 +0000 Subject: [PATCH 12/15] fix chunk default config; tarball ckpt prfix dir; --- examples/aishell/s1/README.md | 10 ++++++++++ examples/aishell/s1/conf/chunk_conformer.yaml | 2 +- utils/tarball.sh | 3 ++- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/examples/aishell/s1/README.md b/examples/aishell/s1/README.md index 2048c4d58..c306f8aa1 100644 --- a/examples/aishell/s1/README.md +++ b/examples/aishell/s1/README.md @@ -9,6 +9,16 @@ | conformer | conf/conformer.yaml | spec_aug + shift | test | ctc_prefix_beam_search | - | 0.062196 | | conformer | conf/conformer.yaml | spec_aug + shift | test | attention_rescoring | - | 0.054694 | +## Chunk Conformer + +| Model | Config | Augmentation| Test set | Decode method | Chunk | Loss | WER | +| --- | --- | --- | --- | --- | --- | --- | --- | +| conformer | conf/chunk_conformer.yaml | spec_aug + shift | test | attention | 16 | - | 0.061939 | +| conformer | conf/chunk_conformer.yaml | spec_aug + shift | test | ctc_greedy_search | 16 | - | 0.070806 | +| conformer | conf/chunk_conformer.yaml | spec_aug + shift | test | ctc_prefix_beam_search | 16 | - | 0.070739 | +| conformer | conf/chunk_conformer.yaml | spec_aug + shift | test | attention_rescoring | 16 | - | 0.059400 | + + ## Transformer | Model | Config | Augmentation| Test set | Decode method | Loss | WER | diff --git a/examples/aishell/s1/conf/chunk_conformer.yaml b/examples/aishell/s1/conf/chunk_conformer.yaml index 904624c3c..e626e1064 100644 --- a/examples/aishell/s1/conf/chunk_conformer.yaml +++ b/examples/aishell/s1/conf/chunk_conformer.yaml @@ -78,7 +78,7 @@ model: training: - n_epoch: 180 + n_epoch: 240 accum_grad: 4 global_grad_clip: 5.0 optim: adam diff --git a/utils/tarball.sh b/utils/tarball.sh index 100b4719e..224b740cd 100755 --- a/utils/tarball.sh +++ b/utils/tarball.sh @@ -18,7 +18,8 @@ function clean() { } trap clean EXIT -cp ${ckpt_prefix}.* ${output} +# ckpt_prfix.{json,...} and ckpt_prfix dir +cp -r ${ckpt_prefix}* ${output} cp ${model_config} ${mean_std} ${vocab} ${output} tar zcvf release.tar.gz ${output} From 68149cb9a7d39c14e95ada2979a4b7200eaf4902 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 22 Jun 2021 03:25:26 +0000 Subject: [PATCH 13/15] fix config for new datapipeline --- examples/aishell/s1/README.md | 2 +- examples/aishell/s1/conf/chunk_conformer.yaml | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/examples/aishell/s1/README.md b/examples/aishell/s1/README.md index c306f8aa1..601b0a8d0 100644 --- a/examples/aishell/s1/README.md +++ b/examples/aishell/s1/README.md @@ -12,7 +12,7 @@ ## Chunk Conformer | Model | Config | Augmentation| Test set | Decode method | Chunk | Loss | WER | -| --- | --- | --- | --- | --- | --- | --- | --- | +| --- | --- | --- | --- | --- | --- | --- | --- | | conformer | conf/chunk_conformer.yaml | spec_aug + shift | test | attention | 16 | - | 0.061939 | | conformer | conf/chunk_conformer.yaml | spec_aug + shift | test | ctc_greedy_search | 16 | - | 0.070806 | | conformer | conf/chunk_conformer.yaml | spec_aug + shift | test | ctc_prefix_beam_search | 16 | - | 0.070739 | diff --git a/examples/aishell/s1/conf/chunk_conformer.yaml b/examples/aishell/s1/conf/chunk_conformer.yaml index e626e1064..0e5b8699f 100644 --- a/examples/aishell/s1/conf/chunk_conformer.yaml +++ b/examples/aishell/s1/conf/chunk_conformer.yaml @@ -3,17 +3,20 @@ data: train_manifest: data/manifest.train dev_manifest: data/manifest.dev test_manifest: data/manifest.test - vocab_filepath: data/vocab.txt - unit_type: 'char' - spm_model_prefix: '' - augmentation_config: conf/augmentation.json - batch_size: 32 min_input_len: 0.5 max_input_len: 20.0 # second min_output_len: 0.0 max_output_len: 400.0 min_output_input_ratio: 0.05 max_output_input_ratio: 10.0 + + +collator: + vocab_filepath: data/vocab.txt + unit_type: 'char' + spm_model_prefix: '' + augmentation_config: conf/augmentation.json + batch_size: 32 raw_wav: True # use raw_wav or kaldi feature specgram_type: fbank #linear, mfcc, fbank feat_dim: 80 @@ -30,7 +33,7 @@ data: keep_transcription_text: False sortagrad: True shuffle_method: batch_shuffle - num_workers: 0 + num_workers: 2 # network architecture From 1b84f21ccfda2794e323a69a163411ab15c17288 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 22 Jun 2021 06:27:19 +0000 Subject: [PATCH 14/15] fix miss match --- utils/tarball.sh | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/utils/tarball.sh b/utils/tarball.sh index 224b740cd..a4611c75b 100755 --- a/utils/tarball.sh +++ b/utils/tarball.sh @@ -18,8 +18,11 @@ function clean() { } trap clean EXIT -# ckpt_prfix.{json,...} and ckpt_prfix dir -cp -r ${ckpt_prefix}* ${output} +# ckpt_prfix dir +cp -r ${ckpt_prefix} ${output} +# ckpt_prfix.{json,...} +cp ${ckpt_prefix}.* ${output} +# model config, mean std, vocab cp ${model_config} ${mean_std} ${vocab} ${output} tar zcvf release.tar.gz ${output} From 3c6eea077b2b077b9c3f5cc7baf339c545053d35 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 22 Jun 2021 07:27:43 +0000 Subject: [PATCH 15/15] cp dir when it exits --- utils/tarball.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/utils/tarball.sh b/utils/tarball.sh index a4611c75b..5f7c21a34 100755 --- a/utils/tarball.sh +++ b/utils/tarball.sh @@ -19,7 +19,9 @@ function clean() { trap clean EXIT # ckpt_prfix dir -cp -r ${ckpt_prefix} ${output} +if [ -d ${ckpt_prefix} ];then + cp -r ${ckpt_prefix} ${output} +fi # ckpt_prfix.{json,...} cp ${ckpt_prefix}.* ${output} # model config, mean std, vocab