from typing import Tuple import numpy as np import paddle from paddle import Tensor from paddle import nn from paddle.nn import functional as F def frame(x: Tensor, num_samples: Tensor, win_length: int, hop_length: int, clip: bool = True) -> Tuple[Tensor, Tensor]: """Extract frames from audio. Parameters ---------- x : Tensor Shape (N, T), batched waveform. num_samples : Tensor Shape (N, ), number of samples of each waveform. win_length : int Window length. hop_length : int Number of samples shifted between ajancent frames. clip : bool, optional Whether to clip audio that does not fit into the last frame, by default True Returns ------- frames : Tensor Shape (N, T', win_length). num_frames : Tensor Shape (N, ) number of valid frames """ assert hop_length <= win_length num_frames = (num_samples - win_length) // hop_length padding = (0, 0) if not clip: num_frames += 1 # NOTE: pad hop_length - 1 to the right to ensure that there is at most # one frame dangling to the righe edge padding = (0, hop_length - 1) weight = paddle.eye(win_length).unsqueeze(1) frames = F.conv1d(x.unsqueeze(1), weight, padding=padding, stride=(hop_length, )) return frames, num_frames class STFT(nn.Layer): """A module for computing stft transformation in a differentiable way. Parameters ------------ n_fft : int Number of samples in a frame. hop_length : int Number of samples shifted between adjacent frames. win_length : int Length of the window. clip: bool Whether to clip audio is necesaary. """ def __init__(self, n_fft: int, hop_length: int, win_length: int, window_type: str = None, clip: bool = True): super().__init__() self.hop_length = hop_length self.n_bin = 1 + n_fft // 2 self.n_fft = n_fft self.clip = clip # calculate window if window_type is None: window = np.ones(win_length) elif window_type == "hann": window = np.hanning(win_length) elif window_type == "hamming": window = np.hamming(win_length) else: raise ValueError("Not supported yet!") if win_length < n_fft: window = F.pad(window, (0, n_fft - win_length)) elif win_length > n_fft: window = window[:n_fft] # (n_bins, n_fft) complex kernel_size = min(n_fft, win_length) weight = np.fft.fft(np.eye(n_fft))[:self.n_bin, :kernel_size] w_real = weight.real w_imag = weight.imag # (2 * n_bins, kernel_size) w = np.concatenate([w_real, w_imag], axis=0) w = w * window # (2 * n_bins, 1, kernel_size) # (C_out, C_in, kernel_size) w = np.expand_dims(w, 1) weight = paddle.cast(paddle.to_tensor(w), paddle.get_default_dtype()) self.register_buffer("weight", weight) def forward(self, x: Tensor, num_samples: Tensor) -> Tuple[Tensor, Tensor]: """Compute the stft transform. Parameters ------------ x : Tensor [shape=(B, T)] The input waveform. num_samples : Tensor Number of samples of each waveform. Returns ------------ D : Tensor Shape(N, T', n_bins, 2) Spectrogram. num_frames: Tensor Shape (N,) number of samples of each spectrogram """ num_frames = (num_samples - self.win_length) // self.hop_length padding = (0, 0) if not self.clip: num_frames += 1 padding = (0, self.hop_length - 1) batch_size, _, _ = paddle.shape(x) x = x.unsqueeze(-1) D = F.conv1d(self.weight, x, stride=(self.hop_length, ), padding=padding, data_format="NLC") D = paddle.reshape(D, [batch_size, -1, self.n_bin, 2]) return D, num_frames