From 220fe2038b34b7c0f35c7f6265b43cf3a2015f94 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 15 Jun 2021 14:45:30 +0000 Subject: [PATCH] test with dither, remove dc offset, preermphs --- third_party/paddle_audio/frontend/kaldi.py | 49 ++++++++++--------- .../paddle_audio/frontend/kaldi_test.py | 44 ++++++++++------- 2 files changed, 50 insertions(+), 43 deletions(-) diff --git a/third_party/paddle_audio/frontend/kaldi.py b/third_party/paddle_audio/frontend/kaldi.py index fa89a80e9..154148081 100644 --- a/third_party/paddle_audio/frontend/kaldi.py +++ b/third_party/paddle_audio/frontend/kaldi.py @@ -28,8 +28,8 @@ def read(wavpath:str, sr:int = None, start=0, stop=None, dtype='int16', always_2 def write(wavpath:str, wav:np.ndarray, sr:int, dtype='PCM_16'): sf.write(wavpath, wav, sr, subtype=dtype) - - + + def frames(x: Tensor, num_samples: Tensor, sr: int, @@ -51,7 +51,7 @@ def frames(x: Tensor, stride_length : float Stride length in ms. clip : bool, optional - Whether to clip audio that does not fit into the last frame, by + Whether to clip audio that does not fit into the last frame, by default True Returns @@ -64,7 +64,7 @@ def frames(x: Tensor, assert stride_length <= win_length stride_length = int(stride_length * sr) win_length = int(win_length * sr) - + num_frames = (num_samples - win_length) // stride_length padding = (0, 0) if not clip: @@ -92,10 +92,11 @@ def dither(signal:Tensor, dither_value=1.0)->Tensor: Returns: Tensor: [B, T, D] """ - signal += paddle.normal(shape=[1, 1, signal.shape[-1]]) * dither_value + D = paddle.shape(signal)[-1] + signal += paddle.normal(shape=[1, 1, D]) * dither_value return signal - - + + def remove_dc_offset(signal:Tensor)->Tensor: """remove dc. @@ -105,7 +106,7 @@ def remove_dc_offset(signal:Tensor)->Tensor: Returns: Tensor: [B, T, D] """ - signal -= paddle.mean(signal, axis=-1) + signal -= paddle.mean(signal, axis=-1, keepdim=True) return signal def preemphasis(signal:Tensor, coeff=0.97)->Tensor: @@ -125,21 +126,21 @@ def preemphasis(signal:Tensor, coeff=0.97)->Tensor: class STFT(nn.Layer): - """A module for computing stft transformation in a differentiable way. - + """A module for computing stft transformation in a differentiable way. + http://practicalcryptography.com/miscellaneous/machine-learning/intuitive-guide-discrete-fourier-transform/ - + Parameters - ------------ + ------------ n_fft : int Number of samples in a frame. - + sr: int Number of Samplilng rate. - + stride_length : float Number of samples shifted between adjacent frames. - + win_length : float Length of the window. @@ -151,7 +152,7 @@ class STFT(nn.Layer): sr: int, win_length: float, stride_length: float, - dither:float=1.0, + dither:float=0.0, preemph_coeff:float=0.97, remove_dc_offset:bool=True, window_type: str = 'povey', @@ -165,17 +166,17 @@ class STFT(nn.Layer): self.remove_dc_offset = remove_dc_offset self.window_type = window_type self.clip = clip - + self.n_fft = n_fft self.n_bin = 1 + n_fft // 2 w_real, w_imag, kernel_size = dft_matrix( self.n_fft, int(self.win_length * self.sr), self.n_bin ) - + # calculate window window = get_window(window_type, kernel_size) - + # (2 * n_bins, kernel_size) w = np.concatenate([w_real, w_imag], axis=0) w = w * window @@ -203,7 +204,7 @@ class STFT(nn.Layer): batch_size = paddle.shape(num_samples) F, nframe = frames(x, num_samples, self.sr, self.win_length, self.stride_length, clip=self.clip) if self.dither: - F = dither(F, dither) + F = dither(F, self.dither) if self.remove_dc_offset: F = remove_dc_offset(F) if self.preemph_coeff: @@ -215,7 +216,7 @@ class STFT(nn.Layer): def powspec(C:Tensor) -> Tensor: - """Compute the power spectrum. + """Compute the power spectrum. Args: C (Tensor): [B, T, C, 2] @@ -225,10 +226,10 @@ def powspec(C:Tensor) -> Tensor: """ real, imag = paddle.chunk(C, 2, axis=-1) return paddle.square(real.squeeze(-1)) + paddle.square(imag.squeeze(-1)) - - + + def magspec(C: Tensor, eps=1e-10) -> Tensor: - """Compute the magnitude spectrum. + """Compute the magnitude spectrum. Args: C (Tensor): [B, T, C, 2] diff --git a/third_party/paddle_audio/frontend/kaldi_test.py b/third_party/paddle_audio/frontend/kaldi_test.py index 7b9788ee8..2f729244a 100644 --- a/third_party/paddle_audio/frontend/kaldi_test.py +++ b/third_party/paddle_audio/frontend/kaldi_test.py @@ -397,20 +397,18 @@ class TestKaldiFE(unittest.TestCase): self.assertEqual(t_nframe.item(), fs.shape[0]) self.assertTrue(np.allclose(t_fs.numpy(), fs)) + def test_stft(self): sr, wav = kaldi.read(self.wavpath) wav = wav[:, 0] for wintype in ['', 'hamm', 'hann', 'povey']: - print(wintype) self.wintype=wintype _, stft_c_win, _, _ = stft_with_window(wav, samplerate=sr, winlen=self.winlen, winstep=self.winstep, nfilt=self.nfilt, nfft=self.nfft, lowfreq=self.lowfreq, highfreq=self.highfreq, wintype=self.wintype) - print('py', stft_c_win.real) - print('py', stft_c_win.imag) t_wav = paddle.to_tensor([wav], dtype='float32') t_wavlen = paddle.to_tensor([len(wav)]) @@ -420,33 +418,26 @@ class TestKaldiFE(unittest.TestCase): t_stft = t_stft.astype(stft_c_win.real.dtype)[0] t_real = t_stft[:, :, 0] t_imag = t_stft[:, :, 1] - print('pd', t_real.numpy()) - print('pd', t_imag.numpy()) self.assertEqual(t_nframe.item(), stft_c_win.real.shape[0]) self.assertLess(np.sum(t_real.numpy()) - np.sum(stft_c_win.real), 1) - print(np.sum(t_real.numpy())) - print(np.sum(stft_c_win.real)) self.assertTrue(np.allclose(t_real.numpy(), stft_c_win.real, atol=1e-1)) self.assertLess(np.sum(t_imag.numpy()) - np.sum(stft_c_win.imag), 1) - print(np.sum(t_imag.numpy())) - print(np.sum(stft_c_win.imag)) self.assertTrue(np.allclose(t_imag.numpy(), stft_c_win.imag, atol=1e-1)) + def test_magspec(self): sr, wav = kaldi.read(self.wavpath) wav = wav[:, 0] for wintype in ['', 'hamm', 'hann', 'povey']: - print(wintype) self.wintype=wintype stft_win, _, _, _ = stft_with_window(wav, samplerate=sr, winlen=self.winlen, winstep=self.winstep, nfilt=self.nfilt, nfft=self.nfft, lowfreq=self.lowfreq, highfreq=self.highfreq, wintype=self.wintype) - print('py', stft_win) t_wav = paddle.to_tensor([wav], dtype='float32') t_wavlen = paddle.to_tensor([len(wav)]) @@ -455,20 +446,39 @@ class TestKaldiFE(unittest.TestCase): t_stft, t_nframe = stft_class(t_wav, t_wavlen) t_stft = t_stft.astype(stft_win.dtype) t_spec = kaldi.magspec(t_stft)[0] - print('pd', t_spec.numpy()) self.assertEqual(t_nframe.item(), stft_win.shape[0]) self.assertLess(np.sum(t_spec.numpy()) - np.sum(stft_win), 1) - print(np.sum(t_spec.numpy())) - print(np.sum(stft_win)) self.assertTrue(np.allclose(t_spec.numpy(), stft_win, atol=1e-1)) + + def test_magsepc_winprocess(self): + sr, wav = kaldi.read(self.wavpath) + wav = wav[:, 0] + fs, _= framesig(wav, self.winlen*sr, self.winstep*sr, + dither=0.0, preemph=0.97, remove_dc_offset=True, wintype='povey', stride_trick=True) + spec = magspec(fs, self.nfft) # nearly the same until this part + + t_wav = paddle.to_tensor([wav], dtype='float32') + t_wavlen = paddle.to_tensor([len(wav)]) + stft_class = kaldi.STFT( + self.nfft, sr, self.winlen, self.winstep, + window_type='povey', dither=0.0, preemph_coeff=0.97, remove_dc_offset=True, clip=False) + t_stft, t_nframe = stft_class(t_wav, t_wavlen) + t_stft = t_stft.astype(spec.dtype) + t_spec = kaldi.magspec(t_stft)[0] + + self.assertEqual(t_nframe.item(), fs.shape[0]) + + self.assertLess(np.sum(t_spec.numpy()) - np.sum(spec), 1) + self.assertTrue(np.allclose(t_spec.numpy(), spec, atol=1e-1)) + + def test_powspec(self): sr, wav = kaldi.read(self.wavpath) wav = wav[:, 0] for wintype in ['', 'hamm', 'hann', 'povey']: - print(wintype) self.wintype=wintype stft_win, _, _, _ = stft_with_window(wav, samplerate=sr, winlen=self.winlen, winstep=self.winstep, @@ -476,7 +486,6 @@ class TestKaldiFE(unittest.TestCase): lowfreq=self.lowfreq, highfreq=self.highfreq, wintype=self.wintype) stft_win = np.square(stft_win) - print('py', stft_win) t_wav = paddle.to_tensor([wav], dtype='float32') t_wavlen = paddle.to_tensor([len(wav)]) @@ -485,13 +494,10 @@ class TestKaldiFE(unittest.TestCase): t_stft, t_nframe = stft_class(t_wav, t_wavlen) t_stft = t_stft.astype(stft_win.dtype) t_spec = kaldi.powspec(t_stft)[0] - print('pd', t_spec.numpy()) self.assertEqual(t_nframe.item(), stft_win.shape[0]) self.assertLess(np.sum(t_spec.numpy() - stft_win), 5e4) - print(np.sum(t_spec.numpy())) - print(np.sum(stft_win)) self.assertTrue(np.allclose(t_spec.numpy(), stft_win, atol=1e2))