|
|
|
@ -71,15 +71,17 @@ class Spectrogram(nn.Layer):
|
|
|
|
|
if win_length is None:
|
|
|
|
|
win_length = n_fft
|
|
|
|
|
|
|
|
|
|
fft_window = get_window(window, win_length, fftbins=True, dtype=dtype)
|
|
|
|
|
self.fft_window = get_window(
|
|
|
|
|
window, win_length, fftbins=True, dtype=dtype)
|
|
|
|
|
self._stft = partial(
|
|
|
|
|
paddle.signal.stft,
|
|
|
|
|
n_fft=n_fft,
|
|
|
|
|
hop_length=hop_length,
|
|
|
|
|
win_length=win_length,
|
|
|
|
|
window=fft_window,
|
|
|
|
|
window=self.fft_window,
|
|
|
|
|
center=center,
|
|
|
|
|
pad_mode=pad_mode)
|
|
|
|
|
self.register_buffer('fft_window', self.fft_window)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
stft = self._stft(x)
|
|
|
|
|