147 lines
4.2 KiB
147 lines
4.2 KiB
from typing import Tuple
import numpy as np
import paddle
from paddle import Tensor
from paddle import nn
from paddle.nn import functional as F
def frame(x: Tensor,
num_samples: Tensor,
win_length: int,
hop_length: int,
clip: bool = True) -> Tuple[Tensor, Tensor]:
"""Extract frames from audio.
x : Tensor
Shape (N, T), batched waveform.
num_samples : Tensor
Shape (N, ), number of samples of each waveform.
win_length : int
Window length.
hop_length : int
Number of samples shifted between ajancent frames.
clip : bool, optional
Whether to clip audio that does not fit into the last frame, by
default True
frames : Tensor
Shape (N, T', win_length).
num_frames : Tensor
Shape (N, ) number of valid frames
assert hop_length <= win_length
num_frames = (num_samples - win_length) // hop_length
padding = (0, 0)
if not clip:
num_frames += 1
# NOTE: pad hop_length - 1 to the right to ensure that there is at most
# one frame dangling to the righe edge
padding = (0, hop_length - 1)
weight = paddle.eye(win_length).unsqueeze(1)
frames = F.conv1d(x.unsqueeze(1),
stride=(hop_length, ))
return frames, num_frames
class STFT(nn.Layer):
"""A module for computing stft transformation in a differentiable way.
n_fft : int
Number of samples in a frame.
hop_length : int
Number of samples shifted between adjacent frames.
win_length : int
Length of the window.
clip: bool
Whether to clip audio is necesaary.
def __init__(self,
n_fft: int,
hop_length: int,
win_length: int,
window_type: str = None,
clip: bool = True):
self.hop_length = hop_length
self.n_bin = 1 + n_fft // 2
self.n_fft = n_fft
self.clip = clip
# calculate window
if window_type is None:
window = np.ones(win_length)
elif window_type == "hann":
window = np.hanning(win_length)
elif window_type == "hamming":
window = np.hamming(win_length)
raise ValueError("Not supported yet!")
if win_length < n_fft:
window = F.pad(window, (0, n_fft - win_length))
elif win_length > n_fft:
window = window[:n_fft]
# (n_bins, n_fft) complex
kernel_size = min(n_fft, win_length)
weight = np.fft.fft(np.eye(n_fft))[:self.n_bin, :kernel_size]
w_real = weight.real
w_imag = weight.imag
# (2 * n_bins, kernel_size)
w = np.concatenate([w_real, w_imag], axis=0)
w = w * window
# (2 * n_bins, 1, kernel_size) # (C_out, C_in, kernel_size)
w = np.expand_dims(w, 1)
weight = paddle.cast(paddle.to_tensor(w), paddle.get_default_dtype())
self.register_buffer("weight", weight)
def forward(self, x: Tensor, num_samples: Tensor) -> Tuple[Tensor, Tensor]:
"""Compute the stft transform.
x : Tensor [shape=(B, T)]
The input waveform.
num_samples : Tensor
Number of samples of each waveform.
D : Tensor
Shape(N, T', n_bins, 2) Spectrogram.
num_frames: Tensor
Shape (N,) number of samples of each spectrogram
num_frames = (num_samples - self.win_length) // self.hop_length
padding = (0, 0)
if not self.clip:
num_frames += 1
padding = (0, self.hop_length - 1)
batch_size, _, _ = paddle.shape(x)
x = x.unsqueeze(-1)
D = F.conv1d(self.weight,
stride=(self.hop_length, ),
D = paddle.reshape(D, [batch_size, -1, self.n_bin, 2])
return D, num_frames