diff --git a/paddleaudio/paddleaudio/backends/soundfile_backend.py b/paddleaudio/paddleaudio/backends/soundfile_backend.py index 2b920284..c1155654 100644 --- a/paddleaudio/paddleaudio/backends/soundfile_backend.py +++ b/paddleaudio/paddleaudio/backends/soundfile_backend.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import warnings from typing import Optional from typing import Tuple @@ -19,7 +20,6 @@ from typing import Union import numpy as np import resampy import soundfile as sf -from numpy import ndarray as array from scipy.io import wavfile from ..utils import ParameterError @@ -38,13 +38,21 @@ RESAMPLE_MODES = ['kaiser_best', 'kaiser_fast'] EPS = 1e-8 -def resample(y: array, src_sr: int, target_sr: int, - mode: str='kaiser_fast') -> array: - """ Audio resampling - This function is the same as using resampy.resample(). - Notes: - The default mode is kaiser_fast. For better audio quality, use mode = 'kaiser_fast' - """ +def resample(y: np.ndarray, + src_sr: int, + target_sr: int, + mode: str='kaiser_fast') -> np.ndarray: + """Audio resampling. + + Args: + y (np.ndarray): Input waveform array in 1D or 2D. + src_sr (int): Source sample rate. + target_sr (int): Target sample rate. + mode (str, optional): The resampling filter to use. Defaults to 'kaiser_fast'. + + Returns: + np.ndarray: `y` resampled to `target_sr` + """ if mode == 'kaiser_best': warnings.warn( @@ -53,7 +61,7 @@ def resample(y: array, src_sr: int, target_sr: int, if not isinstance(y, np.ndarray): raise ParameterError( - 'Only support numpy array, but received y in {type(y)}') + 'Only support numpy np.ndarray, but received y in {type(y)}') if mode not in RESAMPLE_MODES: raise ParameterError(f'resample mode must in {RESAMPLE_MODES}') @@ -61,9 +69,17 @@ def resample(y: array, src_sr: int, target_sr: int, return resampy.resample(y, src_sr, target_sr, filter=mode) -def to_mono(y: array, merge_type: str='average') -> array: - """ convert sterior audio to mono +def to_mono(y: np.ndarray, merge_type: str='average') -> np.ndarray: + """Convert sterior audio to mono. + + Args: + y (np.ndarray): Input waveform array in 1D or 2D. + merge_type (str, optional): Merge type to generate mono waveform. Defaults to 'average'. + + Returns: + np.ndarray: `y` with mono channel. """ + if merge_type not in MERGE_TYPES: raise ParameterError( f'Unsupported merge type {merge_type}, available types are {MERGE_TYPES}' @@ -101,18 +117,34 @@ def to_mono(y: array, merge_type: str='average') -> array: return y_out -def _safe_cast(y: array, dtype: Union[type, str]) -> array: - """ data type casting in a safe way, i.e., prevent overflow or underflow - This function is used internally. +def _safe_cast(y: np.ndarray, dtype: Union[type, str]) -> np.ndarray: + """Data type casting in a safe way, i.e., prevent overflow or underflow. + + Args: + y (np.ndarray): Input waveform array in 1D or 2D. + dtype (Union[type, str]): Data type of waveform. + + Returns: + np.ndarray: `y` after safe casting. """ - return np.clip(y, np.iinfo(dtype).min, np.iinfo(dtype).max).astype(dtype) + if 'float' in str(y.dtype): + return np.clip(y, np.finfo(dtype).min, + np.finfo(dtype).max).astype(dtype) + else: + return np.clip(y, np.iinfo(dtype).min, + np.iinfo(dtype).max).astype(dtype) -def depth_convert(y: array, dtype: Union[type, str], - dithering: bool=True) -> array: - """Convert audio array to target dtype safely - This function convert audio waveform to a target dtype, with addition steps of +def depth_convert(y: np.ndarray, dtype: Union[type, str]) -> np.ndarray: + """Convert audio array to target dtype safely. This function convert audio waveform to a target dtype, with addition steps of preventing overflow/underflow and preserving audio range. + + Args: + y (np.ndarray): Input waveform array in 1D or 2D. + dtype (Union[type, str]): Data type of waveform. + + Returns: + np.ndarray: `y` after safe casting. """ SUPPORT_DTYPE = ['int16', 'int8', 'float32', 'float64'] @@ -157,14 +189,20 @@ def depth_convert(y: array, dtype: Union[type, str], return y -def sound_file_load(file: str, +def sound_file_load(file: os.PathLike, offset: Optional[float]=None, dtype: str='int16', - duration: Optional[int]=None) -> Tuple[array, int]: - """Load audio using soundfile library - This function load audio file using libsndfile. - Reference: - http://www.mega-nerd.com/libsndfile/#Features + duration: Optional[int]=None) -> Tuple[np.ndarray, int]: + """Load audio using soundfile library. This function load audio file using libsndfile. + + Args: + file (os.PathLike): File of waveform. + offset (Optional[float], optional): Offset to the start of waveform. Defaults to None. + dtype (str, optional): Data type of waveform. Defaults to 'int16'. + duration (Optional[int], optional): Duration of waveform to read. Defaults to None. + + Returns: + Tuple[np.ndarray, int]: Waveform in ndarray and its samplerate. """ with sf.SoundFile(file) as sf_desc: sr_native = sf_desc.samplerate @@ -179,9 +217,17 @@ def sound_file_load(file: str, return y, sf_desc.samplerate -def normalize(y: array, norm_type: str='linear', - mul_factor: float=1.0) -> array: - """ normalize an input audio with additional multiplier. +def normalize(y: np.ndarray, norm_type: str='linear', + mul_factor: float=1.0) -> np.ndarray: + """Normalize an input audio with additional multiplier. + + Args: + y (np.ndarray): Input waveform array in 1D or 2D. + norm_type (str, optional): Type of normalization. Defaults to 'linear'. + mul_factor (float, optional): Scaling factor. Defaults to 1.0. + + Returns: + np.ndarray: `y` after normalization. """ if norm_type == 'linear': @@ -199,12 +245,13 @@ def normalize(y: array, norm_type: str='linear', return y -def save(y: array, sr: int, file: str) -> None: - """Save audio file to disk. - This function saves audio to disk using scipy.io.wavfile, with additional step - to convert input waveform to int16 unless it already is int16 - Notes: - It only support raw wav format. +def save(y: np.ndarray, sr: int, file: os.PathLike) -> None: + """Save audio file to disk. This function saves audio to disk using scipy.io.wavfile, with additional step to convert input waveform to int16. + + Args: + y (np.ndarray): Input waveform array in 1D or 2D. + sr (int): Sample rate. + file (os.PathLike): Path of auido file to save. """ if not file.endswith('.wav'): raise ParameterError( @@ -226,7 +273,7 @@ def save(y: array, sr: int, file: str) -> None: def load( - file: str, + file: os.PathLike, sr: Optional[int]=None, mono: bool=True, merge_type: str='average', # ch0,ch1,random,average @@ -236,11 +283,24 @@ def load( offset: float=0.0, duration: Optional[int]=None, dtype: str='float32', - resample_mode: str='kaiser_fast') -> Tuple[array, int]: - """Load audio file from disk. - This function loads audio from disk using using audio beackend. - Parameters: - Notes: + resample_mode: str='kaiser_fast') -> Tuple[np.ndarray, int]: + """Load audio file from disk. This function loads audio from disk using using audio beackend. + + Args: + file (os.PathLike): Path of auido file to load. + sr (Optional[int], optional): Sample rate of loaded waveform. Defaults to None. + mono (bool, optional): Return waveform with mono channel. Defaults to True. + merge_type (str, optional): Merge type of multi-channels waveform. Defaults to 'average'. + normal (bool, optional): Waveform normalization. Defaults to True. + norm_type (str, optional): Type of normalization. Defaults to 'linear'. + norm_mul_factor (float, optional): Scaling factor. Defaults to 1.0. + offset (float, optional): Offset to the start of waveform. Defaults to 0.0. + duration (Optional[int], optional): Duration of waveform to read. Defaults to None. + dtype (str, optional): Data type of waveform. Defaults to 'float32'. + resample_mode (str, optional): The resampling filter to use. Defaults to 'kaiser_fast'. + + Returns: + Tuple[np.ndarray, int]: Waveform in ndarray and its samplerate. """ y, r = sound_file_load(file, offset=offset, dtype=dtype, duration=duration) diff --git a/paddleaudio/paddleaudio/compliance/kaldi.py b/paddleaudio/paddleaudio/compliance/kaldi.py index 8cb9b666..538be019 100644 --- a/paddleaudio/paddleaudio/compliance/kaldi.py +++ b/paddleaudio/paddleaudio/compliance/kaldi.py @@ -220,7 +220,7 @@ def spectrogram(waveform: Tensor, """Compute and return a spectrogram from a waveform. The output is identical to Kaldi's. Args: - waveform (Tensor): A waveform tensor with shape [C, T]. + waveform (Tensor): A waveform tensor with shape `(C, T)`. blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42. channel (int, optional): Select the channel of waveform. Defaults to -1. dither (float, optional): Dithering constant . Defaults to 0.0. @@ -239,7 +239,7 @@ def spectrogram(waveform: Tensor, window_type (str, optional): Choose type of window for FFT computation. Defaults to POVEY. Returns: - Tensor: A spectrogram tensor with shape (m, padded_window_size // 2 + 1) where m is the number of frames + Tensor: A spectrogram tensor with shape `(m, padded_window_size // 2 + 1)` where m is the number of frames depends on frame_length and frame_shift. """ dtype = waveform.dtype @@ -422,7 +422,7 @@ def fbank(waveform: Tensor, """Compute and return filter banks from a waveform. The output is identical to Kaldi's. Args: - waveform (Tensor): A waveform tensor with shape [C, T]. + waveform (Tensor): A waveform tensor with shape `(C, T)`. blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42. channel (int, optional): Select the channel of waveform. Defaults to -1. dither (float, optional): Dithering constant . Defaults to 0.0. @@ -451,7 +451,7 @@ def fbank(waveform: Tensor, window_type (str, optional): Choose type of window for FFT computation. Defaults to POVEY. Returns: - Tensor: A filter banks tensor with shape (m, n_mels). + Tensor: A filter banks tensor with shape `(m, n_mels)`. """ dtype = waveform.dtype @@ -542,7 +542,7 @@ def mfcc(waveform: Tensor, identical to Kaldi's. Args: - waveform (Tensor): A waveform tensor with shape [C, T]. + waveform (Tensor): A waveform tensor with shape `(C, T)`. blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42. cepstral_lifter (float, optional): Scaling of output mfccs. Defaults to 22.0. channel (int, optional): Select the channel of waveform. Defaults to -1. @@ -571,7 +571,7 @@ def mfcc(waveform: Tensor, window_type (str, optional): Choose type of window for FFT computation. Defaults to POVEY. Returns: - Tensor: A mel frequency cepstral coefficients tensor with shape (m, n_mfcc). + Tensor: A mel frequency cepstral coefficients tensor with shape `(m, n_mfcc)`. """ assert n_mfcc <= n_mels, 'n_mfcc cannot be larger than n_mels: %d vs %d' % ( n_mfcc, n_mels) diff --git a/paddleaudio/paddleaudio/compliance/librosa.py b/paddleaudio/paddleaudio/compliance/librosa.py index 167795c3..d7ceb2b4 100644 --- a/paddleaudio/paddleaudio/compliance/librosa.py +++ b/paddleaudio/paddleaudio/compliance/librosa.py @@ -19,7 +19,6 @@ from typing import Union import numpy as np import scipy -from numpy import ndarray as array from numpy.lib.stride_tricks import as_strided from scipy import signal @@ -32,7 +31,6 @@ __all__ = [ 'mfcc', 'hz_to_mel', 'mel_to_hz', - 'split_frames', 'mel_frequencies', 'power_to_db', 'compute_fbank_matrix', @@ -49,7 +47,8 @@ __all__ = [ ] -def pad_center(data: array, size: int, axis: int=-1, **kwargs) -> array: +def _pad_center(data: np.ndarray, size: int, axis: int=-1, + **kwargs) -> np.ndarray: """Pad an array to a target length along a target axis. This differs from `np.pad` by centering the data prior to padding, @@ -69,8 +68,10 @@ def pad_center(data: array, size: int, axis: int=-1, **kwargs) -> array: return np.pad(data, lengths, **kwargs) -def split_frames(x: array, frame_length: int, hop_length: int, - axis: int=-1) -> array: +def _split_frames(x: np.ndarray, + frame_length: int, + hop_length: int, + axis: int=-1) -> np.ndarray: """Slice a data array into (overlapping) frames. This function is aligned with librosa.frame @@ -142,11 +143,16 @@ def _check_audio(y, mono=True) -> bool: return True -def hz_to_mel(frequencies: Union[float, List[float], array], - htk: bool=False) -> array: - """Convert Hz to Mels +def hz_to_mel(frequencies: Union[float, List[float], np.ndarray], + htk: bool=False) -> np.ndarray: + """Convert Hz to Mels. - This function is aligned with librosa. + Args: + frequencies (Union[float, List[float], np.ndarray]): Frequencies in Hz. + htk (bool, optional): Use htk scaling. Defaults to False. + + Returns: + np.ndarray: Frequency in mels. """ freq = np.asanyarray(frequencies) @@ -177,10 +183,16 @@ def hz_to_mel(frequencies: Union[float, List[float], array], return mels -def mel_to_hz(mels: Union[float, List[float], array], htk: int=False) -> array: +def mel_to_hz(mels: Union[float, List[float], np.ndarray], + htk: int=False) -> np.ndarray: """Convert mel bin numbers to frequencies. - This function is aligned with librosa. + Args: + mels (Union[float, List[float], np.ndarray]): Frequency in mels. + htk (bool, optional): Use htk scaling. Defaults to False. + + Returns: + np.ndarray: Frequencies in Hz. """ mel_array = np.asanyarray(mels) @@ -212,10 +224,17 @@ def mel_to_hz(mels: Union[float, List[float], array], htk: int=False) -> array: def mel_frequencies(n_mels: int=128, fmin: float=0.0, fmax: float=11025.0, - htk: bool=False) -> array: - """Compute mel frequencies + htk: bool=False) -> np.ndarray: + """Compute mel frequencies. + + Args: + n_mels (int, optional): Number of mel bins. Defaults to 128. + fmin (float, optional): Minimum frequency in Hz. Defaults to 0.0. + fmax (float, optional): Maximum frequency in Hz. Defaults to 11025.0. + htk (bool, optional): Use htk scaling. Defaults to False. - This function is aligned with librosa. + Returns: + np.ndarray: Vector of n_mels frequencies in Hz with shape `(n_mels,)`. """ # 'Center freqs' of mel bands - uniformly spaced between limits min_mel = hz_to_mel(fmin, htk=htk) @@ -226,10 +245,15 @@ def mel_frequencies(n_mels: int=128, return mel_to_hz(mels, htk=htk) -def fft_frequencies(sr: int, n_fft: int) -> array: +def fft_frequencies(sr: int, n_fft: int) -> np.ndarray: """Compute fourier frequencies. - This function is aligned with librosa. + Args: + sr (int): Sample rate. + n_fft (int): FFT size. + + Returns: + np.ndarray: FFT frequencies in Hz with shape `(n_fft//2 + 1,)`. """ return np.linspace(0, float(sr) / 2, int(1 + n_fft // 2), endpoint=True) @@ -241,10 +265,22 @@ def compute_fbank_matrix(sr: int, fmax: Optional[float]=None, htk: bool=False, norm: str="slaney", - dtype: type=np.float32): + dtype: type=np.float32) -> np.ndarray: """Compute fbank matrix. - This funciton is aligned with librosa. + Args: + sr (int): Sample rate. + n_fft (int): FFT size. + n_mels (int, optional): Number of mel bins. Defaults to 128. + fmin (float, optional): Minimum frequency in Hz. Defaults to 0.0. + fmax (Optional[float], optional): Maximum frequency in Hz. Defaults to None. + htk (bool, optional): Use htk scaling. Defaults to False. + norm (str, optional): Type of normalization. Defaults to "slaney". + dtype (type, optional): Data type. Defaults to np.float32. + + + Returns: + np.ndarray: Mel transform matrix with shape `(n_mels, n_fft//2 + 1)`. """ if norm != "slaney": raise ParameterError('norm must set to slaney') @@ -289,17 +325,28 @@ def compute_fbank_matrix(sr: int, return weights -def stft(x: array, +def stft(x: np.ndarray, n_fft: int=2048, hop_length: Optional[int]=None, win_length: Optional[int]=None, window: str="hann", center: bool=True, dtype: type=np.complex64, - pad_mode: str="reflect") -> array: + pad_mode: str="reflect") -> np.ndarray: """Short-time Fourier transform (STFT). - This function is aligned with librosa. + Args: + x (np.ndarray): Input waveform in one dimension. + n_fft (int, optional): FFT size. Defaults to 2048. + hop_length (Optional[int], optional): Number of steps to advance between adjacent windows. Defaults to None. + win_length (Optional[int], optional): The size of window. Defaults to None. + window (str, optional): A string of window specification. Defaults to "hann". + center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True. + dtype (type, optional): Data type of STFT results. Defaults to np.complex64. + pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to "reflect". + + Returns: + np.ndarray: The complex STFT output with shape `(n_fft//2 + 1, num_frames)` """ _check_audio(x) @@ -314,7 +361,7 @@ def stft(x: array, fft_window = signal.get_window(window, win_length, fftbins=True) # Pad the window out to n_fft size - fft_window = pad_center(fft_window, n_fft) + fft_window = _pad_center(fft_window, n_fft) # Reshape so that the window can be broadcast fft_window = fft_window.reshape((-1, 1)) @@ -333,7 +380,7 @@ def stft(x: array, ) # Window the time series. - x_frames = split_frames(x, frame_length=n_fft, hop_length=hop_length) + x_frames = _split_frames(x, frame_length=n_fft, hop_length=hop_length) # Pre-allocate the STFT matrix stft_matrix = np.empty( (int(1 + n_fft // 2), x_frames.shape[1]), dtype=dtype, order="F") @@ -352,16 +399,20 @@ def stft(x: array, return stft_matrix -def power_to_db(spect: array, +def power_to_db(spect: np.ndarray, ref: float=1.0, amin: float=1e-10, - top_db: Optional[float]=80.0) -> array: - """Convert a power spectrogram (amplitude squared) to decibel (dB) units + top_db: Optional[float]=80.0) -> np.ndarray: + """Convert a power spectrogram (amplitude squared) to decibel (dB) units. This computes the scaling `10 * log10(spect / ref)` in a numerically stable way. - This computes the scaling ``10 * log10(spect / ref)`` in a numerically - stable way. + Args: + spect (np.ndarray): STFT power spectrogram of an input waveform. + ref (float, optional): Scaling factor of spectrogram. Defaults to 1.0. + amin (float, optional): Minimum threshold. Defaults to 1e-10. + top_db (Optional[float], optional): Threshold the output at `top_db` below the peak. Defaults to 80.0. - This function is aligned with librosa. + Returns: + np.ndarray: Power spectrogram in db scale. """ spect = np.asarray(spect) @@ -394,49 +445,27 @@ def power_to_db(spect: array, return log_spec -def mfcc(x, +def mfcc(x: np.ndarray, sr: int=16000, - spect: Optional[array]=None, + spect: Optional[np.ndarray]=None, n_mfcc: int=20, dct_type: int=2, norm: str="ortho", lifter: int=0, - **kwargs) -> array: + **kwargs) -> np.ndarray: """Mel-frequency cepstral coefficients (MFCCs) - This function is NOT strictly aligned with librosa. The following example shows how to get the - same result with librosa: - - # mfcc: - kwargs = { - 'window_size':512, - 'hop_length':320, - 'mel_bins':64, - 'fmin':50, - 'to_db':False} - a = mfcc(x, - spect=None, - n_mfcc=20, - dct_type=2, - norm='ortho', - lifter=0, - **kwargs) - - # librosa mfcc: - spect = librosa.feature.melspectrogram(y=x,sr=16000,n_fft=512, - win_length=512, - hop_length=320, - n_mels=64, fmin=50) - b = librosa.feature.mfcc(y=x, - sr=16000, - S=spect, - n_mfcc=20, - dct_type=2, - norm='ortho', - lifter=0) - - assert np.mean( (a-b)**2) < 1e-8 + Args: + x (np.ndarray): Input waveform in one dimension. + sr (int, optional): Sample rate. Defaults to 16000. + spect (Optional[np.ndarray], optional): Input log-power Mel spectrogram. Defaults to None. + n_mfcc (int, optional): Number of cepstra in MFCC. Defaults to 20. + dct_type (int, optional): Discrete cosine transform (DCT) type. Defaults to 2. + norm (str, optional): Type of normalization. Defaults to "ortho". + lifter (int, optional): Cepstral filtering. Defaults to 0. + Returns: + np.ndarray: A mel frequency cepstral coefficients tensor with shape `(n_mfcc, num_frames)`. """ if spect is None: spect = melspectrogram(x, sr=sr, **kwargs) @@ -454,12 +483,12 @@ def mfcc(x, f"MFCC lifter={lifter} must be a non-negative number") -def melspectrogram(x: array, +def melspectrogram(x: np.ndarray, sr: int=16000, window_size: int=512, hop_length: int=320, n_mels: int=64, - fmin: int=50, + fmin: float=50.0, fmax: Optional[float]=None, window: str='hann', center: bool=True, @@ -468,27 +497,28 @@ def melspectrogram(x: array, to_db: bool=True, ref: float=1.0, amin: float=1e-10, - top_db: Optional[float]=None) -> array: + top_db: Optional[float]=None) -> np.ndarray: """Compute mel-spectrogram. - Parameters: - x: numpy.ndarray - The input wavform is a numpy array [shape=(n,)] - - window_size: int, typically 512, 1024, 2048, etc. - The window size for framing, also used as n_fft for stft - + Args: + x (np.ndarray): Input waveform in one dimension. + sr (int, optional): Sample rate. Defaults to 16000. + window_size (int, optional): Size of FFT and window length. Defaults to 512. + hop_length (int, optional): Number of steps to advance between adjacent windows. Defaults to 320. + n_mels (int, optional): Number of mel bins. Defaults to 64. + fmin (float, optional): Minimum frequency in Hz. Defaults to 50.0. + fmax (Optional[float], optional): Maximum frequency in Hz. Defaults to None. + window (str, optional): A string of window specification. Defaults to "hann". + center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True. + pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to "reflect". + power (float, optional): Exponent for the magnitude melspectrogram. Defaults to 2.0. + to_db (bool, optional): Enable db scale. Defaults to True. + ref (float, optional): Scaling factor of spectrogram. Defaults to 1.0. + amin (float, optional): Minimum threshold. Defaults to 1e-10. + top_db (Optional[float], optional): Threshold the output at `top_db` below the peak. Defaults to None. Returns: - The mel-spectrogram in power scale or db scale(default) - - - Notes: - 1. sr is default to 16000, which is commonly used in speech/speaker processing. - 2. when fmax is None, it is set to sr//2. - 3. this function will convert mel spectgrum to db scale by default. This is different - that of librosa. - + np.ndarray: The mel-spectrogram in power scale or db scale with shape `(n_mels, num_frames)`. """ _check_audio(x, mono=True) if len(x) <= 0: @@ -518,18 +548,28 @@ def melspectrogram(x: array, return mel_spect -def spectrogram(x: array, +def spectrogram(x: np.ndarray, sr: int=16000, window_size: int=512, hop_length: int=320, window: str='hann', center: bool=True, pad_mode: str='reflect', - power: float=2.0) -> array: - """Compute spectrogram from an input waveform. + power: float=2.0) -> np.ndarray: + """Compute spectrogram. + + Args: + x (np.ndarray): Input waveform in one dimension. + sr (int, optional): Sample rate. Defaults to 16000. + window_size (int, optional): Size of FFT and window length. Defaults to 512. + hop_length (int, optional): Number of steps to advance between adjacent windows. Defaults to 320. + window (str, optional): A string of window specification. Defaults to "hann". + center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True. + pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to "reflect". + power (float, optional): Exponent for the magnitude melspectrogram. Defaults to 2.0. - This function is a wrapper for librosa.feature.stft, with addition step to - compute the magnitude of the complex spectrogram. + Returns: + np.ndarray: The STFT spectrogram in power scale `(n_fft//2 + 1, num_frames)`. """ s = stft( @@ -544,18 +584,16 @@ def spectrogram(x: array, return np.abs(s)**power -def mu_encode(x: array, mu: int=255, quantized: bool=True) -> array: - """Mu-law encoding. - - Compute the mu-law decoding given an input code. - When quantized is True, the result will be converted to - integer in range [0,mu-1]. Otherwise, the resulting signal - is in range [-1,1] - +def mu_encode(x: np.ndarray, mu: int=255, quantized: bool=True) -> np.ndarray: + """Mu-law encoding. Encode waveform based on mu-law companding. When quantized is True, the result will be converted to integer in range `[0,mu-1]`. Otherwise, the resulting waveform is in range `[-1,1]`. - Reference: - https://en.wikipedia.org/wiki/%CE%9C-law_algorithm + Args: + x (np.ndarray): The input waveform to encode. + mu (int, optional): The endoceding parameter. Defaults to 255. + quantized (bool, optional): If `True`, quantize the encoded values into `1 + mu` distinct integer values. Defaults to True. + Returns: + np.ndarray: The mu-law encoded waveform. """ mu = 255 y = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu) @@ -564,17 +602,16 @@ def mu_encode(x: array, mu: int=255, quantized: bool=True) -> array: return y -def mu_decode(y: array, mu: int=255, quantized: bool=True) -> array: - """Mu-law decoding. - - Compute the mu-law decoding given an input code. +def mu_decode(y: np.ndarray, mu: int=255, quantized: bool=True) -> np.ndarray: + """Mu-law decoding. Compute the mu-law decoding given an input code. It assumes that the input `y` is in range `[0,mu-1]` when quantize is True and `[-1,1]` otherwise. - it assumes that the input y is in - range [0,mu-1] when quantize is True and [-1,1] otherwise - - Reference: - https://en.wikipedia.org/wiki/%CE%9C-law_algorithm + Args: + y (np.ndarray): The encoded waveform. + mu (int, optional): The endoceding parameter. Defaults to 255. + quantized (bool, optional): If `True`, the input is assumed to be quantized to `1 + mu` distinct integer values. Defaults to True. + Returns: + np.ndarray: The mu-law decoded waveform. """ if mu < 1: raise ParameterError('mu is typically set as 2**k-1, k=1, 2, 3,...') @@ -586,7 +623,7 @@ def mu_decode(y: array, mu: int=255, quantized: bool=True) -> array: return x -def randint(high: int) -> int: +def _randint(high: int) -> int: """Generate one random integer in range [0 high) This is a helper function for random data augmentaiton @@ -594,20 +631,18 @@ def randint(high: int) -> int: return int(np.random.randint(0, high=high)) -def rand() -> float: - """Generate one floating-point number in range [0 1) - - This is a helper function for random data augmentaiton - """ - return float(np.random.rand(1)) - - -def depth_augment(y: array, +def depth_augment(y: np.ndarray, choices: List=['int8', 'int16'], - probs: List[float]=[0.5, 0.5]) -> array: - """ Audio depth augmentation + probs: List[float]=[0.5, 0.5]) -> np.ndarray: + """ Audio depth augmentation. Do audio depth augmentation to simulate the distortion brought by quantization. + + Args: + y (np.ndarray): Input waveform array in 1D or 2D. + choices (List, optional): A list of data type to depth conversion. Defaults to ['int8', 'int16']. + probs (List[float], optional): Probabilities to depth conversion. Defaults to [0.5, 0.5]. - Do audio depth augmentation to simulate the distortion brought by quantization. + Returns: + np.ndarray: The augmented waveform. """ assert len(probs) == len( choices @@ -621,13 +656,18 @@ def depth_augment(y: array, return y2 -def adaptive_spect_augment(spect: array, tempo_axis: int=0, - level: float=0.1) -> array: - """Do adpative spectrogram augmentation +def adaptive_spect_augment(spect: np.ndarray, + tempo_axis: int=0, + level: float=0.1) -> np.ndarray: + """Do adpative spectrogram augmentation. The level of the augmentation is gowern by the paramter level, ranging from 0 to 1, with 0 represents no augmentation. - The level of the augmentation is gowern by the paramter level, - ranging from 0 to 1, with 0 represents no augmentation。 + Args: + spect (np.ndarray): Input spectrogram. + tempo_axis (int, optional): Indicate the tempo axis. Defaults to 0. + level (float, optional): The level factor of masking. Defaults to 0.1. + Returns: + np.ndarray: The augmented spectrogram. """ assert spect.ndim == 2., 'only supports 2d tensor or numpy array' if tempo_axis == 0: @@ -643,32 +683,40 @@ def adaptive_spect_augment(spect: array, tempo_axis: int=0, if tempo_axis == 0: for _ in range(num_time_mask): - start = randint(nt - time_mask_width) + start = _randint(nt - time_mask_width) spect[start:start + time_mask_width, :] = 0 for _ in range(num_freq_mask): - start = randint(nf - freq_mask_width) + start = _randint(nf - freq_mask_width) spect[:, start:start + freq_mask_width] = 0 else: for _ in range(num_time_mask): - start = randint(nt - time_mask_width) + start = _randint(nt - time_mask_width) spect[:, start:start + time_mask_width] = 0 for _ in range(num_freq_mask): - start = randint(nf - freq_mask_width) + start = _randint(nf - freq_mask_width) spect[start:start + freq_mask_width, :] = 0 return spect -def spect_augment(spect: array, +def spect_augment(spect: np.ndarray, tempo_axis: int=0, max_time_mask: int=3, max_freq_mask: int=3, max_time_mask_width: int=30, - max_freq_mask_width: int=20) -> array: - """Do spectrogram augmentation in both time and freq axis + max_freq_mask_width: int=20) -> np.ndarray: + """Do spectrogram augmentation in both time and freq axis. - Reference: + Args: + spect (np.ndarray): Input spectrogram. + tempo_axis (int, optional): Indicate the tempo axis. Defaults to 0. + max_time_mask (int, optional): Maximum number of time masking. Defaults to 3. + max_freq_mask (int, optional): Maximum number of frenquence masking. Defaults to 3. + max_time_mask_width (int, optional): Maximum width of time masking. Defaults to 30. + max_freq_mask_width (int, optional): Maximum width of frenquence masking. Defaults to 20. + Returns: + np.ndarray: The augmented spectrogram. """ assert spect.ndim == 2., 'only supports 2d tensor or numpy array' if tempo_axis == 0: @@ -676,52 +724,64 @@ def spect_augment(spect: array, else: nf, nt = spect.shape - num_time_mask = randint(max_time_mask) - num_freq_mask = randint(max_freq_mask) + num_time_mask = _randint(max_time_mask) + num_freq_mask = _randint(max_freq_mask) - time_mask_width = randint(max_time_mask_width) - freq_mask_width = randint(max_freq_mask_width) + time_mask_width = _randint(max_time_mask_width) + freq_mask_width = _randint(max_freq_mask_width) if tempo_axis == 0: for _ in range(num_time_mask): - start = randint(nt - time_mask_width) + start = _randint(nt - time_mask_width) spect[start:start + time_mask_width, :] = 0 for _ in range(num_freq_mask): - start = randint(nf - freq_mask_width) + start = _randint(nf - freq_mask_width) spect[:, start:start + freq_mask_width] = 0 else: for _ in range(num_time_mask): - start = randint(nt - time_mask_width) + start = _randint(nt - time_mask_width) spect[:, start:start + time_mask_width] = 0 for _ in range(num_freq_mask): - start = randint(nf - freq_mask_width) + start = _randint(nf - freq_mask_width) spect[start:start + freq_mask_width, :] = 0 return spect -def random_crop1d(y: array, crop_len: int) -> array: - """ Do random cropping on 1d input signal +def random_crop1d(y: np.ndarray, crop_len: int) -> np.ndarray: + """ Random cropping on a input waveform. - The input is a 1d signal, typically a sound waveform + Args: + y (np.ndarray): Input waveform array in 1D. + crop_len (int): Length of waveform to crop. + + Returns: + np.ndarray: The cropped waveform. """ if y.ndim != 1: 'only accept 1d tensor or numpy array' n = len(y) - idx = randint(n - crop_len) + idx = _randint(n - crop_len) return y[idx:idx + crop_len] -def random_crop2d(s: array, crop_len: int, tempo_axis: int=0) -> array: - """ Do random cropping for 2D array, typically a spectrogram. +def random_crop2d(s: np.ndarray, crop_len: int, + tempo_axis: int=0) -> np.ndarray: + """ Random cropping on a spectrogram. - The cropping is done in temporal direction on the time-freq input signal. + Args: + s (np.ndarray): Input spectrogram in 2D. + crop_len (int): Length of spectrogram to crop. + tempo_axis (int, optional): Indicate the tempo axis. Defaults to 0. + + Returns: + np.ndarray: The cropped spectrogram. """ if tempo_axis >= s.ndim: raise ParameterError('axis out of range') n = s.shape[tempo_axis] - idx = randint(high=n - crop_len) + idx = _randint(high=n - crop_len) sli = [slice(None) for i in range(s.ndim)] sli[tempo_axis] = slice(idx, idx + crop_len) out = s[tuple(sli)] diff --git a/paddleaudio/paddleaudio/features/layers.py b/paddleaudio/paddleaudio/features/layers.py index 6afd234a..877a5ae8 100644 --- a/paddleaudio/paddleaudio/features/layers.py +++ b/paddleaudio/paddleaudio/features/layers.py @@ -44,29 +44,16 @@ class Spectrogram(nn.Layer): """Compute spectrogram of a given signal, typically an audio waveform. The spectorgram is defined as the complex norm of the short-time Fourier transformation. - Parameters: - n_fft (int): the number of frequency components of the discrete Fourier transform. - The default value is 2048, - hop_length (int|None): the hop length of the short time FFT. If None, it is set to win_length//4. - The default value is None. - win_length: the window length of the short time FFt. If None, it is set to same as n_fft. - The default value is None. - window (str): the name of the window function applied to the single before the Fourier transform. - The folllowing window names are supported: 'hamming','hann','kaiser','gaussian', - 'exponential','triang','bohman','blackman','cosine','tukey','taylor'. - The default value is 'hann' - power (float): Exponent for the magnitude spectrogram. The default value is 2.0. - center (bool): if True, the signal is padded so that frame t is centered at x[t * hop_length]. - If False, frame t begins at x[t * hop_length] - The default value is True - pad_mode (str): the mode to pad the signal if necessary. The supported modes are 'reflect' - and 'constant'. The default value is 'reflect'. - dtype (str): the data type of input and window. - Notes: - The Spectrogram transform relies on STFT transform to compute the spectrogram. - By default, the weights are not learnable. To fine-tune the Fourier coefficients, - set stop_gradient=False before training. - For more information, see STFT(). + + Args: + n_fft (int, optional): The number of frequency components of the discrete Fourier transform. Defaults to 512. + hop_length (Optional[int], optional): The hop length of the short time FFT. If `None`, it is set to `win_length//4`. Defaults to None. + win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None. + window (str, optional): The window function applied to the single before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'. + power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0. + center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True. + pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'. + dtype (str, optional): Data type of input and window. Defaults to paddle.float32. """ super(Spectrogram, self).__init__()