Add paddleaudio doc.

pull/1582/head
KP 2 years ago
parent 0b4270575f
commit 8dcaef9ae9

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
from typing import Optional
from typing import Tuple
@ -19,7 +20,6 @@ from typing import Union
import numpy as np
import resampy
import soundfile as sf
from numpy import ndarray as array
from scipy.io import wavfile
from ..utils import ParameterError
@ -38,13 +38,21 @@ RESAMPLE_MODES = ['kaiser_best', 'kaiser_fast']
EPS = 1e-8
def resample(y: array, src_sr: int, target_sr: int,
mode: str='kaiser_fast') -> array:
""" Audio resampling
This function is the same as using resampy.resample().
Notes:
The default mode is kaiser_fast. For better audio quality, use mode = 'kaiser_fast'
"""
def resample(y: np.ndarray,
src_sr: int,
target_sr: int,
mode: str='kaiser_fast') -> np.ndarray:
"""Audio resampling.
Args:
y (np.ndarray): Input waveform array in 1D or 2D.
src_sr (int): Source sample rate.
target_sr (int): Target sample rate.
mode (str, optional): The resampling filter to use. Defaults to 'kaiser_fast'.
Returns:
np.ndarray: `y` resampled to `target_sr`
"""
if mode == 'kaiser_best':
warnings.warn(
@ -53,7 +61,7 @@ def resample(y: array, src_sr: int, target_sr: int,
if not isinstance(y, np.ndarray):
raise ParameterError(
'Only support numpy array, but received y in {type(y)}')
'Only support numpy np.ndarray, but received y in {type(y)}')
if mode not in RESAMPLE_MODES:
raise ParameterError(f'resample mode must in {RESAMPLE_MODES}')
@ -61,9 +69,17 @@ def resample(y: array, src_sr: int, target_sr: int,
return resampy.resample(y, src_sr, target_sr, filter=mode)
def to_mono(y: array, merge_type: str='average') -> array:
""" convert sterior audio to mono
def to_mono(y: np.ndarray, merge_type: str='average') -> np.ndarray:
"""Convert sterior audio to mono.
Args:
y (np.ndarray): Input waveform array in 1D or 2D.
merge_type (str, optional): Merge type to generate mono waveform. Defaults to 'average'.
Returns:
np.ndarray: `y` with mono channel.
"""
if merge_type not in MERGE_TYPES:
raise ParameterError(
f'Unsupported merge type {merge_type}, available types are {MERGE_TYPES}'
@ -101,18 +117,34 @@ def to_mono(y: array, merge_type: str='average') -> array:
return y_out
def _safe_cast(y: array, dtype: Union[type, str]) -> array:
""" data type casting in a safe way, i.e., prevent overflow or underflow
This function is used internally.
def _safe_cast(y: np.ndarray, dtype: Union[type, str]) -> np.ndarray:
"""Data type casting in a safe way, i.e., prevent overflow or underflow.
Args:
y (np.ndarray): Input waveform array in 1D or 2D.
dtype (Union[type, str]): Data type of waveform.
Returns:
np.ndarray: `y` after safe casting.
"""
return np.clip(y, np.iinfo(dtype).min, np.iinfo(dtype).max).astype(dtype)
if 'float' in str(y.dtype):
return np.clip(y, np.finfo(dtype).min,
np.finfo(dtype).max).astype(dtype)
else:
return np.clip(y, np.iinfo(dtype).min,
np.iinfo(dtype).max).astype(dtype)
def depth_convert(y: array, dtype: Union[type, str],
dithering: bool=True) -> array:
"""Convert audio array to target dtype safely
This function convert audio waveform to a target dtype, with addition steps of
def depth_convert(y: np.ndarray, dtype: Union[type, str]) -> np.ndarray:
"""Convert audio array to target dtype safely. This function convert audio waveform to a target dtype, with addition steps of
preventing overflow/underflow and preserving audio range.
Args:
y (np.ndarray): Input waveform array in 1D or 2D.
dtype (Union[type, str]): Data type of waveform.
Returns:
np.ndarray: `y` after safe casting.
"""
SUPPORT_DTYPE = ['int16', 'int8', 'float32', 'float64']
@ -157,14 +189,20 @@ def depth_convert(y: array, dtype: Union[type, str],
return y
def sound_file_load(file: str,
def sound_file_load(file: os.PathLike,
offset: Optional[float]=None,
dtype: str='int16',
duration: Optional[int]=None) -> Tuple[array, int]:
"""Load audio using soundfile library
This function load audio file using libsndfile.
Reference:
http://www.mega-nerd.com/libsndfile/#Features
duration: Optional[int]=None) -> Tuple[np.ndarray, int]:
"""Load audio using soundfile library. This function load audio file using libsndfile.
Args:
file (os.PathLike): File of waveform.
offset (Optional[float], optional): Offset to the start of waveform. Defaults to None.
dtype (str, optional): Data type of waveform. Defaults to 'int16'.
duration (Optional[int], optional): Duration of waveform to read. Defaults to None.
Returns:
Tuple[np.ndarray, int]: Waveform in ndarray and its samplerate.
"""
with sf.SoundFile(file) as sf_desc:
sr_native = sf_desc.samplerate
@ -179,9 +217,17 @@ def sound_file_load(file: str,
return y, sf_desc.samplerate
def normalize(y: array, norm_type: str='linear',
mul_factor: float=1.0) -> array:
""" normalize an input audio with additional multiplier.
def normalize(y: np.ndarray, norm_type: str='linear',
mul_factor: float=1.0) -> np.ndarray:
"""Normalize an input audio with additional multiplier.
Args:
y (np.ndarray): Input waveform array in 1D or 2D.
norm_type (str, optional): Type of normalization. Defaults to 'linear'.
mul_factor (float, optional): Scaling factor. Defaults to 1.0.
Returns:
np.ndarray: `y` after normalization.
"""
if norm_type == 'linear':
@ -199,12 +245,13 @@ def normalize(y: array, norm_type: str='linear',
return y
def save(y: array, sr: int, file: str) -> None:
"""Save audio file to disk.
This function saves audio to disk using scipy.io.wavfile, with additional step
to convert input waveform to int16 unless it already is int16
Notes:
It only support raw wav format.
def save(y: np.ndarray, sr: int, file: os.PathLike) -> None:
"""Save audio file to disk. This function saves audio to disk using scipy.io.wavfile, with additional step to convert input waveform to int16.
Args:
y (np.ndarray): Input waveform array in 1D or 2D.
sr (int): Sample rate.
file (os.PathLike): Path of auido file to save.
"""
if not file.endswith('.wav'):
raise ParameterError(
@ -226,7 +273,7 @@ def save(y: array, sr: int, file: str) -> None:
def load(
file: str,
file: os.PathLike,
sr: Optional[int]=None,
mono: bool=True,
merge_type: str='average', # ch0,ch1,random,average
@ -236,11 +283,24 @@ def load(
offset: float=0.0,
duration: Optional[int]=None,
dtype: str='float32',
resample_mode: str='kaiser_fast') -> Tuple[array, int]:
"""Load audio file from disk.
This function loads audio from disk using using audio beackend.
Parameters:
Notes:
resample_mode: str='kaiser_fast') -> Tuple[np.ndarray, int]:
"""Load audio file from disk. This function loads audio from disk using using audio beackend.
Args:
file (os.PathLike): Path of auido file to load.
sr (Optional[int], optional): Sample rate of loaded waveform. Defaults to None.
mono (bool, optional): Return waveform with mono channel. Defaults to True.
merge_type (str, optional): Merge type of multi-channels waveform. Defaults to 'average'.
normal (bool, optional): Waveform normalization. Defaults to True.
norm_type (str, optional): Type of normalization. Defaults to 'linear'.
norm_mul_factor (float, optional): Scaling factor. Defaults to 1.0.
offset (float, optional): Offset to the start of waveform. Defaults to 0.0.
duration (Optional[int], optional): Duration of waveform to read. Defaults to None.
dtype (str, optional): Data type of waveform. Defaults to 'float32'.
resample_mode (str, optional): The resampling filter to use. Defaults to 'kaiser_fast'.
Returns:
Tuple[np.ndarray, int]: Waveform in ndarray and its samplerate.
"""
y, r = sound_file_load(file, offset=offset, dtype=dtype, duration=duration)

@ -220,7 +220,7 @@ def spectrogram(waveform: Tensor,
"""Compute and return a spectrogram from a waveform. The output is identical to Kaldi's.
Args:
waveform (Tensor): A waveform tensor with shape [C, T].
waveform (Tensor): A waveform tensor with shape `(C, T)`.
blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42.
channel (int, optional): Select the channel of waveform. Defaults to -1.
dither (float, optional): Dithering constant . Defaults to 0.0.
@ -239,7 +239,7 @@ def spectrogram(waveform: Tensor,
window_type (str, optional): Choose type of window for FFT computation. Defaults to POVEY.
Returns:
Tensor: A spectrogram tensor with shape (m, padded_window_size // 2 + 1) where m is the number of frames
Tensor: A spectrogram tensor with shape `(m, padded_window_size // 2 + 1)` where m is the number of frames
depends on frame_length and frame_shift.
"""
dtype = waveform.dtype
@ -422,7 +422,7 @@ def fbank(waveform: Tensor,
"""Compute and return filter banks from a waveform. The output is identical to Kaldi's.
Args:
waveform (Tensor): A waveform tensor with shape [C, T].
waveform (Tensor): A waveform tensor with shape `(C, T)`.
blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42.
channel (int, optional): Select the channel of waveform. Defaults to -1.
dither (float, optional): Dithering constant . Defaults to 0.0.
@ -451,7 +451,7 @@ def fbank(waveform: Tensor,
window_type (str, optional): Choose type of window for FFT computation. Defaults to POVEY.
Returns:
Tensor: A filter banks tensor with shape (m, n_mels).
Tensor: A filter banks tensor with shape `(m, n_mels)`.
"""
dtype = waveform.dtype
@ -542,7 +542,7 @@ def mfcc(waveform: Tensor,
identical to Kaldi's.
Args:
waveform (Tensor): A waveform tensor with shape [C, T].
waveform (Tensor): A waveform tensor with shape `(C, T)`.
blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42.
cepstral_lifter (float, optional): Scaling of output mfccs. Defaults to 22.0.
channel (int, optional): Select the channel of waveform. Defaults to -1.
@ -571,7 +571,7 @@ def mfcc(waveform: Tensor,
window_type (str, optional): Choose type of window for FFT computation. Defaults to POVEY.
Returns:
Tensor: A mel frequency cepstral coefficients tensor with shape (m, n_mfcc).
Tensor: A mel frequency cepstral coefficients tensor with shape `(m, n_mfcc)`.
"""
assert n_mfcc <= n_mels, 'n_mfcc cannot be larger than n_mels: %d vs %d' % (
n_mfcc, n_mels)

@ -19,7 +19,6 @@ from typing import Union
import numpy as np
import scipy
from numpy import ndarray as array
from numpy.lib.stride_tricks import as_strided
from scipy import signal
@ -32,7 +31,6 @@ __all__ = [
'mfcc',
'hz_to_mel',
'mel_to_hz',
'split_frames',
'mel_frequencies',
'power_to_db',
'compute_fbank_matrix',
@ -49,7 +47,8 @@ __all__ = [
]
def pad_center(data: array, size: int, axis: int=-1, **kwargs) -> array:
def _pad_center(data: np.ndarray, size: int, axis: int=-1,
**kwargs) -> np.ndarray:
"""Pad an array to a target length along a target axis.
This differs from `np.pad` by centering the data prior to padding,
@ -69,8 +68,10 @@ def pad_center(data: array, size: int, axis: int=-1, **kwargs) -> array:
return np.pad(data, lengths, **kwargs)
def split_frames(x: array, frame_length: int, hop_length: int,
axis: int=-1) -> array:
def _split_frames(x: np.ndarray,
frame_length: int,
hop_length: int,
axis: int=-1) -> np.ndarray:
"""Slice a data array into (overlapping) frames.
This function is aligned with librosa.frame
@ -142,11 +143,16 @@ def _check_audio(y, mono=True) -> bool:
return True
def hz_to_mel(frequencies: Union[float, List[float], array],
htk: bool=False) -> array:
"""Convert Hz to Mels
def hz_to_mel(frequencies: Union[float, List[float], np.ndarray],
htk: bool=False) -> np.ndarray:
"""Convert Hz to Mels.
This function is aligned with librosa.
Args:
frequencies (Union[float, List[float], np.ndarray]): Frequencies in Hz.
htk (bool, optional): Use htk scaling. Defaults to False.
Returns:
np.ndarray: Frequency in mels.
"""
freq = np.asanyarray(frequencies)
@ -177,10 +183,16 @@ def hz_to_mel(frequencies: Union[float, List[float], array],
return mels
def mel_to_hz(mels: Union[float, List[float], array], htk: int=False) -> array:
def mel_to_hz(mels: Union[float, List[float], np.ndarray],
htk: int=False) -> np.ndarray:
"""Convert mel bin numbers to frequencies.
This function is aligned with librosa.
Args:
mels (Union[float, List[float], np.ndarray]): Frequency in mels.
htk (bool, optional): Use htk scaling. Defaults to False.
Returns:
np.ndarray: Frequencies in Hz.
"""
mel_array = np.asanyarray(mels)
@ -212,10 +224,17 @@ def mel_to_hz(mels: Union[float, List[float], array], htk: int=False) -> array:
def mel_frequencies(n_mels: int=128,
fmin: float=0.0,
fmax: float=11025.0,
htk: bool=False) -> array:
"""Compute mel frequencies
htk: bool=False) -> np.ndarray:
"""Compute mel frequencies.
Args:
n_mels (int, optional): Number of mel bins. Defaults to 128.
fmin (float, optional): Minimum frequency in Hz. Defaults to 0.0.
fmax (float, optional): Maximum frequency in Hz. Defaults to 11025.0.
htk (bool, optional): Use htk scaling. Defaults to False.
This function is aligned with librosa.
Returns:
np.ndarray: Vector of n_mels frequencies in Hz with shape `(n_mels,)`.
"""
# 'Center freqs' of mel bands - uniformly spaced between limits
min_mel = hz_to_mel(fmin, htk=htk)
@ -226,10 +245,15 @@ def mel_frequencies(n_mels: int=128,
return mel_to_hz(mels, htk=htk)
def fft_frequencies(sr: int, n_fft: int) -> array:
def fft_frequencies(sr: int, n_fft: int) -> np.ndarray:
"""Compute fourier frequencies.
This function is aligned with librosa.
Args:
sr (int): Sample rate.
n_fft (int): FFT size.
Returns:
np.ndarray: FFT frequencies in Hz with shape `(n_fft//2 + 1,)`.
"""
return np.linspace(0, float(sr) / 2, int(1 + n_fft // 2), endpoint=True)
@ -241,10 +265,22 @@ def compute_fbank_matrix(sr: int,
fmax: Optional[float]=None,
htk: bool=False,
norm: str="slaney",
dtype: type=np.float32):
dtype: type=np.float32) -> np.ndarray:
"""Compute fbank matrix.
This funciton is aligned with librosa.
Args:
sr (int): Sample rate.
n_fft (int): FFT size.
n_mels (int, optional): Number of mel bins. Defaults to 128.
fmin (float, optional): Minimum frequency in Hz. Defaults to 0.0.
fmax (Optional[float], optional): Maximum frequency in Hz. Defaults to None.
htk (bool, optional): Use htk scaling. Defaults to False.
norm (str, optional): Type of normalization. Defaults to "slaney".
dtype (type, optional): Data type. Defaults to np.float32.
Returns:
np.ndarray: Mel transform matrix with shape `(n_mels, n_fft//2 + 1)`.
"""
if norm != "slaney":
raise ParameterError('norm must set to slaney')
@ -289,17 +325,28 @@ def compute_fbank_matrix(sr: int,
return weights
def stft(x: array,
def stft(x: np.ndarray,
n_fft: int=2048,
hop_length: Optional[int]=None,
win_length: Optional[int]=None,
window: str="hann",
center: bool=True,
dtype: type=np.complex64,
pad_mode: str="reflect") -> array:
pad_mode: str="reflect") -> np.ndarray:
"""Short-time Fourier transform (STFT).
This function is aligned with librosa.
Args:
x (np.ndarray): Input waveform in one dimension.
n_fft (int, optional): FFT size. Defaults to 2048.
hop_length (Optional[int], optional): Number of steps to advance between adjacent windows. Defaults to None.
win_length (Optional[int], optional): The size of window. Defaults to None.
window (str, optional): A string of window specification. Defaults to "hann".
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True.
dtype (type, optional): Data type of STFT results. Defaults to np.complex64.
pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to "reflect".
Returns:
np.ndarray: The complex STFT output with shape `(n_fft//2 + 1, num_frames)`
"""
_check_audio(x)
@ -314,7 +361,7 @@ def stft(x: array,
fft_window = signal.get_window(window, win_length, fftbins=True)
# Pad the window out to n_fft size
fft_window = pad_center(fft_window, n_fft)
fft_window = _pad_center(fft_window, n_fft)
# Reshape so that the window can be broadcast
fft_window = fft_window.reshape((-1, 1))
@ -333,7 +380,7 @@ def stft(x: array,
)
# Window the time series.
x_frames = split_frames(x, frame_length=n_fft, hop_length=hop_length)
x_frames = _split_frames(x, frame_length=n_fft, hop_length=hop_length)
# Pre-allocate the STFT matrix
stft_matrix = np.empty(
(int(1 + n_fft // 2), x_frames.shape[1]), dtype=dtype, order="F")
@ -352,16 +399,20 @@ def stft(x: array,
return stft_matrix
def power_to_db(spect: array,
def power_to_db(spect: np.ndarray,
ref: float=1.0,
amin: float=1e-10,
top_db: Optional[float]=80.0) -> array:
"""Convert a power spectrogram (amplitude squared) to decibel (dB) units
top_db: Optional[float]=80.0) -> np.ndarray:
"""Convert a power spectrogram (amplitude squared) to decibel (dB) units. This computes the scaling `10 * log10(spect / ref)` in a numerically stable way.
This computes the scaling ``10 * log10(spect / ref)`` in a numerically
stable way.
Args:
spect (np.ndarray): STFT power spectrogram of an input waveform.
ref (float, optional): Scaling factor of spectrogram. Defaults to 1.0.
amin (float, optional): Minimum threshold. Defaults to 1e-10.
top_db (Optional[float], optional): Threshold the output at `top_db` below the peak. Defaults to 80.0.
This function is aligned with librosa.
Returns:
np.ndarray: Power spectrogram in db scale.
"""
spect = np.asarray(spect)
@ -394,49 +445,27 @@ def power_to_db(spect: array,
return log_spec
def mfcc(x,
def mfcc(x: np.ndarray,
sr: int=16000,
spect: Optional[array]=None,
spect: Optional[np.ndarray]=None,
n_mfcc: int=20,
dct_type: int=2,
norm: str="ortho",
lifter: int=0,
**kwargs) -> array:
**kwargs) -> np.ndarray:
"""Mel-frequency cepstral coefficients (MFCCs)
This function is NOT strictly aligned with librosa. The following example shows how to get the
same result with librosa:
# mfcc:
kwargs = {
'window_size':512,
'hop_length':320,
'mel_bins':64,
'fmin':50,
'to_db':False}
a = mfcc(x,
spect=None,
n_mfcc=20,
dct_type=2,
norm='ortho',
lifter=0,
**kwargs)
# librosa mfcc:
spect = librosa.feature.melspectrogram(y=x,sr=16000,n_fft=512,
win_length=512,
hop_length=320,
n_mels=64, fmin=50)
b = librosa.feature.mfcc(y=x,
sr=16000,
S=spect,
n_mfcc=20,
dct_type=2,
norm='ortho',
lifter=0)
assert np.mean( (a-b)**2) < 1e-8
Args:
x (np.ndarray): Input waveform in one dimension.
sr (int, optional): Sample rate. Defaults to 16000.
spect (Optional[np.ndarray], optional): Input log-power Mel spectrogram. Defaults to None.
n_mfcc (int, optional): Number of cepstra in MFCC. Defaults to 20.
dct_type (int, optional): Discrete cosine transform (DCT) type. Defaults to 2.
norm (str, optional): Type of normalization. Defaults to "ortho".
lifter (int, optional): Cepstral filtering. Defaults to 0.
Returns:
np.ndarray: A mel frequency cepstral coefficients tensor with shape `(n_mfcc, num_frames)`.
"""
if spect is None:
spect = melspectrogram(x, sr=sr, **kwargs)
@ -454,12 +483,12 @@ def mfcc(x,
f"MFCC lifter={lifter} must be a non-negative number")
def melspectrogram(x: array,
def melspectrogram(x: np.ndarray,
sr: int=16000,
window_size: int=512,
hop_length: int=320,
n_mels: int=64,
fmin: int=50,
fmin: float=50.0,
fmax: Optional[float]=None,
window: str='hann',
center: bool=True,
@ -468,27 +497,28 @@ def melspectrogram(x: array,
to_db: bool=True,
ref: float=1.0,
amin: float=1e-10,
top_db: Optional[float]=None) -> array:
top_db: Optional[float]=None) -> np.ndarray:
"""Compute mel-spectrogram.
Parameters:
x: numpy.ndarray
The input wavform is a numpy array [shape=(n,)]
window_size: int, typically 512, 1024, 2048, etc.
The window size for framing, also used as n_fft for stft
Args:
x (np.ndarray): Input waveform in one dimension.
sr (int, optional): Sample rate. Defaults to 16000.
window_size (int, optional): Size of FFT and window length. Defaults to 512.
hop_length (int, optional): Number of steps to advance between adjacent windows. Defaults to 320.
n_mels (int, optional): Number of mel bins. Defaults to 64.
fmin (float, optional): Minimum frequency in Hz. Defaults to 50.0.
fmax (Optional[float], optional): Maximum frequency in Hz. Defaults to None.
window (str, optional): A string of window specification. Defaults to "hann".
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True.
pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to "reflect".
power (float, optional): Exponent for the magnitude melspectrogram. Defaults to 2.0.
to_db (bool, optional): Enable db scale. Defaults to True.
ref (float, optional): Scaling factor of spectrogram. Defaults to 1.0.
amin (float, optional): Minimum threshold. Defaults to 1e-10.
top_db (Optional[float], optional): Threshold the output at `top_db` below the peak. Defaults to None.
Returns:
The mel-spectrogram in power scale or db scale(default)
Notes:
1. sr is default to 16000, which is commonly used in speech/speaker processing.
2. when fmax is None, it is set to sr//2.
3. this function will convert mel spectgrum to db scale by default. This is different
that of librosa.
np.ndarray: The mel-spectrogram in power scale or db scale with shape `(n_mels, num_frames)`.
"""
_check_audio(x, mono=True)
if len(x) <= 0:
@ -518,18 +548,28 @@ def melspectrogram(x: array,
return mel_spect
def spectrogram(x: array,
def spectrogram(x: np.ndarray,
sr: int=16000,
window_size: int=512,
hop_length: int=320,
window: str='hann',
center: bool=True,
pad_mode: str='reflect',
power: float=2.0) -> array:
"""Compute spectrogram from an input waveform.
power: float=2.0) -> np.ndarray:
"""Compute spectrogram.
Args:
x (np.ndarray): Input waveform in one dimension.
sr (int, optional): Sample rate. Defaults to 16000.
window_size (int, optional): Size of FFT and window length. Defaults to 512.
hop_length (int, optional): Number of steps to advance between adjacent windows. Defaults to 320.
window (str, optional): A string of window specification. Defaults to "hann".
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True.
pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to "reflect".
power (float, optional): Exponent for the magnitude melspectrogram. Defaults to 2.0.
This function is a wrapper for librosa.feature.stft, with addition step to
compute the magnitude of the complex spectrogram.
Returns:
np.ndarray: The STFT spectrogram in power scale `(n_fft//2 + 1, num_frames)`.
"""
s = stft(
@ -544,18 +584,16 @@ def spectrogram(x: array,
return np.abs(s)**power
def mu_encode(x: array, mu: int=255, quantized: bool=True) -> array:
"""Mu-law encoding.
Compute the mu-law decoding given an input code.
When quantized is True, the result will be converted to
integer in range [0,mu-1]. Otherwise, the resulting signal
is in range [-1,1]
def mu_encode(x: np.ndarray, mu: int=255, quantized: bool=True) -> np.ndarray:
"""Mu-law encoding. Encode waveform based on mu-law companding. When quantized is True, the result will be converted to integer in range `[0,mu-1]`. Otherwise, the resulting waveform is in range `[-1,1]`.
Reference:
https://en.wikipedia.org/wiki/%CE%9C-law_algorithm
Args:
x (np.ndarray): The input waveform to encode.
mu (int, optional): The endoceding parameter. Defaults to 255.
quantized (bool, optional): If `True`, quantize the encoded values into `1 + mu` distinct integer values. Defaults to True.
Returns:
np.ndarray: The mu-law encoded waveform.
"""
mu = 255
y = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu)
@ -564,17 +602,16 @@ def mu_encode(x: array, mu: int=255, quantized: bool=True) -> array:
return y
def mu_decode(y: array, mu: int=255, quantized: bool=True) -> array:
"""Mu-law decoding.
Compute the mu-law decoding given an input code.
def mu_decode(y: np.ndarray, mu: int=255, quantized: bool=True) -> np.ndarray:
"""Mu-law decoding. Compute the mu-law decoding given an input code. It assumes that the input `y` is in range `[0,mu-1]` when quantize is True and `[-1,1]` otherwise.
it assumes that the input y is in
range [0,mu-1] when quantize is True and [-1,1] otherwise
Reference:
https://en.wikipedia.org/wiki/%CE%9C-law_algorithm
Args:
y (np.ndarray): The encoded waveform.
mu (int, optional): The endoceding parameter. Defaults to 255.
quantized (bool, optional): If `True`, the input is assumed to be quantized to `1 + mu` distinct integer values. Defaults to True.
Returns:
np.ndarray: The mu-law decoded waveform.
"""
if mu < 1:
raise ParameterError('mu is typically set as 2**k-1, k=1, 2, 3,...')
@ -586,7 +623,7 @@ def mu_decode(y: array, mu: int=255, quantized: bool=True) -> array:
return x
def randint(high: int) -> int:
def _randint(high: int) -> int:
"""Generate one random integer in range [0 high)
This is a helper function for random data augmentaiton
@ -594,20 +631,18 @@ def randint(high: int) -> int:
return int(np.random.randint(0, high=high))
def rand() -> float:
"""Generate one floating-point number in range [0 1)
This is a helper function for random data augmentaiton
"""
return float(np.random.rand(1))
def depth_augment(y: array,
def depth_augment(y: np.ndarray,
choices: List=['int8', 'int16'],
probs: List[float]=[0.5, 0.5]) -> array:
""" Audio depth augmentation
probs: List[float]=[0.5, 0.5]) -> np.ndarray:
""" Audio depth augmentation. Do audio depth augmentation to simulate the distortion brought by quantization.
Args:
y (np.ndarray): Input waveform array in 1D or 2D.
choices (List, optional): A list of data type to depth conversion. Defaults to ['int8', 'int16'].
probs (List[float], optional): Probabilities to depth conversion. Defaults to [0.5, 0.5].
Do audio depth augmentation to simulate the distortion brought by quantization.
Returns:
np.ndarray: The augmented waveform.
"""
assert len(probs) == len(
choices
@ -621,13 +656,18 @@ def depth_augment(y: array,
return y2
def adaptive_spect_augment(spect: array, tempo_axis: int=0,
level: float=0.1) -> array:
"""Do adpative spectrogram augmentation
def adaptive_spect_augment(spect: np.ndarray,
tempo_axis: int=0,
level: float=0.1) -> np.ndarray:
"""Do adpative spectrogram augmentation. The level of the augmentation is gowern by the paramter level, ranging from 0 to 1, with 0 represents no augmentation.
The level of the augmentation is gowern by the paramter level,
ranging from 0 to 1, with 0 represents no augmentation
Args:
spect (np.ndarray): Input spectrogram.
tempo_axis (int, optional): Indicate the tempo axis. Defaults to 0.
level (float, optional): The level factor of masking. Defaults to 0.1.
Returns:
np.ndarray: The augmented spectrogram.
"""
assert spect.ndim == 2., 'only supports 2d tensor or numpy array'
if tempo_axis == 0:
@ -643,32 +683,40 @@ def adaptive_spect_augment(spect: array, tempo_axis: int=0,
if tempo_axis == 0:
for _ in range(num_time_mask):
start = randint(nt - time_mask_width)
start = _randint(nt - time_mask_width)
spect[start:start + time_mask_width, :] = 0
for _ in range(num_freq_mask):
start = randint(nf - freq_mask_width)
start = _randint(nf - freq_mask_width)
spect[:, start:start + freq_mask_width] = 0
else:
for _ in range(num_time_mask):
start = randint(nt - time_mask_width)
start = _randint(nt - time_mask_width)
spect[:, start:start + time_mask_width] = 0
for _ in range(num_freq_mask):
start = randint(nf - freq_mask_width)
start = _randint(nf - freq_mask_width)
spect[start:start + freq_mask_width, :] = 0
return spect
def spect_augment(spect: array,
def spect_augment(spect: np.ndarray,
tempo_axis: int=0,
max_time_mask: int=3,
max_freq_mask: int=3,
max_time_mask_width: int=30,
max_freq_mask_width: int=20) -> array:
"""Do spectrogram augmentation in both time and freq axis
max_freq_mask_width: int=20) -> np.ndarray:
"""Do spectrogram augmentation in both time and freq axis.
Reference:
Args:
spect (np.ndarray): Input spectrogram.
tempo_axis (int, optional): Indicate the tempo axis. Defaults to 0.
max_time_mask (int, optional): Maximum number of time masking. Defaults to 3.
max_freq_mask (int, optional): Maximum number of frenquence masking. Defaults to 3.
max_time_mask_width (int, optional): Maximum width of time masking. Defaults to 30.
max_freq_mask_width (int, optional): Maximum width of frenquence masking. Defaults to 20.
Returns:
np.ndarray: The augmented spectrogram.
"""
assert spect.ndim == 2., 'only supports 2d tensor or numpy array'
if tempo_axis == 0:
@ -676,52 +724,64 @@ def spect_augment(spect: array,
else:
nf, nt = spect.shape
num_time_mask = randint(max_time_mask)
num_freq_mask = randint(max_freq_mask)
num_time_mask = _randint(max_time_mask)
num_freq_mask = _randint(max_freq_mask)
time_mask_width = randint(max_time_mask_width)
freq_mask_width = randint(max_freq_mask_width)
time_mask_width = _randint(max_time_mask_width)
freq_mask_width = _randint(max_freq_mask_width)
if tempo_axis == 0:
for _ in range(num_time_mask):
start = randint(nt - time_mask_width)
start = _randint(nt - time_mask_width)
spect[start:start + time_mask_width, :] = 0
for _ in range(num_freq_mask):
start = randint(nf - freq_mask_width)
start = _randint(nf - freq_mask_width)
spect[:, start:start + freq_mask_width] = 0
else:
for _ in range(num_time_mask):
start = randint(nt - time_mask_width)
start = _randint(nt - time_mask_width)
spect[:, start:start + time_mask_width] = 0
for _ in range(num_freq_mask):
start = randint(nf - freq_mask_width)
start = _randint(nf - freq_mask_width)
spect[start:start + freq_mask_width, :] = 0
return spect
def random_crop1d(y: array, crop_len: int) -> array:
""" Do random cropping on 1d input signal
def random_crop1d(y: np.ndarray, crop_len: int) -> np.ndarray:
""" Random cropping on a input waveform.
The input is a 1d signal, typically a sound waveform
Args:
y (np.ndarray): Input waveform array in 1D.
crop_len (int): Length of waveform to crop.
Returns:
np.ndarray: The cropped waveform.
"""
if y.ndim != 1:
'only accept 1d tensor or numpy array'
n = len(y)
idx = randint(n - crop_len)
idx = _randint(n - crop_len)
return y[idx:idx + crop_len]
def random_crop2d(s: array, crop_len: int, tempo_axis: int=0) -> array:
""" Do random cropping for 2D array, typically a spectrogram.
def random_crop2d(s: np.ndarray, crop_len: int,
tempo_axis: int=0) -> np.ndarray:
""" Random cropping on a spectrogram.
The cropping is done in temporal direction on the time-freq input signal.
Args:
s (np.ndarray): Input spectrogram in 2D.
crop_len (int): Length of spectrogram to crop.
tempo_axis (int, optional): Indicate the tempo axis. Defaults to 0.
Returns:
np.ndarray: The cropped spectrogram.
"""
if tempo_axis >= s.ndim:
raise ParameterError('axis out of range')
n = s.shape[tempo_axis]
idx = randint(high=n - crop_len)
idx = _randint(high=n - crop_len)
sli = [slice(None) for i in range(s.ndim)]
sli[tempo_axis] = slice(idx, idx + crop_len)
out = s[tuple(sli)]

@ -44,29 +44,16 @@ class Spectrogram(nn.Layer):
"""Compute spectrogram of a given signal, typically an audio waveform.
The spectorgram is defined as the complex norm of the short-time
Fourier transformation.
Parameters:
n_fft (int): the number of frequency components of the discrete Fourier transform.
The default value is 2048,
hop_length (int|None): the hop length of the short time FFT. If None, it is set to win_length//4.
The default value is None.
win_length: the window length of the short time FFt. If None, it is set to same as n_fft.
The default value is None.
window (str): the name of the window function applied to the single before the Fourier transform.
The folllowing window names are supported: 'hamming','hann','kaiser','gaussian',
'exponential','triang','bohman','blackman','cosine','tukey','taylor'.
The default value is 'hann'
power (float): Exponent for the magnitude spectrogram. The default value is 2.0.
center (bool): if True, the signal is padded so that frame t is centered at x[t * hop_length].
If False, frame t begins at x[t * hop_length]
The default value is True
pad_mode (str): the mode to pad the signal if necessary. The supported modes are 'reflect'
and 'constant'. The default value is 'reflect'.
dtype (str): the data type of input and window.
Notes:
The Spectrogram transform relies on STFT transform to compute the spectrogram.
By default, the weights are not learnable. To fine-tune the Fourier coefficients,
set stop_gradient=False before training.
For more information, see STFT().
Args:
n_fft (int, optional): The number of frequency components of the discrete Fourier transform. Defaults to 512.
hop_length (Optional[int], optional): The hop length of the short time FFT. If `None`, it is set to `win_length//4`. Defaults to None.
win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None.
window (str, optional): The window function applied to the single before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'.
power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0.
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True.
pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'.
dtype (str, optional): Data type of input and window. Defaults to paddle.float32.
"""
super(Spectrogram, self).__init__()

Loading…
Cancel
Save