From 855e4c379239462ab92f1538845831d7a0a9a4f3 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 15 Jun 2021 11:05:36 +0000 Subject: [PATCH] add common utils --- third_party/paddle_audio/frontend/common.py | 103 +++++++++++ third_party/paddle_audio/frontend/kaldi.py | 46 +---- .../paddle_audio/frontend/kaldi_test.py | 166 +++++++++++++++++- 3 files changed, 276 insertions(+), 39 deletions(-) create mode 100644 third_party/paddle_audio/frontend/common.py diff --git a/third_party/paddle_audio/frontend/common.py b/third_party/paddle_audio/frontend/common.py new file mode 100644 index 000000000..ad74c67d0 --- /dev/null +++ b/third_party/paddle_audio/frontend/common.py @@ -0,0 +1,103 @@ +import paddle +import numpy as np +from typing import Tuple + + +# https://github.com/kaldi-asr/kaldi/blob/cbed4ff688/src/feat/feature-window.cc#L109 +def povey_window(frame_len:int) -> np.ndarray: + win = np.empty(frame_len) + a = 2 * np.pi / (frame_len -1) + for i in range(frame_len): + win[i] = (0.5 - 0.5 * np.cos(a * i) )**0.85 + return win + +def hann_window(frame_len:int) -> np.ndarray: + win = np.empty(frame_len) + a = 2 * np.pi / (frame_len -1) + for i in range(frame_len): + win[i] = 0.5 - 0.5 * np.cos(a * i) + return win + +def sine_window(frame_len:int) -> np.ndarray: + win = np.empty(frame_len) + a = 2 * np.pi / (frame_len -1) + for i in range(frame_len): + win[i] = np.sin(0.5 * a * i) + return win + +def hamm_window(frame_len:int) -> np.ndarray: + win = np.empty(frame_len) + a = 2 * np.pi / (frame_len -1) + for i in range(frame_len): + win[i] = 0.54 - 0.46 * np.cos(a * i) + return win + +def get_window(wintype:str, winlen:int) -> np.ndarray: + # calculate window + if not wintype or wintype == 'rectangular': + window = np.ones(winlen) + elif wintype == "hann": + window = hann_window(winlen) + elif wintype == "hamm": + window = hamm_window(winlen) + elif wintype == "povey": + window = povey_window(winlen) + else: + msg = f"{wintype} Not supported yet!" + raise ValueError(msg) + return window + + +def dft_matrix(n_fft:int, winlen:int=None, n_bin:int=None) -> Tuple[np.ndarray, np.ndarray, int]: + # https://en.wikipedia.org/wiki/Discrete_Fourier_transform + # (n_bins, n_fft) complex + if n_bin is None: + n_bin = 1 + n_fft // 2 + if winlen is None: + winlen = n_bin + # https://github.com/numpy/numpy/blob/v1.20.0/numpy/fft/_pocketfft.py#L49 + kernel_size = min(n_fft, winlen) + + n = np.arange(0, n_fft, 1.) + wsin = np.empty((n_bin, kernel_size)) #[Cout, kernel_size] + wcos = np.empty((n_bin, kernel_size)) #[Cout, kernel_size] + for k in range(n_bin): # Only half of the bins contain useful info + wsin[k,:] = -np.sin(2*np.pi*k*n/n_fft)[:kernel_size] + wcos[k,:] = np.cos(2*np.pi*k*n/n_fft)[:kernel_size] + w_real = wcos + w_imag = wsin + return w_real, w_imag, kernel_size + + +def dft_matrix_fast(n_fft:int, winlen:int=None, n_bin:int=None) -> Tuple[np.ndarray, np.ndarray, int]: + # (n_bins, n_fft) complex + if n_bin is None: + n_bin = 1 + n_fft // 2 + if winlen is None: + winlen = n_bin + # https://github.com/numpy/numpy/blob/v1.20.0/numpy/fft/_pocketfft.py#L49 + kernel_size = min(n_fft, winlen) + + # https://en.wikipedia.org/wiki/DFT_matrix + # https://ccrma.stanford.edu/~jos/st/Matrix_Formulation_DFT.html + weight = np.fft.fft(np.eye(n_fft))[:self.n_bin, :kernel_size] + w_real = weight.real + w_imag = weight.imag + return w_real, w_imag, kernel_size + + +def hz2mel(hz): + """Convert a value in Hertz to Mels + + :param hz: a value in Hz. This can also be a numpy array, conversion proceeds element-wise. + :returns: a value in Mels. If an array was passed in, an identical sized array is returned. + """ + return 1127 * np.log(1+hz/700.0) + +def mel2hz(mel): + """Convert a value in Mels to Hertz + + :param mel: a value in Mels. This can also be a numpy array, conversion proceeds element-wise. + :returns: a value in Hertz. If an array was passed in, an identical sized array is returned. + """ + return 700 * (np.exp(mel/1127.0)-1) \ No newline at end of file diff --git a/third_party/paddle_audio/frontend/kaldi.py b/third_party/paddle_audio/frontend/kaldi.py index 999a8fbc4..8cb3f8c6c 100644 --- a/third_party/paddle_audio/frontend/kaldi.py +++ b/third_party/paddle_audio/frontend/kaldi.py @@ -6,6 +6,9 @@ from paddle import nn from paddle.nn import functional as F import soundfile as sf +from .common import get_window, dft_matrix + + def read(wavpath:str, sr:int = None, dtype='int16')->Tuple[int, np.ndarray]: wav, r_sr = sf.read(wavpath, dtype=dtype) if sr: @@ -65,13 +68,6 @@ def frames(x: Tensor, return frames, num_frames -def _povey_window(frame_len:int) -> np.ndarray: - win = np.empty(frame_len) - for i in range(frame_len): - win[i] = (0.5 - 0.5 * np.cos(2 * np.pi * i / (frame_len - 1)) )**0.85 - return win - - class STFT(nn.Layer): """A module for computing stft transformation in a differentiable way. @@ -110,38 +106,10 @@ class STFT(nn.Layer): self.n_fft = n_fft self.n_bin = 1 + n_fft // 2 - # https://github.com/numpy/numpy/blob/v1.20.0/numpy/fft/_pocketfft.py#L49 - kernel_size = min(self.n_fft, self.win_length) + w_real, w_imag, kernel_size = dft_matrix(self.n_fft, self.win_length, self.n_bin) # calculate window - if not window_type: - window = np.ones(kernel_size) - elif window_type == "hann": - window = np.hanning(kernel_size) - elif window_type == "hamm": - window = np.hamming(kernel_size) - elif window_type == "povey": - window = _povey_window(kernel_size) - else: - msg = f"{window_type} Not supported yet!" - raise ValueError(msg) - - # https://en.wikipedia.org/wiki/Discrete_Fourier_transform - # (n_bins, n_fft) complex - n = np.arange(0, self.n_fft, 1.) - wsin = np.empty((self.n_bin, kernel_size)) #[Cout, kernel_size] - wcos = np.empty((self.n_bin, kernel_size)) #[Cout, kernel_size] - for k in range(self.n_bin): # Only half of the bins contain useful info - wsin[k,:] = -np.sin(2*np.pi*k*n/self.n_fft)[:kernel_size] - wcos[k,:] = np.cos(2*np.pi*k*n/self.n_fft)[:kernel_size] - w_real = wcos - w_imag = wsin - - # https://en.wikipedia.org/wiki/DFT_matrix - # https://ccrma.stanford.edu/~jos/st/Matrix_Formulation_DFT.html - # weight = np.fft.fft(np.eye(n_fft))[:self.n_bin, :kernel_size] - # w_real = weight.real - # w_imag = weight.imag + window = get_window(window_type, kernel_size) # (2 * n_bins, kernel_size) w = np.concatenate([w_real, w_imag], axis=0) @@ -210,4 +178,6 @@ def magspec(C: Tensor, eps=1e-10) -> Tensor: Tensor: [B, T, C] """ pspec = powspec(C) - return paddle.sqrt(pspec + eps) \ No newline at end of file + return paddle.sqrt(pspec + eps) + + diff --git a/third_party/paddle_audio/frontend/kaldi_test.py b/third_party/paddle_audio/frontend/kaldi_test.py index 4e306c0f7..0e609aebf 100644 --- a/third_party/paddle_audio/frontend/kaldi_test.py +++ b/third_party/paddle_audio/frontend/kaldi_test.py @@ -9,8 +9,11 @@ import math import logging from pathlib import Path +from scipy.fftpack import dct + from third_party.paddle_audio.frontend import kaldi + def round_half_up(number): return int(decimal.Decimal(number).quantize(decimal.Decimal('1'), rounding=decimal.ROUND_HALF_UP)) @@ -111,6 +114,168 @@ def powspec(frames, NFFT): return numpy.square(magspec(frames, NFFT)) + +def mfcc(signal,samplerate=16000,winlen=0.025,winstep=0.01,numcep=13, + nfilt=23,nfft=512,lowfreq=20,highfreq=None,dither=1.0,remove_dc_offset=True,preemph=0.97, + ceplifter=22,useEnergy=True,wintype='povey'): + """Compute MFCC features from an audio signal. + + :param signal: the audio signal from which to compute features. Should be an N*1 array + :param samplerate: the samplerate of the signal we are working with. + :param winlen: the length of the analysis window in seconds. Default is 0.025s (25 milliseconds) + :param winstep: the step between successive windows in seconds. Default is 0.01s (10 milliseconds) + :param numcep: the number of cepstrum to return, default 13 + :param nfilt: the number of filters in the filterbank, default 26. + :param nfft: the FFT size. Default is 512. + :param lowfreq: lowest band edge of mel filters. In Hz, default is 0. + :param highfreq: highest band edge of mel filters. In Hz, default is samplerate/2 + :param preemph: apply preemphasis filter with preemph as coefficient. 0 is no filter. Default is 0.97. + :param ceplifter: apply a lifter to final cepstral coefficients. 0 is no lifter. Default is 22. + :param appendEnergy: if this is true, the zeroth cepstral coefficient is replaced with the log of the total frame energy. + :param winfunc: the analysis window to apply to each frame. By default no window is applied. You can use numpy window functions here e.g. winfunc=numpy.hamming + :returns: A numpy array of size (NUMFRAMES by numcep) containing features. Each row holds 1 feature vector. + """ + feat,energy = fbank(signal,samplerate,winlen,winstep,nfilt,nfft,lowfreq,highfreq,dither,remove_dc_offset,preemph,wintype) + feat = numpy.log(feat) + feat = dct(feat, type=2, axis=1, norm='ortho')[:,:numcep] + feat = lifter(feat,ceplifter) + if useEnergy: feat[:,0] = numpy.log(energy) # replace first cepstral coefficient with log of frame energy + return feat + +def fbank(signal,samplerate=16000,winlen=0.025,winstep=0.01, + nfilt=40,nfft=512,lowfreq=0,highfreq=None,dither=1.0,remove_dc_offset=True, preemph=0.97, + wintype='hamming'): + """Compute Mel-filterbank energy features from an audio signal. + + :param signal: the audio signal from which to compute features. Should be an N*1 array + :param samplerate: the samplerate of the signal we are working with. + :param winlen: the length of the analysis window in seconds. Default is 0.025s (25 milliseconds) + :param winstep: the step between successive windows in seconds. Default is 0.01s (10 milliseconds) + :param nfilt: the number of filters in the filterbank, default 26. + :param nfft: the FFT size. Default is 512. + :param lowfreq: lowest band edge of mel filters. In Hz, default is 0. + :param highfreq: highest band edge of mel filters. In Hz, default is samplerate/2 + :param preemph: apply preemphasis filter with preemph as coefficient. 0 is no filter. Default is 0.97. + :param winfunc: the analysis window to apply to each frame. By default no window is applied. You can use numpy window functions here e.g. winfunc=numpy.hamming + winfunc=lambda x:numpy.ones((x,)) + :returns: 2 values. The first is a numpy array of size (NUMFRAMES by nfilt) containing features. Each row holds 1 feature vector. The + second return value is the energy in each frame (total energy, unwindowed) + """ + highfreq= highfreq or samplerate/2 + frames,raw_frames = sigproc.framesig(signal, winlen*samplerate, winstep*samplerate, dither, preemph, remove_dc_offset, wintype) + pspec = sigproc.powspec(frames,nfft) # nearly the same until this part + energy = numpy.sum(raw_frames**2,1) # this stores the raw energy in each frame + energy = numpy.where(energy == 0,numpy.finfo(float).eps,energy) # if energy is zero, we get problems with log + + fb = get_filterbanks(nfilt,nfft,samplerate,lowfreq,highfreq) + feat = numpy.dot(pspec,fb.T) # compute the filterbank energies + feat = numpy.where(feat == 0,numpy.finfo(float).eps,feat) # if feat is zero, we get problems with log + + return feat,energy + +def logfbank(signal,samplerate=16000,winlen=0.025,winstep=0.01, + nfilt=40,nfft=512,lowfreq=64,highfreq=None,dither=1.0,remove_dc_offset=True,preemph=0.97,wintype='hamming'): + """Compute log Mel-filterbank energy features from an audio signal. + + :param signal: the audio signal from which to compute features. Should be an N*1 array + :param samplerate: the samplerate of the signal we are working with. + :param winlen: the length of the analysis window in seconds. Default is 0.025s (25 milliseconds) + :param winstep: the step between successive windows in seconds. Default is 0.01s (10 milliseconds) + :param nfilt: the number of filters in the filterbank, default 26. + :param nfft: the FFT size. Default is 512. + :param lowfreq: lowest band edge of mel filters. In Hz, default is 0. + :param highfreq: highest band edge of mel filters. In Hz, default is samplerate/2 + :param preemph: apply preemphasis filter with preemph as coefficient. 0 is no filter. Default is 0.97. + :returns: A numpy array of size (NUMFRAMES by nfilt) containing features. Each row holds 1 feature vector. + """ + feat,energy = fbank(signal,samplerate,winlen,winstep,nfilt,nfft,lowfreq,highfreq,dither, remove_dc_offset,preemph,wintype) + return numpy.log(feat) + +def hz2mel(hz): + """Convert a value in Hertz to Mels + + :param hz: a value in Hz. This can also be a numpy array, conversion proceeds element-wise. + :returns: a value in Mels. If an array was passed in, an identical sized array is returned. + """ + return 1127 * numpy.log(1+hz/700.0) + +def mel2hz(mel): + """Convert a value in Mels to Hertz + + :param mel: a value in Mels. This can also be a numpy array, conversion proceeds element-wise. + :returns: a value in Hertz. If an array was passed in, an identical sized array is returned. + """ + return 700 * (numpy.exp(mel/1127.0)-1) + +def get_filterbanks(nfilt=26,nfft=512,samplerate=16000,lowfreq=0,highfreq=None): + """Compute a Mel-filterbank. The filters are stored in the rows, the columns correspond + to fft bins. The filters are returned as an array of size nfilt * (nfft/2 + 1) + + :param nfilt: the number of filters in the filterbank, default 20. + :param nfft: the FFT size. Default is 512. + :param samplerate: the samplerate of the signal we are working with. Affects mel spacing. + :param lowfreq: lowest band edge of mel filters, default 0 Hz + :param highfreq: highest band edge of mel filters, default samplerate/2 + :returns: A numpy array of size nfilt * (nfft/2 + 1) containing filterbank. Each row holds 1 filter. + """ + highfreq= highfreq or samplerate/2 + assert highfreq <= samplerate/2, "highfreq is greater than samplerate/2" + + # compute points evenly spaced in mels + lowmel = hz2mel(lowfreq) + highmel = hz2mel(highfreq) + + # check kaldi/src/feat/Mel-computations.h + fbank = numpy.zeros([nfilt,nfft//2+1]) + mel_freq_delta = (highmel-lowmel)/(nfilt+1) + for j in range(0,nfilt): + leftmel = lowmel+j*mel_freq_delta + centermel = lowmel+(j+1)*mel_freq_delta + rightmel = lowmel+(j+2)*mel_freq_delta + for i in range(0,nfft//2): + mel=hz2mel(i*samplerate/nfft) + if mel>leftmel and mel 0: + nframes,ncoeff = numpy.shape(cepstra) + n = numpy.arange(ncoeff) + lift = 1 + (L/2.)*numpy.sin(numpy.pi*n/L) + return lift*cepstra + else: + # values of L <= 0, do nothing + return cepstra + +def delta(feat, N): + """Compute delta features from a feature vector sequence. + + :param feat: A numpy array of size (NUMFRAMES by number of features) containing features. Each row holds 1 feature vector. + :param N: For each frame, calculate delta features based on preceding and following N frames + :returns: A numpy array of size (NUMFRAMES by number of features) containing delta features. Each row holds 1 delta feature vector. + """ + if N < 1: + raise ValueError('N must be an integer >= 1') + NUMFRAMES = len(feat) + denominator = 2 * sum([i**2 for i in range(1, N+1)]) + delta_feat = numpy.empty_like(feat) + padded = numpy.pad(feat, ((N, N), (0, 0)), mode='edge') # padded version of feat + for t in range(NUMFRAMES): + delta_feat[t] = numpy.dot(numpy.arange(-N, N+1), padded[t : t+2*N+1]) / denominator # [t : t+2*N+1] == [(N+t)-N : (N+t)+N+1] + return delta_feat + +##### modify for test ###### + def framesig_without_dither_dc_preemphasize(sig, frame_len, frame_step, wintype='hamming', stride_trick=True): """Frame a signal into overlapping frames. @@ -228,7 +393,6 @@ class TestKaldiFE(unittest.TestCase): self.assertEqual(t_nframe.item(), fs.shape[0]) self.assertTrue(np.allclose(t_fs.numpy(), fs)) - def test_stft(self): sr, wav = kaldi.read(self.wavpath)