|
|
@ -165,9 +165,13 @@ class STFT(torch.nn.Module):
|
|
|
|
# self.kernel_cos = torch.nn.Parameter(self.kernel_cos, requires_grad=self.trainable)
|
|
|
|
# self.kernel_cos = torch.nn.Parameter(self.kernel_cos, requires_grad=self.trainable)
|
|
|
|
|
|
|
|
|
|
|
|
# Applying window functions to the Fourier kernels
|
|
|
|
# Applying window functions to the Fourier kernels
|
|
|
|
|
|
|
|
if window:
|
|
|
|
window_mask = torch.tensor(window_mask)
|
|
|
|
window_mask = torch.tensor(window_mask)
|
|
|
|
wsin = kernel_sin * window_mask
|
|
|
|
wsin = kernel_sin * window_mask
|
|
|
|
wcos = kernel_cos * window_mask
|
|
|
|
wcos = kernel_cos * window_mask
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
wsin = kernel_sin
|
|
|
|
|
|
|
|
wcos = kernel_cos
|
|
|
|
|
|
|
|
|
|
|
|
if self.trainable==False:
|
|
|
|
if self.trainable==False:
|
|
|
|
self.register_buffer('wsin', wsin)
|
|
|
|
self.register_buffer('wsin', wsin)
|
|
|
@ -179,7 +183,6 @@ class STFT(torch.nn.Module):
|
|
|
|
self.register_parameter('wsin', wsin)
|
|
|
|
self.register_parameter('wsin', wsin)
|
|
|
|
self.register_parameter('wcos', wcos)
|
|
|
|
self.register_parameter('wcos', wcos)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Prepare the shape of window mask so that it can be used later in inverse
|
|
|
|
# Prepare the shape of window mask so that it can be used later in inverse
|
|
|
|
self.register_buffer('window_mask', window_mask.unsqueeze(0).unsqueeze(-1))
|
|
|
|
self.register_buffer('window_mask', window_mask.unsqueeze(0).unsqueeze(-1))
|
|
|
|
|
|
|
|
|
|
|
|