test with dither, remove dc offset, preermphs

pull/670/head
Hui Zhang 4 years ago
parent 42f93b2cb6
commit 220fe2038b

@ -28,8 +28,8 @@ def read(wavpath:str, sr:int = None, start=0, stop=None, dtype='int16', always_2
def write(wavpath:str, wav:np.ndarray, sr:int, dtype='PCM_16'): def write(wavpath:str, wav:np.ndarray, sr:int, dtype='PCM_16'):
sf.write(wavpath, wav, sr, subtype=dtype) sf.write(wavpath, wav, sr, subtype=dtype)
def frames(x: Tensor, def frames(x: Tensor,
num_samples: Tensor, num_samples: Tensor,
sr: int, sr: int,
@ -51,7 +51,7 @@ def frames(x: Tensor,
stride_length : float stride_length : float
Stride length in ms. Stride length in ms.
clip : bool, optional clip : bool, optional
Whether to clip audio that does not fit into the last frame, by Whether to clip audio that does not fit into the last frame, by
default True default True
Returns Returns
@ -64,7 +64,7 @@ def frames(x: Tensor,
assert stride_length <= win_length assert stride_length <= win_length
stride_length = int(stride_length * sr) stride_length = int(stride_length * sr)
win_length = int(win_length * sr) win_length = int(win_length * sr)
num_frames = (num_samples - win_length) // stride_length num_frames = (num_samples - win_length) // stride_length
padding = (0, 0) padding = (0, 0)
if not clip: if not clip:
@ -92,10 +92,11 @@ def dither(signal:Tensor, dither_value=1.0)->Tensor:
Returns: Returns:
Tensor: [B, T, D] Tensor: [B, T, D]
""" """
signal += paddle.normal(shape=[1, 1, signal.shape[-1]]) * dither_value D = paddle.shape(signal)[-1]
signal += paddle.normal(shape=[1, 1, D]) * dither_value
return signal return signal
def remove_dc_offset(signal:Tensor)->Tensor: def remove_dc_offset(signal:Tensor)->Tensor:
"""remove dc. """remove dc.
@ -105,7 +106,7 @@ def remove_dc_offset(signal:Tensor)->Tensor:
Returns: Returns:
Tensor: [B, T, D] Tensor: [B, T, D]
""" """
signal -= paddle.mean(signal, axis=-1) signal -= paddle.mean(signal, axis=-1, keepdim=True)
return signal return signal
def preemphasis(signal:Tensor, coeff=0.97)->Tensor: def preemphasis(signal:Tensor, coeff=0.97)->Tensor:
@ -125,21 +126,21 @@ def preemphasis(signal:Tensor, coeff=0.97)->Tensor:
class STFT(nn.Layer): class STFT(nn.Layer):
"""A module for computing stft transformation in a differentiable way. """A module for computing stft transformation in a differentiable way.
http://practicalcryptography.com/miscellaneous/machine-learning/intuitive-guide-discrete-fourier-transform/ http://practicalcryptography.com/miscellaneous/machine-learning/intuitive-guide-discrete-fourier-transform/
Parameters Parameters
------------ ------------
n_fft : int n_fft : int
Number of samples in a frame. Number of samples in a frame.
sr: int sr: int
Number of Samplilng rate. Number of Samplilng rate.
stride_length : float stride_length : float
Number of samples shifted between adjacent frames. Number of samples shifted between adjacent frames.
win_length : float win_length : float
Length of the window. Length of the window.
@ -151,7 +152,7 @@ class STFT(nn.Layer):
sr: int, sr: int,
win_length: float, win_length: float,
stride_length: float, stride_length: float,
dither:float=1.0, dither:float=0.0,
preemph_coeff:float=0.97, preemph_coeff:float=0.97,
remove_dc_offset:bool=True, remove_dc_offset:bool=True,
window_type: str = 'povey', window_type: str = 'povey',
@ -165,17 +166,17 @@ class STFT(nn.Layer):
self.remove_dc_offset = remove_dc_offset self.remove_dc_offset = remove_dc_offset
self.window_type = window_type self.window_type = window_type
self.clip = clip self.clip = clip
self.n_fft = n_fft self.n_fft = n_fft
self.n_bin = 1 + n_fft // 2 self.n_bin = 1 + n_fft // 2
w_real, w_imag, kernel_size = dft_matrix( w_real, w_imag, kernel_size = dft_matrix(
self.n_fft, int(self.win_length * self.sr), self.n_bin self.n_fft, int(self.win_length * self.sr), self.n_bin
) )
# calculate window # calculate window
window = get_window(window_type, kernel_size) window = get_window(window_type, kernel_size)
# (2 * n_bins, kernel_size) # (2 * n_bins, kernel_size)
w = np.concatenate([w_real, w_imag], axis=0) w = np.concatenate([w_real, w_imag], axis=0)
w = w * window w = w * window
@ -203,7 +204,7 @@ class STFT(nn.Layer):
batch_size = paddle.shape(num_samples) batch_size = paddle.shape(num_samples)
F, nframe = frames(x, num_samples, self.sr, self.win_length, self.stride_length, clip=self.clip) F, nframe = frames(x, num_samples, self.sr, self.win_length, self.stride_length, clip=self.clip)
if self.dither: if self.dither:
F = dither(F, dither) F = dither(F, self.dither)
if self.remove_dc_offset: if self.remove_dc_offset:
F = remove_dc_offset(F) F = remove_dc_offset(F)
if self.preemph_coeff: if self.preemph_coeff:
@ -215,7 +216,7 @@ class STFT(nn.Layer):
def powspec(C:Tensor) -> Tensor: def powspec(C:Tensor) -> Tensor:
"""Compute the power spectrum. """Compute the power spectrum.
Args: Args:
C (Tensor): [B, T, C, 2] C (Tensor): [B, T, C, 2]
@ -225,10 +226,10 @@ def powspec(C:Tensor) -> Tensor:
""" """
real, imag = paddle.chunk(C, 2, axis=-1) real, imag = paddle.chunk(C, 2, axis=-1)
return paddle.square(real.squeeze(-1)) + paddle.square(imag.squeeze(-1)) return paddle.square(real.squeeze(-1)) + paddle.square(imag.squeeze(-1))
def magspec(C: Tensor, eps=1e-10) -> Tensor: def magspec(C: Tensor, eps=1e-10) -> Tensor:
"""Compute the magnitude spectrum. """Compute the magnitude spectrum.
Args: Args:
C (Tensor): [B, T, C, 2] C (Tensor): [B, T, C, 2]

@ -397,20 +397,18 @@ class TestKaldiFE(unittest.TestCase):
self.assertEqual(t_nframe.item(), fs.shape[0]) self.assertEqual(t_nframe.item(), fs.shape[0])
self.assertTrue(np.allclose(t_fs.numpy(), fs)) self.assertTrue(np.allclose(t_fs.numpy(), fs))
def test_stft(self): def test_stft(self):
sr, wav = kaldi.read(self.wavpath) sr, wav = kaldi.read(self.wavpath)
wav = wav[:, 0] wav = wav[:, 0]
for wintype in ['', 'hamm', 'hann', 'povey']: for wintype in ['', 'hamm', 'hann', 'povey']:
print(wintype)
self.wintype=wintype self.wintype=wintype
_, stft_c_win, _, _ = stft_with_window(wav, samplerate=sr, _, stft_c_win, _, _ = stft_with_window(wav, samplerate=sr,
winlen=self.winlen, winstep=self.winstep, winlen=self.winlen, winstep=self.winstep,
nfilt=self.nfilt, nfft=self.nfft, nfilt=self.nfilt, nfft=self.nfft,
lowfreq=self.lowfreq, highfreq=self.highfreq, lowfreq=self.lowfreq, highfreq=self.highfreq,
wintype=self.wintype) wintype=self.wintype)
print('py', stft_c_win.real)
print('py', stft_c_win.imag)
t_wav = paddle.to_tensor([wav], dtype='float32') t_wav = paddle.to_tensor([wav], dtype='float32')
t_wavlen = paddle.to_tensor([len(wav)]) t_wavlen = paddle.to_tensor([len(wav)])
@ -420,33 +418,26 @@ class TestKaldiFE(unittest.TestCase):
t_stft = t_stft.astype(stft_c_win.real.dtype)[0] t_stft = t_stft.astype(stft_c_win.real.dtype)[0]
t_real = t_stft[:, :, 0] t_real = t_stft[:, :, 0]
t_imag = t_stft[:, :, 1] t_imag = t_stft[:, :, 1]
print('pd', t_real.numpy())
print('pd', t_imag.numpy())
self.assertEqual(t_nframe.item(), stft_c_win.real.shape[0]) self.assertEqual(t_nframe.item(), stft_c_win.real.shape[0])
self.assertLess(np.sum(t_real.numpy()) - np.sum(stft_c_win.real), 1) self.assertLess(np.sum(t_real.numpy()) - np.sum(stft_c_win.real), 1)
print(np.sum(t_real.numpy()))
print(np.sum(stft_c_win.real))
self.assertTrue(np.allclose(t_real.numpy(), stft_c_win.real, atol=1e-1)) self.assertTrue(np.allclose(t_real.numpy(), stft_c_win.real, atol=1e-1))
self.assertLess(np.sum(t_imag.numpy()) - np.sum(stft_c_win.imag), 1) self.assertLess(np.sum(t_imag.numpy()) - np.sum(stft_c_win.imag), 1)
print(np.sum(t_imag.numpy()))
print(np.sum(stft_c_win.imag))
self.assertTrue(np.allclose(t_imag.numpy(), stft_c_win.imag, atol=1e-1)) self.assertTrue(np.allclose(t_imag.numpy(), stft_c_win.imag, atol=1e-1))
def test_magspec(self): def test_magspec(self):
sr, wav = kaldi.read(self.wavpath) sr, wav = kaldi.read(self.wavpath)
wav = wav[:, 0] wav = wav[:, 0]
for wintype in ['', 'hamm', 'hann', 'povey']: for wintype in ['', 'hamm', 'hann', 'povey']:
print(wintype)
self.wintype=wintype self.wintype=wintype
stft_win, _, _, _ = stft_with_window(wav, samplerate=sr, stft_win, _, _, _ = stft_with_window(wav, samplerate=sr,
winlen=self.winlen, winstep=self.winstep, winlen=self.winlen, winstep=self.winstep,
nfilt=self.nfilt, nfft=self.nfft, nfilt=self.nfilt, nfft=self.nfft,
lowfreq=self.lowfreq, highfreq=self.highfreq, lowfreq=self.lowfreq, highfreq=self.highfreq,
wintype=self.wintype) wintype=self.wintype)
print('py', stft_win)
t_wav = paddle.to_tensor([wav], dtype='float32') t_wav = paddle.to_tensor([wav], dtype='float32')
t_wavlen = paddle.to_tensor([len(wav)]) t_wavlen = paddle.to_tensor([len(wav)])
@ -455,20 +446,39 @@ class TestKaldiFE(unittest.TestCase):
t_stft, t_nframe = stft_class(t_wav, t_wavlen) t_stft, t_nframe = stft_class(t_wav, t_wavlen)
t_stft = t_stft.astype(stft_win.dtype) t_stft = t_stft.astype(stft_win.dtype)
t_spec = kaldi.magspec(t_stft)[0] t_spec = kaldi.magspec(t_stft)[0]
print('pd', t_spec.numpy())
self.assertEqual(t_nframe.item(), stft_win.shape[0]) self.assertEqual(t_nframe.item(), stft_win.shape[0])
self.assertLess(np.sum(t_spec.numpy()) - np.sum(stft_win), 1) self.assertLess(np.sum(t_spec.numpy()) - np.sum(stft_win), 1)
print(np.sum(t_spec.numpy()))
print(np.sum(stft_win))
self.assertTrue(np.allclose(t_spec.numpy(), stft_win, atol=1e-1)) self.assertTrue(np.allclose(t_spec.numpy(), stft_win, atol=1e-1))
def test_magsepc_winprocess(self):
sr, wav = kaldi.read(self.wavpath)
wav = wav[:, 0]
fs, _= framesig(wav, self.winlen*sr, self.winstep*sr,
dither=0.0, preemph=0.97, remove_dc_offset=True, wintype='povey', stride_trick=True)
spec = magspec(fs, self.nfft) # nearly the same until this part
t_wav = paddle.to_tensor([wav], dtype='float32')
t_wavlen = paddle.to_tensor([len(wav)])
stft_class = kaldi.STFT(
self.nfft, sr, self.winlen, self.winstep,
window_type='povey', dither=0.0, preemph_coeff=0.97, remove_dc_offset=True, clip=False)
t_stft, t_nframe = stft_class(t_wav, t_wavlen)
t_stft = t_stft.astype(spec.dtype)
t_spec = kaldi.magspec(t_stft)[0]
self.assertEqual(t_nframe.item(), fs.shape[0])
self.assertLess(np.sum(t_spec.numpy()) - np.sum(spec), 1)
self.assertTrue(np.allclose(t_spec.numpy(), spec, atol=1e-1))
def test_powspec(self): def test_powspec(self):
sr, wav = kaldi.read(self.wavpath) sr, wav = kaldi.read(self.wavpath)
wav = wav[:, 0] wav = wav[:, 0]
for wintype in ['', 'hamm', 'hann', 'povey']: for wintype in ['', 'hamm', 'hann', 'povey']:
print(wintype)
self.wintype=wintype self.wintype=wintype
stft_win, _, _, _ = stft_with_window(wav, samplerate=sr, stft_win, _, _, _ = stft_with_window(wav, samplerate=sr,
winlen=self.winlen, winstep=self.winstep, winlen=self.winlen, winstep=self.winstep,
@ -476,7 +486,6 @@ class TestKaldiFE(unittest.TestCase):
lowfreq=self.lowfreq, highfreq=self.highfreq, lowfreq=self.lowfreq, highfreq=self.highfreq,
wintype=self.wintype) wintype=self.wintype)
stft_win = np.square(stft_win) stft_win = np.square(stft_win)
print('py', stft_win)
t_wav = paddle.to_tensor([wav], dtype='float32') t_wav = paddle.to_tensor([wav], dtype='float32')
t_wavlen = paddle.to_tensor([len(wav)]) t_wavlen = paddle.to_tensor([len(wav)])
@ -485,13 +494,10 @@ class TestKaldiFE(unittest.TestCase):
t_stft, t_nframe = stft_class(t_wav, t_wavlen) t_stft, t_nframe = stft_class(t_wav, t_wavlen)
t_stft = t_stft.astype(stft_win.dtype) t_stft = t_stft.astype(stft_win.dtype)
t_spec = kaldi.powspec(t_stft)[0] t_spec = kaldi.powspec(t_stft)[0]
print('pd', t_spec.numpy())
self.assertEqual(t_nframe.item(), stft_win.shape[0]) self.assertEqual(t_nframe.item(), stft_win.shape[0])
self.assertLess(np.sum(t_spec.numpy() - stft_win), 5e4) self.assertLess(np.sum(t_spec.numpy() - stft_win), 5e4)
print(np.sum(t_spec.numpy()))
print(np.sum(stft_win))
self.assertTrue(np.allclose(t_spec.numpy(), stft_win, atol=1e2)) self.assertTrue(np.allclose(t_spec.numpy(), stft_win, atol=1e2))

Loading…
Cancel
Save