From 831cadacc74b99a425d8b0d151863fca21f188a4 Mon Sep 17 00:00:00 2001 From: KP <109694228@qq.com> Date: Thu, 17 Mar 2022 21:17:41 +0800 Subject: [PATCH] Add paddleaudio doc. --- paddleaudio/paddleaudio/compliance/librosa.py | 6 +- paddleaudio/paddleaudio/features/layers.py | 8 +- .../paddleaudio/functional/functional.py | 139 ++++++------- paddleaudio/paddleaudio/functional/window.py | 186 +++++------------- paddleaudio/paddleaudio/metric/dtw.py | 4 +- paddlespeech/cli/executor.py | 3 +- 6 files changed, 127 insertions(+), 219 deletions(-) diff --git a/paddleaudio/paddleaudio/compliance/librosa.py b/paddleaudio/paddleaudio/compliance/librosa.py index 1342b251..740584ca 100644 --- a/paddleaudio/paddleaudio/compliance/librosa.py +++ b/paddleaudio/paddleaudio/compliance/librosa.py @@ -403,11 +403,11 @@ def power_to_db(spect: np.ndarray, ref: float=1.0, amin: float=1e-10, top_db: Optional[float]=80.0) -> np.ndarray: - """Convert a power spectrogram (amplitude squared) to decibel (dB) units. This computes the scaling `10 * log10(spect / ref)` in a numerically stable way. + """Convert a power spectrogram (amplitude squared) to decibel (dB) units. The function computes the scaling `10 * log10(x / ref)` in a numerically stable way. Args: spect (np.ndarray): STFT power spectrogram of an input waveform. - ref (float, optional): Scaling factor of spectrogram. Defaults to 1.0. + ref (float, optional): The reference value. If smaller than 1.0, the db level of the signal will be pulled up accordingly. Otherwise, the db level is pushed down. Defaults to 1.0. amin (float, optional): Minimum threshold. Defaults to 1e-10. top_db (Optional[float], optional): Threshold the output at `top_db` below the peak. Defaults to 80.0. @@ -513,7 +513,7 @@ def melspectrogram(x: np.ndarray, pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to "reflect". power (float, optional): Exponent for the magnitude melspectrogram. Defaults to 2.0. to_db (bool, optional): Enable db scale. Defaults to True. - ref (float, optional): Scaling factor of spectrogram. Defaults to 1.0. + ref (float, optional): The reference value. If smaller than 1.0, the db level of the signal will be pulled up accordingly. Otherwise, the db level is pushed down. Defaults to 1.0. amin (float, optional): Minimum threshold. Defaults to 1e-10. top_db (Optional[float], optional): Threshold the output at `top_db` below the peak. Defaults to None. diff --git a/paddleaudio/paddleaudio/features/layers.py b/paddleaudio/paddleaudio/features/layers.py index ad990b78..09037255 100644 --- a/paddleaudio/paddleaudio/features/layers.py +++ b/paddleaudio/paddleaudio/features/layers.py @@ -40,7 +40,7 @@ class Spectrogram(nn.Layer): n_fft (int, optional): The number of frequency components of the discrete Fourier transform. Defaults to 512. hop_length (Optional[int], optional): The hop length of the short time FFT. If `None`, it is set to `win_length//4`. Defaults to None. win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None. - window (str, optional): The window function applied to the single before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'. + window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'. power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0. center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True. pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'. @@ -97,7 +97,7 @@ class MelSpectrogram(nn.Layer): n_fft (int, optional): The number of frequency components of the discrete Fourier transform. Defaults to 512. hop_length (Optional[int], optional): The hop length of the short time FFT. If `None`, it is set to `win_length//4`. Defaults to None. win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None. - window (str, optional): The window function applied to the single before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'. + window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'. power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0. center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True. pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'. @@ -174,7 +174,7 @@ class LogMelSpectrogram(nn.Layer): n_fft (int, optional): The number of frequency components of the discrete Fourier transform. Defaults to 512. hop_length (Optional[int], optional): The hop length of the short time FFT. If `None`, it is set to `win_length//4`. Defaults to None. win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None. - window (str, optional): The window function applied to the single before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'. + window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'. power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0. center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True. pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'. @@ -255,7 +255,7 @@ class MFCC(nn.Layer): n_fft (int, optional): The number of frequency components of the discrete Fourier transform. Defaults to 512. hop_length (Optional[int], optional): The hop length of the short time FFT. If `None`, it is set to `win_length//4`. Defaults to None. win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None. - window (str, optional): The window function applied to the single before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'. + window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'. power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0. center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True. pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'. diff --git a/paddleaudio/paddleaudio/functional/functional.py b/paddleaudio/paddleaudio/functional/functional.py index c5ab3045..19c63a9a 100644 --- a/paddleaudio/paddleaudio/functional/functional.py +++ b/paddleaudio/paddleaudio/functional/functional.py @@ -17,6 +17,7 @@ from typing import Optional from typing import Union import paddle +from paddle import Tensor __all__ = [ 'hz_to_mel', @@ -29,19 +30,20 @@ __all__ = [ ] -def hz_to_mel(freq: Union[paddle.Tensor, float], - htk: bool=False) -> Union[paddle.Tensor, float]: +def hz_to_mel(freq: Union[Tensor, float], + htk: bool=False) -> Union[Tensor, float]: """Convert Hz to Mels. - Parameters: - freq: the input tensor of arbitrary shape, or a single floating point number. - htk: use HTK formula to do the conversion. - The default value is False. + + Args: + freq (Union[Tensor, float]): The input tensor with arbitrary shape. + htk (bool, optional): Use htk scaling. Defaults to False. + Returns: - The frequencies represented in Mel-scale. + Union[Tensor, float]: Frequency in mels. """ if htk: - if isinstance(freq, paddle.Tensor): + if isinstance(freq, Tensor): return 2595.0 * paddle.log10(1.0 + freq / 700.0) else: return 2595.0 * math.log10(1.0 + freq / 700.0) @@ -58,7 +60,7 @@ def hz_to_mel(freq: Union[paddle.Tensor, float], min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) logstep = math.log(6.4) / 27.0 # step size for log region - if isinstance(freq, paddle.Tensor): + if isinstance(freq, Tensor): target = min_log_mel + paddle.log( freq / min_log_hz + 1e-10) / logstep # prevent nan with 1e-10 mask = (freq > min_log_hz).astype(freq.dtype) @@ -71,14 +73,16 @@ def hz_to_mel(freq: Union[paddle.Tensor, float], return mels -def mel_to_hz(mel: Union[float, paddle.Tensor], - htk: bool=False) -> Union[float, paddle.Tensor]: +def mel_to_hz(mel: Union[float, Tensor], + htk: bool=False) -> Union[float, Tensor]: """Convert mel bin numbers to frequencies. - Parameters: - mel: the mel frequency represented as a tensor of arbitrary shape, or a floating point number. - htk: use HTK formula to do the conversion. + + Args: + mel (Union[float, Tensor]): The mel frequency represented as a tensor with arbitrary shape. + htk (bool, optional): Use htk scaling. Defaults to False. + Returns: - The frequencies represented in hz. + Union[float, Tensor]: Frequencies in Hz. """ if htk: return 700.0 * (10.0**(mel / 2595.0) - 1.0) @@ -90,7 +94,7 @@ def mel_to_hz(mel: Union[float, paddle.Tensor], min_log_hz = 1000.0 # beginning of log region (Hz) min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) logstep = math.log(6.4) / 27.0 # step size for log region - if isinstance(mel, paddle.Tensor): + if isinstance(mel, Tensor): target = min_log_hz * paddle.exp(logstep * (mel - min_log_mel)) mask = (mel > min_log_mel).astype(mel.dtype) freqs = target * mask + freqs * ( @@ -106,16 +110,18 @@ def mel_frequencies(n_mels: int=64, f_min: float=0.0, f_max: float=11025.0, htk: bool=False, - dtype: str=paddle.float32): + dtype: str='float32') -> Tensor: """Compute mel frequencies. - Parameters: - n_mels(int): number of Mel bins. - f_min(float): the lower cut-off frequency, below which the filter response is zero. - f_max(float): the upper cut-off frequency, above which the filter response is zero. - htk(bool): whether to use htk formula. - dtype(str): the datatype of the return frequencies. + + Args: + n_mels (int, optional): Number of mel bins. Defaults to 64. + f_min (float, optional): Minimum frequency in Hz. Defaults to 0.0. + fmax (float, optional): Maximum frequency in Hz. Defaults to 11025.0. + htk (bool, optional): Use htk scaling. Defaults to False. + dtype (str, optional): The data type of the return frequencies. Defaults to 'float32'. + Returns: - The frequencies represented in Mel-scale + Tensor: Tensor of n_mels frequencies in Hz with shape `(n_mels,)`. """ # 'Center freqs' of mel bands - uniformly spaced between limits min_mel = hz_to_mel(f_min, htk=htk) @@ -125,14 +131,16 @@ def mel_frequencies(n_mels: int=64, return freqs -def fft_frequencies(sr: int, n_fft: int, dtype: str=paddle.float32): +def fft_frequencies(sr: int, n_fft: int, dtype: str='float32') -> Tensor: """Compute fourier frequencies. - Parameters: - sr(int): the audio sample rate. - n_fft(float): the number of fft bins. - dtype(str): the datatype of the return frequencies. + + Args: + sr (int): Sample rate. + n_fft (int): Number of fft bins. + dtype (str, optional): The data type of the return frequencies. Defaults to 'float32'. + Returns: - The frequencies represented in hz. + Tensor: FFT frequencies in Hz with shape `(n_fft//2 + 1,)`. """ return paddle.linspace(0, float(sr) / 2, int(1 + n_fft // 2), dtype=dtype) @@ -144,23 +152,21 @@ def compute_fbank_matrix(sr: int, f_max: Optional[float]=None, htk: bool=False, norm: Union[str, float]='slaney', - dtype: str=paddle.float32): + dtype: str='float32') -> Tensor: """Compute fbank matrix. - Parameters: - sr(int): the audio sample rate. - n_fft(int): the number of fft bins. - n_mels(int): the number of Mel bins. - f_min(float): the lower cut-off frequency, below which the filter response is zero. - f_max(float): the upper cut-off frequency, above which the filter response is zero. - htk: whether to use htk formula. - return_complex(bool): whether to return complex matrix. If True, the matrix will - be complex type. Otherwise, the real and image part will be stored in the last - axis of returned tensor. - dtype(str): the datatype of the returned fbank matrix. + + Args: + sr (int): Sample rate. + n_fft (int): Number of fft bins. + n_mels (int, optional): Number of mel bins. Defaults to 64. + f_min (float, optional): Minimum frequency in Hz. Defaults to 0.0. + f_max (Optional[float], optional): Maximum frequency in Hz. Defaults to None. + htk (bool, optional): Use htk scaling. Defaults to False. + norm (Union[str, float], optional): Type of normalization. Defaults to 'slaney'. + dtype (str, optional): The data type of the return matrix. Defaults to 'float32'. + Returns: - The fbank matrix of shape (n_mels, int(1+n_fft//2)). - Shape: - output: (n_mels, int(1+n_fft//2)) + Tensor: Mel transform matrix with shape `(n_mels, n_fft//2 + 1)`. """ if f_max is None: @@ -199,27 +205,20 @@ def compute_fbank_matrix(sr: int, return weights -def power_to_db(magnitude: paddle.Tensor, +def power_to_db(spect: Tensor, ref_value: float=1.0, amin: float=1e-10, - top_db: Optional[float]=None) -> paddle.Tensor: - """Convert a power spectrogram (amplitude squared) to decibel (dB) units. - The function computes the scaling ``10 * log10(x / ref)`` in a numerically - stable way. - Parameters: - magnitude(Tensor): the input magnitude tensor of any shape. - ref_value(float): the reference value. If smaller than 1.0, the db level - of the signal will be pulled up accordingly. Otherwise, the db level - is pushed down. - amin(float): the minimum value of input magnitude, below which the input - magnitude is clipped(to amin). - top_db(float): the maximum db value of resulting spectrum, above which the - spectrum is clipped(to top_db). + top_db: Optional[float]=None) -> Tensor: + """Convert a power spectrogram (amplitude squared) to decibel (dB) units. The function computes the scaling `10 * log10(x / ref)` in a numerically stable way. + + Args: + spect (Tensor): STFT power spectrogram. + ref_value (float, optional): The reference value. If smaller than 1.0, the db level of the signal will be pulled up accordingly. Otherwise, the db level is pushed down. Defaults to 1.0. + amin (float, optional): Minimum threshold. Defaults to 1e-10. + top_db (Optional[float], optional): Threshold the output at `top_db` below the peak. Defaults to None. + Returns: - The spectrogram in log-scale. - shape: - input: any shape - output: same as input + Tensor: Power spectrogram in db scale. """ if amin <= 0: raise Exception("amin must be strictly positive") @@ -227,8 +226,8 @@ def power_to_db(magnitude: paddle.Tensor, if ref_value <= 0: raise Exception("ref_value must be strictly positive") - ones = paddle.ones_like(magnitude) - log_spec = 10.0 * paddle.log10(paddle.maximum(ones * amin, magnitude)) + ones = paddle.ones_like(spect) + log_spec = 10.0 * paddle.log10(paddle.maximum(ones * amin, spect)) log_spec -= 10.0 * math.log10(max(ref_value, amin)) if top_db is not None: @@ -242,15 +241,17 @@ def power_to_db(magnitude: paddle.Tensor, def create_dct(n_mfcc: int, n_mels: int, norm: Optional[str]='ortho', - dtype: Optional[str]=paddle.float32) -> paddle.Tensor: + dtype: str='float32') -> Tensor: """Create a discrete cosine transform(DCT) matrix. - Parameters: + Args: n_mfcc (int): Number of mel frequency cepstral coefficients. n_mels (int): Number of mel filterbanks. - norm (str, optional): Normalizaiton type. Defaults to 'ortho'. + norm (Optional[str], optional): Normalizaiton type. Defaults to 'ortho'. + dtype (str, optional): The data type of the return matrix. Defaults to 'float32'. + Returns: - Tensor: The DCT matrix with shape (n_mels, n_mfcc). + Tensor: The DCT matrix with shape `(n_mels, n_mfcc)`. """ n = paddle.arange(n_mels, dtype=dtype) k = paddle.arange(n_mfcc, dtype=dtype).unsqueeze(1) diff --git a/paddleaudio/paddleaudio/functional/window.py b/paddleaudio/paddleaudio/functional/window.py index f321b38e..c99d5046 100644 --- a/paddleaudio/paddleaudio/functional/window.py +++ b/paddleaudio/paddleaudio/functional/window.py @@ -20,24 +20,11 @@ from paddle import Tensor __all__ = [ 'get_window', - - # windows - 'taylor', - 'hamming', - 'hann', - 'tukey', - 'kaiser', - 'gaussian', - 'exponential', - 'triang', - 'bohman', - 'blackman', - 'cosine', ] -def _cat(a: List[Tensor], data_type: str) -> Tensor: - l = [paddle.to_tensor(_a, data_type) for _a in a] +def _cat(x: List[Tensor], data_type: str) -> Tensor: + l = [paddle.to_tensor(_, data_type) for _ in x] return paddle.concat(l) @@ -48,7 +35,7 @@ def _acosh(x: Union[Tensor, float]) -> Tensor: def _extend(M: int, sym: bool) -> bool: - """Extend window by 1 sample if needed for DFT-even symmetry""" + """Extend window by 1 sample if needed for DFT-even symmetry. """ if not sym: return M + 1, True else: @@ -56,7 +43,7 @@ def _extend(M: int, sym: bool) -> bool: def _len_guards(M: int) -> bool: - """Handle small or incorrect window lengths""" + """Handle small or incorrect window lengths. """ if int(M) != M or M < 0: raise ValueError('Window length M must be a non-negative integer') @@ -64,15 +51,15 @@ def _len_guards(M: int) -> bool: def _truncate(w: Tensor, needed: bool) -> Tensor: - """Truncate window by 1 sample if needed for DFT-even symmetry""" + """Truncate window by 1 sample if needed for DFT-even symmetry. """ if needed: return w[:-1] else: return w -def general_gaussian(M: int, p, sig, sym: bool=True, - dtype: str='float64') -> Tensor: +def _general_gaussian(M: int, p, sig, sym: bool=True, + dtype: str='float64') -> Tensor: """Compute a window with a generalized Gaussian shape. This function is consistent with scipy.signal.windows.general_gaussian(). """ @@ -86,8 +73,8 @@ def general_gaussian(M: int, p, sig, sym: bool=True, return _truncate(w, needs_trunc) -def general_cosine(M: int, a: float, sym: bool=True, - dtype: str='float64') -> Tensor: +def _general_cosine(M: int, a: float, sym: bool=True, + dtype: str='float64') -> Tensor: """Compute a generic weighted sum of cosine terms window. This function is consistent with scipy.signal.windows.general_cosine(). """ @@ -101,31 +88,23 @@ def general_cosine(M: int, a: float, sym: bool=True, return _truncate(w, needs_trunc) -def general_hamming(M: int, alpha: float, sym: bool=True, - dtype: str='float64') -> Tensor: +def _general_hamming(M: int, alpha: float, sym: bool=True, + dtype: str='float64') -> Tensor: """Compute a generalized Hamming window. This function is consistent with scipy.signal.windows.general_hamming() """ - return general_cosine(M, [alpha, 1. - alpha], sym, dtype=dtype) + return _general_cosine(M, [alpha, 1. - alpha], sym, dtype=dtype) -def taylor(M: int, - nbar=4, - sll=30, - norm=True, - sym: bool=True, - dtype: str='float64') -> Tensor: +def _taylor(M: int, + nbar=4, + sll=30, + norm=True, + sym: bool=True, + dtype: str='float64') -> Tensor: """Compute a Taylor window. The Taylor window taper function approximates the Dolph-Chebyshev window's constant sidelobe level for a parameterized number of near-in sidelobes. - Parameters: - M(int): window size - nbar, sil, norm: the window-specific parameter. - sym(bool):whether to return symmetric window. - The default value is True - dtype(str): the datatype of returned tensor. - Returns: - Tensor: the window tensor """ if _len_guards(M): return paddle.ones((M, ), dtype=dtype) @@ -171,46 +150,25 @@ def taylor(M: int, return _truncate(w, needs_trunc) -def hamming(M: int, sym: bool=True, dtype: str='float64') -> Tensor: +def _hamming(M: int, sym: bool=True, dtype: str='float64') -> Tensor: """Compute a Hamming window. The Hamming window is a taper formed by using a raised cosine with non-zero endpoints, optimized to minimize the nearest side lobe. - Parameters: - M(int): window size - sym(bool):whether to return symmetric window. - The default value is True - dtype(str): the datatype of returned tensor. - Returns: - Tensor: the window tensor """ - return general_hamming(M, 0.54, sym, dtype=dtype) + return _general_hamming(M, 0.54, sym, dtype=dtype) -def hann(M: int, sym: bool=True, dtype: str='float64') -> Tensor: +def _hann(M: int, sym: bool=True, dtype: str='float64') -> Tensor: """Compute a Hann window. The Hann window is a taper formed by using a raised cosine or sine-squared with ends that touch zero. - Parameters: - M(int): window size - sym(bool):whether to return symmetric window. - The default value is True - dtype(str): the datatype of returned tensor. - Returns: - Tensor: the window tensor """ - return general_hamming(M, 0.5, sym, dtype=dtype) + return _general_hamming(M, 0.5, sym, dtype=dtype) -def tukey(M: int, alpha=0.5, sym: bool=True, dtype: str='float64') -> Tensor: +def _tukey(M: int, alpha=0.5, sym: bool=True, dtype: str='float64') -> Tensor: """Compute a Tukey window. The Tukey window is also known as a tapered cosine window. - Parameters: - M(int): window size - sym(bool):whether to return symmetric window. - The default value is True - dtype(str): the datatype of returned tensor. - Returns: - Tensor: the window tensor """ if _len_guards(M): return paddle.ones((M, ), dtype=dtype) @@ -237,32 +195,18 @@ def tukey(M: int, alpha=0.5, sym: bool=True, dtype: str='float64') -> Tensor: return _truncate(w, needs_trunc) -def kaiser(M: int, beta: float, sym: bool=True, dtype: str='float64') -> Tensor: +def _kaiser(M: int, beta: float, sym: bool=True, + dtype: str='float64') -> Tensor: """Compute a Kaiser window. The Kaiser window is a taper formed by using a Bessel function. - Parameters: - M(int): window size. - beta(float): the window-specific parameter. - sym(bool):whether to return symmetric window. - The default value is True - Returns: - Tensor: the window tensor """ raise NotImplementedError() -def gaussian(M: int, std: float, sym: bool=True, - dtype: str='float64') -> Tensor: +def _gaussian(M: int, std: float, sym: bool=True, + dtype: str='float64') -> Tensor: """Compute a Gaussian window. The Gaussian widows has a Gaussian shape defined by the standard deviation(std). - Parameters: - M(int): window size. - std(float): the window-specific parameter. - sym(bool):whether to return symmetric window. - The default value is True - dtype(str): the datatype of returned tensor. - Returns: - Tensor: the window tensor """ if _len_guards(M): return paddle.ones((M, ), dtype=dtype) @@ -275,21 +219,12 @@ def gaussian(M: int, std: float, sym: bool=True, return _truncate(w, needs_trunc) -def exponential(M: int, - center=None, - tau=1., - sym: bool=True, - dtype: str='float64') -> Tensor: - """Compute an exponential (or Poisson) window. - Parameters: - M(int): window size. - tau(float): the window-specific parameter. - sym(bool):whether to return symmetric window. - The default value is True - dtype(str): the datatype of returned tensor. - Returns: - Tensor: the window tensor - """ +def _exponential(M: int, + center=None, + tau=1., + sym: bool=True, + dtype: str='float64') -> Tensor: + """Compute an exponential (or Poisson) window. """ if sym and center is not None: raise ValueError("If sym==True, center must be None.") if _len_guards(M): @@ -305,15 +240,8 @@ def exponential(M: int, return _truncate(w, needs_trunc) -def triang(M: int, sym: bool=True, dtype: str='float64') -> Tensor: +def _triang(M: int, sym: bool=True, dtype: str='float64') -> Tensor: """Compute a triangular window. - Parameters: - M(int): window size. - sym(bool):whether to return symmetric window. - The default value is True - dtype(str): the datatype of returned tensor. - Returns: - Tensor: the window tensor """ if _len_guards(M): return paddle.ones((M, ), dtype=dtype) @@ -330,16 +258,9 @@ def triang(M: int, sym: bool=True, dtype: str='float64') -> Tensor: return _truncate(w, needs_trunc) -def bohman(M: int, sym: bool=True, dtype: str='float64') -> Tensor: +def _bohman(M: int, sym: bool=True, dtype: str='float64') -> Tensor: """Compute a Bohman window. The Bohman window is the autocorrelation of a cosine window. - Parameters: - M(int): window size. - sym(bool):whether to return symmetric window. - The default value is True - dtype(str): the datatype of returned tensor. - Returns: - Tensor: the window tensor """ if _len_guards(M): return paddle.ones((M, ), dtype=dtype) @@ -353,32 +274,18 @@ def bohman(M: int, sym: bool=True, dtype: str='float64') -> Tensor: return _truncate(w, needs_trunc) -def blackman(M: int, sym: bool=True, dtype: str='float64') -> Tensor: +def _blackman(M: int, sym: bool=True, dtype: str='float64') -> Tensor: """Compute a Blackman window. The Blackman window is a taper formed by using the first three terms of a summation of cosines. It was designed to have close to the minimal leakage possible. It is close to optimal, only slightly worse than a Kaiser window. - Parameters: - M(int): window size. - sym(bool):whether to return symmetric window. - The default value is True - dtype(str): the datatype of returned tensor. - Returns: - Tensor: the window tensor """ - return general_cosine(M, [0.42, 0.50, 0.08], sym, dtype=dtype) + return _general_cosine(M, [0.42, 0.50, 0.08], sym, dtype=dtype) -def cosine(M: int, sym: bool=True, dtype: str='float64') -> Tensor: +def _cosine(M: int, sym: bool=True, dtype: str='float64') -> Tensor: """Compute a window with a simple cosine shape. - Parameters: - M(int): window size. - sym(bool):whether to return symmetric window. - The default value is True - dtype(str): the datatype of returned tensor. - Returns: - Tensor: the window tensor """ if _len_guards(M): return paddle.ones((M, ), dtype=dtype) @@ -388,19 +295,20 @@ def cosine(M: int, sym: bool=True, dtype: str='float64') -> Tensor: return _truncate(w, needs_trunc) -## factory function def get_window(window: Union[str, Tuple[str, float]], win_length: int, fftbins: bool=True, dtype: str='float64') -> Tensor: """Return a window of a given length and type. - Parameters: - window(str|(str,float)): the type of window to create. - win_length(int): the number of samples in the window. - fftbins(bool): If True, create a "periodic" window. Otherwise, - create a "symmetric" window, for use in filter design. + + Args: + window (Union[str, Tuple[str, float]]): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. + win_length (int): Number of samples. + fftbins (bool, optional): If True, create a "periodic" window. Otherwise, create a "symmetric" window, for use in filter design. Defaults to True. + dtype (str, optional): The data type of the return window. Defaults to 'float64'. + Returns: - The window represented as a tensor. + Tensor: The window represented as a tensor. """ sym = not fftbins @@ -420,7 +328,7 @@ def get_window(window: Union[str, Tuple[str, float]], str(type(window))) try: - winfunc = eval(winstr) + winfunc = eval('_' + winstr) except KeyError as e: raise ValueError("Unknown window type.") from e diff --git a/paddleaudio/paddleaudio/metric/dtw.py b/paddleaudio/paddleaudio/metric/dtw.py index d27f56e2..c4dc7a28 100644 --- a/paddleaudio/paddleaudio/metric/dtw.py +++ b/paddleaudio/paddleaudio/metric/dtw.py @@ -20,9 +20,7 @@ __all__ = [ def dtw_distance(xs: np.ndarray, ys: np.ndarray) -> float: - """dtw distance - - Dynamic Time Warping. + """Dynamic Time Warping. This function keeps a compact matrix, not the full warping paths matrix. Uses dynamic programming to compute: diff --git a/paddlespeech/cli/executor.py b/paddlespeech/cli/executor.py index d77d27b0..064939a8 100644 --- a/paddlespeech/cli/executor.py +++ b/paddlespeech/cli/executor.py @@ -178,7 +178,8 @@ class BaseExecutor(ABC): Returns: bool: return `True` for job input, `False` otherwise. """ - return input_ and os.path.isfile(input_) and input_.endswith('.job') + return input_ and os.path.isfile(input_) and (input_.endswith('.job') or + input_.endswith('.txt')) def _get_job_contents( self, job_input: os.PathLike) -> Dict[str, Union[str, os.PathLike]]: