test with dither, remove dc offset, preermphs

pull/670/head
Hui Zhang 4 years ago
parent 42f93b2cb6
commit 220fe2038b

@ -28,8 +28,8 @@ def read(wavpath:str, sr:int = None, start=0, stop=None, dtype='int16', always_2
def write(wavpath:str, wav:np.ndarray, sr:int, dtype='PCM_16'):
sf.write(wavpath, wav, sr, subtype=dtype)
def frames(x: Tensor,
num_samples: Tensor,
sr: int,
@ -51,7 +51,7 @@ def frames(x: Tensor,
stride_length : float
Stride length in ms.
clip : bool, optional
Whether to clip audio that does not fit into the last frame, by
Whether to clip audio that does not fit into the last frame, by
default True
Returns
@ -64,7 +64,7 @@ def frames(x: Tensor,
assert stride_length <= win_length
stride_length = int(stride_length * sr)
win_length = int(win_length * sr)
num_frames = (num_samples - win_length) // stride_length
padding = (0, 0)
if not clip:
@ -92,10 +92,11 @@ def dither(signal:Tensor, dither_value=1.0)->Tensor:
Returns:
Tensor: [B, T, D]
"""
signal += paddle.normal(shape=[1, 1, signal.shape[-1]]) * dither_value
D = paddle.shape(signal)[-1]
signal += paddle.normal(shape=[1, 1, D]) * dither_value
return signal
def remove_dc_offset(signal:Tensor)->Tensor:
"""remove dc.
@ -105,7 +106,7 @@ def remove_dc_offset(signal:Tensor)->Tensor:
Returns:
Tensor: [B, T, D]
"""
signal -= paddle.mean(signal, axis=-1)
signal -= paddle.mean(signal, axis=-1, keepdim=True)
return signal
def preemphasis(signal:Tensor, coeff=0.97)->Tensor:
@ -125,21 +126,21 @@ def preemphasis(signal:Tensor, coeff=0.97)->Tensor:
class STFT(nn.Layer):
"""A module for computing stft transformation in a differentiable way.
"""A module for computing stft transformation in a differentiable way.
http://practicalcryptography.com/miscellaneous/machine-learning/intuitive-guide-discrete-fourier-transform/
Parameters
------------
------------
n_fft : int
Number of samples in a frame.
sr: int
Number of Samplilng rate.
stride_length : float
Number of samples shifted between adjacent frames.
win_length : float
Length of the window.
@ -151,7 +152,7 @@ class STFT(nn.Layer):
sr: int,
win_length: float,
stride_length: float,
dither:float=1.0,
dither:float=0.0,
preemph_coeff:float=0.97,
remove_dc_offset:bool=True,
window_type: str = 'povey',
@ -165,17 +166,17 @@ class STFT(nn.Layer):
self.remove_dc_offset = remove_dc_offset
self.window_type = window_type
self.clip = clip
self.n_fft = n_fft
self.n_bin = 1 + n_fft // 2
w_real, w_imag, kernel_size = dft_matrix(
self.n_fft, int(self.win_length * self.sr), self.n_bin
)
# calculate window
window = get_window(window_type, kernel_size)
# (2 * n_bins, kernel_size)
w = np.concatenate([w_real, w_imag], axis=0)
w = w * window
@ -203,7 +204,7 @@ class STFT(nn.Layer):
batch_size = paddle.shape(num_samples)
F, nframe = frames(x, num_samples, self.sr, self.win_length, self.stride_length, clip=self.clip)
if self.dither:
F = dither(F, dither)
F = dither(F, self.dither)
if self.remove_dc_offset:
F = remove_dc_offset(F)
if self.preemph_coeff:
@ -215,7 +216,7 @@ class STFT(nn.Layer):
def powspec(C:Tensor) -> Tensor:
"""Compute the power spectrum.
"""Compute the power spectrum.
Args:
C (Tensor): [B, T, C, 2]
@ -225,10 +226,10 @@ def powspec(C:Tensor) -> Tensor:
"""
real, imag = paddle.chunk(C, 2, axis=-1)
return paddle.square(real.squeeze(-1)) + paddle.square(imag.squeeze(-1))
def magspec(C: Tensor, eps=1e-10) -> Tensor:
"""Compute the magnitude spectrum.
"""Compute the magnitude spectrum.
Args:
C (Tensor): [B, T, C, 2]

@ -397,20 +397,18 @@ class TestKaldiFE(unittest.TestCase):
self.assertEqual(t_nframe.item(), fs.shape[0])
self.assertTrue(np.allclose(t_fs.numpy(), fs))
def test_stft(self):
sr, wav = kaldi.read(self.wavpath)
wav = wav[:, 0]
for wintype in ['', 'hamm', 'hann', 'povey']:
print(wintype)
self.wintype=wintype
_, stft_c_win, _, _ = stft_with_window(wav, samplerate=sr,
winlen=self.winlen, winstep=self.winstep,
nfilt=self.nfilt, nfft=self.nfft,
lowfreq=self.lowfreq, highfreq=self.highfreq,
wintype=self.wintype)
print('py', stft_c_win.real)
print('py', stft_c_win.imag)
t_wav = paddle.to_tensor([wav], dtype='float32')
t_wavlen = paddle.to_tensor([len(wav)])
@ -420,33 +418,26 @@ class TestKaldiFE(unittest.TestCase):
t_stft = t_stft.astype(stft_c_win.real.dtype)[0]
t_real = t_stft[:, :, 0]
t_imag = t_stft[:, :, 1]
print('pd', t_real.numpy())
print('pd', t_imag.numpy())
self.assertEqual(t_nframe.item(), stft_c_win.real.shape[0])
self.assertLess(np.sum(t_real.numpy()) - np.sum(stft_c_win.real), 1)
print(np.sum(t_real.numpy()))
print(np.sum(stft_c_win.real))
self.assertTrue(np.allclose(t_real.numpy(), stft_c_win.real, atol=1e-1))
self.assertLess(np.sum(t_imag.numpy()) - np.sum(stft_c_win.imag), 1)
print(np.sum(t_imag.numpy()))
print(np.sum(stft_c_win.imag))
self.assertTrue(np.allclose(t_imag.numpy(), stft_c_win.imag, atol=1e-1))
def test_magspec(self):
sr, wav = kaldi.read(self.wavpath)
wav = wav[:, 0]
for wintype in ['', 'hamm', 'hann', 'povey']:
print(wintype)
self.wintype=wintype
stft_win, _, _, _ = stft_with_window(wav, samplerate=sr,
winlen=self.winlen, winstep=self.winstep,
nfilt=self.nfilt, nfft=self.nfft,
lowfreq=self.lowfreq, highfreq=self.highfreq,
wintype=self.wintype)
print('py', stft_win)
t_wav = paddle.to_tensor([wav], dtype='float32')
t_wavlen = paddle.to_tensor([len(wav)])
@ -455,20 +446,39 @@ class TestKaldiFE(unittest.TestCase):
t_stft, t_nframe = stft_class(t_wav, t_wavlen)
t_stft = t_stft.astype(stft_win.dtype)
t_spec = kaldi.magspec(t_stft)[0]
print('pd', t_spec.numpy())
self.assertEqual(t_nframe.item(), stft_win.shape[0])
self.assertLess(np.sum(t_spec.numpy()) - np.sum(stft_win), 1)
print(np.sum(t_spec.numpy()))
print(np.sum(stft_win))
self.assertTrue(np.allclose(t_spec.numpy(), stft_win, atol=1e-1))
def test_magsepc_winprocess(self):
sr, wav = kaldi.read(self.wavpath)
wav = wav[:, 0]
fs, _= framesig(wav, self.winlen*sr, self.winstep*sr,
dither=0.0, preemph=0.97, remove_dc_offset=True, wintype='povey', stride_trick=True)
spec = magspec(fs, self.nfft) # nearly the same until this part
t_wav = paddle.to_tensor([wav], dtype='float32')
t_wavlen = paddle.to_tensor([len(wav)])
stft_class = kaldi.STFT(
self.nfft, sr, self.winlen, self.winstep,
window_type='povey', dither=0.0, preemph_coeff=0.97, remove_dc_offset=True, clip=False)
t_stft, t_nframe = stft_class(t_wav, t_wavlen)
t_stft = t_stft.astype(spec.dtype)
t_spec = kaldi.magspec(t_stft)[0]
self.assertEqual(t_nframe.item(), fs.shape[0])
self.assertLess(np.sum(t_spec.numpy()) - np.sum(spec), 1)
self.assertTrue(np.allclose(t_spec.numpy(), spec, atol=1e-1))
def test_powspec(self):
sr, wav = kaldi.read(self.wavpath)
wav = wav[:, 0]
for wintype in ['', 'hamm', 'hann', 'povey']:
print(wintype)
self.wintype=wintype
stft_win, _, _, _ = stft_with_window(wav, samplerate=sr,
winlen=self.winlen, winstep=self.winstep,
@ -476,7 +486,6 @@ class TestKaldiFE(unittest.TestCase):
lowfreq=self.lowfreq, highfreq=self.highfreq,
wintype=self.wintype)
stft_win = np.square(stft_win)
print('py', stft_win)
t_wav = paddle.to_tensor([wav], dtype='float32')
t_wavlen = paddle.to_tensor([len(wav)])
@ -485,13 +494,10 @@ class TestKaldiFE(unittest.TestCase):
t_stft, t_nframe = stft_class(t_wav, t_wavlen)
t_stft = t_stft.astype(stft_win.dtype)
t_spec = kaldi.powspec(t_stft)[0]
print('pd', t_spec.numpy())
self.assertEqual(t_nframe.item(), stft_win.shape[0])
self.assertLess(np.sum(t_spec.numpy() - stft_win), 5e4)
print(np.sum(t_spec.numpy()))
print(np.sum(stft_win))
self.assertTrue(np.allclose(t_spec.numpy(), stft_win, atol=1e2))

Loading…
Cancel
Save