diff --git a/third_party/paddle_audio/frontend.py b/third_party/paddle_audio/frontend.py new file mode 100644 index 000000000..1b337732e --- /dev/null +++ b/third_party/paddle_audio/frontend.py @@ -0,0 +1,146 @@ +from typing import Tuple +import numpy as np +import paddle +from paddle import Tensor +from paddle import nn +from paddle.nn import functional as F + + +def frame(x: Tensor, + num_samples: Tensor, + win_length: int, + hop_length: int, + clip: bool = True) -> Tuple[Tensor, Tensor]: + """Extract frames from audio. + + Parameters + ---------- + x : Tensor + Shape (N, T), batched waveform. + num_samples : Tensor + Shape (N, ), number of samples of each waveform. + win_length : int + Window length. + hop_length : int + Number of samples shifted between ajancent frames. + clip : bool, optional + Whether to clip audio that does not fit into the last frame, by + default True + + Returns + ------- + frames : Tensor + Shape (N, T', win_length). + num_frames : Tensor + Shape (N, ) number of valid frames + """ + assert hop_length <= win_length + num_frames = (num_samples - win_length) // hop_length + padding = (0, 0) + if not clip: + num_frames += 1 + # NOTE: pad hop_length - 1 to the right to ensure that there is at most + # one frame dangling to the righe edge + padding = (0, hop_length - 1) + + weight = paddle.eye(win_length).unsqueeze(1) + + frames = F.conv1d(x.unsqueeze(1), + weight, + padding=padding, + stride=(hop_length, )) + return frames, num_frames + + +class STFT(nn.Layer): + """A module for computing stft transformation in a differentiable way. + + Parameters + ------------ + n_fft : int + Number of samples in a frame. + + hop_length : int + Number of samples shifted between adjacent frames. + + win_length : int + Length of the window. + + clip: bool + Whether to clip audio is necesaary. + """ + def __init__(self, + n_fft: int, + hop_length: int, + win_length: int, + window_type: str = None, + clip: bool = True): + super().__init__() + + self.hop_length = hop_length + self.n_bin = 1 + n_fft // 2 + self.n_fft = n_fft + self.clip = clip + + # calculate window + if window_type is None: + window = np.ones(win_length) + elif window_type == "hann": + window = np.hanning(win_length) + elif window_type == "hamming": + window = np.hamming(win_length) + else: + raise ValueError("Not supported yet!") + + if win_length < n_fft: + window = F.pad(window, (0, n_fft - win_length)) + elif win_length > n_fft: + window = window[:n_fft] + + # (n_bins, n_fft) complex + kernel_size = min(n_fft, win_length) + weight = np.fft.fft(np.eye(n_fft))[:self.n_bin, :kernel_size] + w_real = weight.real + w_imag = weight.imag + + # (2 * n_bins, kernel_size) + w = np.concatenate([w_real, w_imag], axis=0) + w = w * window + + # (2 * n_bins, 1, kernel_size) # (C_out, C_in, kernel_size) + w = np.expand_dims(w, 1) + weight = paddle.cast(paddle.to_tensor(w), paddle.get_default_dtype()) + self.register_buffer("weight", weight) + + def forward(self, x: Tensor, num_samples: Tensor) -> Tuple[Tensor, Tensor]: + """Compute the stft transform. + Parameters + ------------ + x : Tensor [shape=(B, T)] + The input waveform. + num_samples : Tensor + Number of samples of each waveform. + Returns + ------------ + D : Tensor + Shape(N, T', n_bins, 2) Spectrogram. + + num_frames: Tensor + Shape (N,) number of samples of each spectrogram + """ + num_frames = (num_samples - self.win_length) // self.hop_length + padding = (0, 0) + if not self.clip: + num_frames += 1 + padding = (0, self.hop_length - 1) + + batch_size, _, _ = paddle.shape(x) + x = x.unsqueeze(-1) + D = F.conv1d(self.weight, + x, + stride=(self.hop_length, ), + padding=padding, + data_format="NLC") + D = paddle.reshape(D, [batch_size, -1, self.n_bin, 2]) + return D, num_frames +