using frames and matmul as stft

pull/670/head
Hui Zhang 4 years ago
parent 08f5d153e7
commit 58f540c8a2

@ -32,7 +32,19 @@ def hamm_window(frame_len:int) -> np.ndarray:
win[i] = 0.54 - 0.46 * np.cos(a * i) win[i] = 0.54 - 0.46 * np.cos(a * i)
return win return win
def get_window(wintype:Optional[None, str], winlen:int) -> np.ndarray: def get_window(wintype:Optional[str], winlen:int) -> np.ndarray:
"""get window function
Args:
wintype (Optional[str]): window type.
winlen (int): window length in samples.
Raises:
ValueError: not support window.
Returns:
np.ndarray: window coeffs.
"""
# calculate window # calculate window
if not wintype or wintype == 'rectangular': if not wintype or wintype == 'rectangular':
window = np.ones(winlen) window = np.ones(winlen)

@ -68,17 +68,17 @@ def frames(x: Tensor,
return frames, num_frames return frames, num_frames
def do_dither(signal, dither_value=1.0): def dither(signal, dither_value=1.0):
signal += paddle.normal(shape=signal.shape) * dither_value signal += paddle.normal(shape=signal.shape) * dither_value
return signal return signal
def do_remove_dc_offset(signal): def remove_dc_offset(signal):
signal -= paddle.mean(signal) signal -= paddle.mean(signal)
return signal return signal
def do_preemphasis(signal, coeff=0.97): def preemphasis(signal, coeff=0.97):
"""perform preemphasis on the input signal. """perform preemphasis on the input signal.
:param signal: The signal to filter. :param signal: The signal to filter.
@ -118,18 +118,24 @@ class STFT(nn.Layer):
sr: int, sr: int,
win_length: float, win_length: float,
stride_length: float, stride_length: float,
window_type: str = None, dither:float=1.0,
preemph_coeff:float=0.97,
remove_dc_offset:bool=True,
window_type: str = 'povey',
clip: bool = False): clip: bool = False):
super().__init__() super().__init__()
self.sr = sr self.sr = sr
self.win_length = int(win_length * sr) self.win_length = win_length
self.stride_length = int(stride_length * sr) self.stride_length = stride_length
self.dither = dither
self.preemph_coeff = preemph_coeff
self.remove_dc_offset = remove_dc_offset
self.clip = clip self.clip = clip
self.n_fft = n_fft self.n_fft = n_fft
self.n_bin = 1 + n_fft // 2 self.n_bin = 1 + n_fft // 2
w_real, w_imag, kernel_size = dft_matrix(self.n_fft, self.win_length, self.n_bin) w_real, w_imag, kernel_size = dft_matrix(self.n_fft, int(self.win_length * sr), self.n_bin)
# calculate window # calculate window
window = get_window(window_type, kernel_size) window = get_window(window_type, kernel_size)
@ -137,9 +143,8 @@ class STFT(nn.Layer):
# (2 * n_bins, kernel_size) # (2 * n_bins, kernel_size)
w = np.concatenate([w_real, w_imag], axis=0) w = np.concatenate([w_real, w_imag], axis=0)
w = w * window w = w * window
# (kernel_size, 2 * n_bins)
# (2 * n_bins, 1, kernel_size) # (C_out, C_in, kernel_size) w = np.transpose(w)
w = np.expand_dims(w, 1)
weight = paddle.cast(paddle.to_tensor(w), paddle.get_default_dtype()) weight = paddle.cast(paddle.to_tensor(w), paddle.get_default_dtype())
self.register_buffer("weight", weight) self.register_buffer("weight", weight)
@ -159,22 +164,12 @@ class STFT(nn.Layer):
num_frames: Tensor num_frames: Tensor
Shape (B,) number of samples of each spectrogram Shape (B,) number of samples of each spectrogram
""" """
num_frames = (num_samples - self.win_length) // self.stride_length batch_size = paddle.shape(num_samples)
padding = (0, 0) F, nframe = frames(x, num_samples, self.sr, self.win_length, self.stride_length, clip=self.clip)
if not self.clip: C = paddle.matmul(F, self.weight) # [B, T, K] [K, 2 * n_bins]
num_frames += 1
need_samples = num_frames * self.stride_length + self.win_length
padding = (0, need_samples - num_samples - 1)
batch_size, _ = paddle.shape(x)
x = x.unsqueeze(-1)
C = F.conv1d(x, self.weight,
stride=(self.stride_length, ),
padding=padding,
data_format="NLC")
C = paddle.reshape(C, [batch_size, -1, 2, self.n_bin]) C = paddle.reshape(C, [batch_size, -1, 2, self.n_bin])
C = C.transpose([0, 1, 3, 2]) C = C.transpose([0, 1, 3, 2])
return C, num_frames return C, nframe
def powspec(C:Tensor) -> Tensor: def powspec(C:Tensor) -> Tensor:

@ -369,6 +369,8 @@ class TestKaldiFE(unittest.TestCase):
self.wintype='hamm' self.wintype='hamm'
self.nfilt=40 self.nfilt=40
paddle.set_device('cpu')
def test_read(self): def test_read(self):
import scipy.io.wavfile as wav import scipy.io.wavfile as wav
@ -484,7 +486,7 @@ class TestKaldiFE(unittest.TestCase):
self.assertEqual(t_nframe.item(), stft_win.shape[0]) self.assertEqual(t_nframe.item(), stft_win.shape[0])
self.assertLess(np.sum(t_spec.numpy() - stft_win), 2e4) self.assertLess(np.sum(t_spec.numpy() - stft_win), 5e4)
print(np.sum(t_spec.numpy())) print(np.sum(t_spec.numpy()))
print(np.sum(stft_win)) print(np.sum(stft_win))
self.assertTrue(np.allclose(t_spec.numpy(), stft_win, atol=1e2)) self.assertTrue(np.allclose(t_spec.numpy(), stft_win, atol=1e2))

Loading…
Cancel
Save