From c52f0f805bc92800b61b9594d873778f79304a9a Mon Sep 17 00:00:00 2001 From: KP <109694228@qq.com> Date: Wed, 2 Mar 2022 12:09:56 +0800 Subject: [PATCH] refactor --- paddleaudio/paddleaudio/__init__.py | 2 + paddleaudio/paddleaudio/backends/__init__.py | 6 + .../paddleaudio/backends/soundfile_backend.py | 252 ++++++ .../{kaldi => compliance}/__init__.py | 0 paddleaudio/paddleaudio/compliance/kaldi.py | 688 ++++++++++++++++ paddleaudio/paddleaudio/compliance/librosa.py | 728 ++++++++++++++++ .../features/{librosa.py => layers.py} | 241 +----- .../paddleaudio/functional/__init__.py | 7 + .../paddleaudio/functional/functional.py | 776 ++++-------------- paddleaudio/paddleaudio/io/__init__.py | 8 +- paddleaudio/paddleaudio/io/audio.py | 303 ------- 11 files changed, 1870 insertions(+), 1141 deletions(-) rename paddleaudio/paddleaudio/{kaldi => compliance}/__init__.py (100%) create mode 100644 paddleaudio/paddleaudio/compliance/kaldi.py create mode 100644 paddleaudio/paddleaudio/compliance/librosa.py rename paddleaudio/paddleaudio/features/{librosa.py => layers.py} (59%) delete mode 100644 paddleaudio/paddleaudio/io/audio.py diff --git a/paddleaudio/paddleaudio/__init__.py b/paddleaudio/paddleaudio/__init__.py index 185a92b8..2dab610c 100644 --- a/paddleaudio/paddleaudio/__init__.py +++ b/paddleaudio/paddleaudio/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from .backends import load +from .backends import save diff --git a/paddleaudio/paddleaudio/backends/__init__.py b/paddleaudio/paddleaudio/backends/__init__.py index 185a92b8..8eae07e8 100644 --- a/paddleaudio/paddleaudio/backends/__init__.py +++ b/paddleaudio/paddleaudio/backends/__init__.py @@ -11,3 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from .soundfile_backend import depth_convert +from .soundfile_backend import load +from .soundfile_backend import normalize +from .soundfile_backend import resample +from .soundfile_backend import save +from .soundfile_backend import to_mono diff --git a/paddleaudio/paddleaudio/backends/soundfile_backend.py b/paddleaudio/paddleaudio/backends/soundfile_backend.py index 97043fd7..2b920284 100644 --- a/paddleaudio/paddleaudio/backends/soundfile_backend.py +++ b/paddleaudio/paddleaudio/backends/soundfile_backend.py @@ -11,3 +11,255 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings +from typing import Optional +from typing import Tuple +from typing import Union + +import numpy as np +import resampy +import soundfile as sf +from numpy import ndarray as array +from scipy.io import wavfile + +from ..utils import ParameterError + +__all__ = [ + 'resample', + 'to_mono', + 'depth_convert', + 'normalize', + 'save', + 'load', +] +NORMALMIZE_TYPES = ['linear', 'gaussian'] +MERGE_TYPES = ['ch0', 'ch1', 'random', 'average'] +RESAMPLE_MODES = ['kaiser_best', 'kaiser_fast'] +EPS = 1e-8 + + +def resample(y: array, src_sr: int, target_sr: int, + mode: str='kaiser_fast') -> array: + """ Audio resampling + This function is the same as using resampy.resample(). + Notes: + The default mode is kaiser_fast. For better audio quality, use mode = 'kaiser_fast' + """ + + if mode == 'kaiser_best': + warnings.warn( + f'Using resampy in kaiser_best to {src_sr}=>{target_sr}. This function is pretty slow, \ + we recommend the mode kaiser_fast in large scale audio trainning') + + if not isinstance(y, np.ndarray): + raise ParameterError( + 'Only support numpy array, but received y in {type(y)}') + + if mode not in RESAMPLE_MODES: + raise ParameterError(f'resample mode must in {RESAMPLE_MODES}') + + return resampy.resample(y, src_sr, target_sr, filter=mode) + + +def to_mono(y: array, merge_type: str='average') -> array: + """ convert sterior audio to mono + """ + if merge_type not in MERGE_TYPES: + raise ParameterError( + f'Unsupported merge type {merge_type}, available types are {MERGE_TYPES}' + ) + if y.ndim > 2: + raise ParameterError( + f'Unsupported audio array, y.ndim > 2, the shape is {y.shape}') + if y.ndim == 1: # nothing to merge + return y + + if merge_type == 'ch0': + return y[0] + if merge_type == 'ch1': + return y[1] + if merge_type == 'random': + return y[np.random.randint(0, 2)] + + # need to do averaging according to dtype + + if y.dtype == 'float32': + y_out = (y[0] + y[1]) * 0.5 + elif y.dtype == 'int16': + y_out = y.astype('int32') + y_out = (y_out[0] + y_out[1]) // 2 + y_out = np.clip(y_out, np.iinfo(y.dtype).min, + np.iinfo(y.dtype).max).astype(y.dtype) + + elif y.dtype == 'int8': + y_out = y.astype('int16') + y_out = (y_out[0] + y_out[1]) // 2 + y_out = np.clip(y_out, np.iinfo(y.dtype).min, + np.iinfo(y.dtype).max).astype(y.dtype) + else: + raise ParameterError(f'Unsupported dtype: {y.dtype}') + return y_out + + +def _safe_cast(y: array, dtype: Union[type, str]) -> array: + """ data type casting in a safe way, i.e., prevent overflow or underflow + This function is used internally. + """ + return np.clip(y, np.iinfo(dtype).min, np.iinfo(dtype).max).astype(dtype) + + +def depth_convert(y: array, dtype: Union[type, str], + dithering: bool=True) -> array: + """Convert audio array to target dtype safely + This function convert audio waveform to a target dtype, with addition steps of + preventing overflow/underflow and preserving audio range. + """ + + SUPPORT_DTYPE = ['int16', 'int8', 'float32', 'float64'] + if y.dtype not in SUPPORT_DTYPE: + raise ParameterError( + 'Unsupported audio dtype, ' + f'y.dtype is {y.dtype}, supported dtypes are {SUPPORT_DTYPE}') + + if dtype not in SUPPORT_DTYPE: + raise ParameterError( + 'Unsupported audio dtype, ' + f'target dtype is {dtype}, supported dtypes are {SUPPORT_DTYPE}') + + if dtype == y.dtype: + return y + + if dtype == 'float64' and y.dtype == 'float32': + return _safe_cast(y, dtype) + if dtype == 'float32' and y.dtype == 'float64': + return _safe_cast(y, dtype) + + if dtype == 'int16' or dtype == 'int8': + if y.dtype in ['float64', 'float32']: + factor = np.iinfo(dtype).max + y = np.clip(y * factor, np.iinfo(dtype).min, + np.iinfo(dtype).max).astype(dtype) + y = y.astype(dtype) + else: + if dtype == 'int16' and y.dtype == 'int8': + factor = np.iinfo('int16').max / np.iinfo('int8').max - EPS + y = y.astype('float32') * factor + y = y.astype('int16') + + else: # dtype == 'int8' and y.dtype=='int16': + y = y.astype('int32') * np.iinfo('int8').max / \ + np.iinfo('int16').max + y = y.astype('int8') + + if dtype in ['float32', 'float64']: + org_dtype = y.dtype + y = y.astype(dtype) / np.iinfo(org_dtype).max + return y + + +def sound_file_load(file: str, + offset: Optional[float]=None, + dtype: str='int16', + duration: Optional[int]=None) -> Tuple[array, int]: + """Load audio using soundfile library + This function load audio file using libsndfile. + Reference: + http://www.mega-nerd.com/libsndfile/#Features + """ + with sf.SoundFile(file) as sf_desc: + sr_native = sf_desc.samplerate + if offset: + sf_desc.seek(int(offset * sr_native)) + if duration is not None: + frame_duration = int(duration * sr_native) + else: + frame_duration = -1 + y = sf_desc.read(frames=frame_duration, dtype=dtype, always_2d=False).T + + return y, sf_desc.samplerate + + +def normalize(y: array, norm_type: str='linear', + mul_factor: float=1.0) -> array: + """ normalize an input audio with additional multiplier. + """ + + if norm_type == 'linear': + amax = np.max(np.abs(y)) + factor = 1.0 / (amax + EPS) + y = y * factor * mul_factor + elif norm_type == 'gaussian': + amean = np.mean(y) + astd = np.std(y) + astd = max(astd, EPS) + y = mul_factor * (y - amean) / astd + else: + raise NotImplementedError(f'norm_type should be in {NORMALMIZE_TYPES}') + + return y + + +def save(y: array, sr: int, file: str) -> None: + """Save audio file to disk. + This function saves audio to disk using scipy.io.wavfile, with additional step + to convert input waveform to int16 unless it already is int16 + Notes: + It only support raw wav format. + """ + if not file.endswith('.wav'): + raise ParameterError( + f'only .wav file supported, but dst file name is: {file}') + + if sr <= 0: + raise ParameterError( + f'Sample rate should be larger than 0, recieved sr = {sr}') + + if y.dtype not in ['int16', 'int8']: + warnings.warn( + f'input data type is {y.dtype}, will convert data to int16 format before saving' + ) + y_out = depth_convert(y, 'int16') + else: + y_out = y + + wavfile.write(file, sr, y_out) + + +def load( + file: str, + sr: Optional[int]=None, + mono: bool=True, + merge_type: str='average', # ch0,ch1,random,average + normal: bool=True, + norm_type: str='linear', + norm_mul_factor: float=1.0, + offset: float=0.0, + duration: Optional[int]=None, + dtype: str='float32', + resample_mode: str='kaiser_fast') -> Tuple[array, int]: + """Load audio file from disk. + This function loads audio from disk using using audio beackend. + Parameters: + Notes: + """ + + y, r = sound_file_load(file, offset=offset, dtype=dtype, duration=duration) + + if not ((y.ndim == 1 and len(y) > 0) or (y.ndim == 2 and len(y[0]) > 0)): + raise ParameterError(f'audio file {file} looks empty') + + if mono: + y = to_mono(y, merge_type) + + if sr is not None and sr != r: + y = resample(y, r, sr, mode=resample_mode) + r = sr + + if normal: + y = normalize(y, norm_type, norm_mul_factor) + elif dtype in ['int8', 'int16']: + # still need to do normalization, before depth convertion + y = normalize(y, 'linear', 1.0) + + y = depth_convert(y, dtype) + return y, r diff --git a/paddleaudio/paddleaudio/kaldi/__init__.py b/paddleaudio/paddleaudio/compliance/__init__.py similarity index 100% rename from paddleaudio/paddleaudio/kaldi/__init__.py rename to paddleaudio/paddleaudio/compliance/__init__.py diff --git a/paddleaudio/paddleaudio/compliance/kaldi.py b/paddleaudio/paddleaudio/compliance/kaldi.py new file mode 100644 index 00000000..61ca4e3d --- /dev/null +++ b/paddleaudio/paddleaudio/compliance/kaldi.py @@ -0,0 +1,688 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Tuple + +import paddle +from paddle import Tensor + +from ..functional.window import get_window +from .spectrum import create_dct + +__all__ = [ + 'spectrogram', + 'fbank', + 'mfcc', +] + +# window types +HANNING = 'hann' +HAMMING = 'hamming' +POVEY = 'povey' +RECTANGULAR = 'rect' +BLACKMAN = 'blackman' + + +def _get_epsilon(dtype): + return paddle.to_tensor(1e-07, dtype=dtype) + + +def _next_power_of_2(x: int) -> int: + return 1 if x == 0 else 2**(x - 1).bit_length() + + +def _get_strided(waveform: Tensor, + window_size: int, + window_shift: int, + snip_edges: bool) -> Tensor: + assert waveform.dim() == 1 + num_samples = waveform.shape[0] + + if snip_edges: + if num_samples < window_size: + return paddle.empty((0, 0), dtype=waveform.dtype) + else: + m = 1 + (num_samples - window_size) // window_shift + else: + reversed_waveform = paddle.flip(waveform, [0]) + m = (num_samples + (window_shift // 2)) // window_shift + pad = window_size // 2 - window_shift // 2 + pad_right = reversed_waveform + if pad > 0: + pad_left = reversed_waveform[-pad:] + waveform = paddle.concat((pad_left, waveform, pad_right), axis=0) + else: + waveform = paddle.concat((waveform[-pad:], pad_right), axis=0) + + return paddle.signal.frame(waveform, window_size, window_shift)[:, :m].T + + +def _feature_window_function( + window_type: str, + window_size: int, + blackman_coeff: float, + dtype: int, ) -> Tensor: + if window_type == HANNING: + return get_window('hann', window_size, fftbins=False, dtype=dtype) + elif window_type == HAMMING: + return get_window('hamming', window_size, fftbins=False, dtype=dtype) + elif window_type == POVEY: + return get_window( + 'hann', window_size, fftbins=False, dtype=dtype).pow(0.85) + elif window_type == RECTANGULAR: + return paddle.ones([window_size], dtype=dtype) + elif window_type == BLACKMAN: + a = 2 * math.pi / (window_size - 1) + window_function = paddle.arange(window_size, dtype=dtype) + return (blackman_coeff - 0.5 * paddle.cos(a * window_function) + + (0.5 - blackman_coeff) * paddle.cos(2 * a * window_function) + ).astype(dtype) + else: + raise Exception('Invalid window type ' + window_type) + + +def _get_log_energy(strided_input: Tensor, epsilon: Tensor, + energy_floor: float) -> Tensor: + log_energy = paddle.maximum(strided_input.pow(2).sum(1), epsilon).log() + if energy_floor == 0.0: + return log_energy + return paddle.maximum( + log_energy, + paddle.to_tensor(math.log(energy_floor), dtype=strided_input.dtype)) + + +def _get_waveform_and_window_properties( + waveform: Tensor, + channel: int, + sample_frequency: float, + frame_shift: float, + frame_length: float, + round_to_power_of_two: bool, + preemphasis_coefficient: float) -> Tuple[Tensor, int, int, int]: + channel = max(channel, 0) + assert channel < waveform.shape[0], ( + 'Invalid channel {} for size {}'.format(channel, waveform.shape[0])) + waveform = waveform[channel, :] # size (n) + window_shift = int( + sample_frequency * frame_shift * + 0.001) # pass frame_shift and frame_length in milliseconds + window_size = int(sample_frequency * frame_length * 0.001) + padded_window_size = _next_power_of_2( + window_size) if round_to_power_of_two else window_size + + assert 2 <= window_size <= len(waveform), ( + 'choose a window size {} that is [2, {}]'.format(window_size, + len(waveform))) + assert 0 < window_shift, '`window_shift` must be greater than 0' + assert padded_window_size % 2 == 0, 'the padded `window_size` must be divisible by two.' \ + ' use `round_to_power_of_two` or change `frame_length`' + assert 0. <= preemphasis_coefficient <= 1.0, '`preemphasis_coefficient` must be between [0,1]' + assert sample_frequency > 0, '`sample_frequency` must be greater than zero' + return waveform, window_shift, window_size, padded_window_size + + +def _get_window(waveform: Tensor, + padded_window_size: int, + window_size: int, + window_shift: int, + window_type: str, + blackman_coeff: float, + snip_edges: bool, + raw_energy: bool, + energy_floor: float, + dither: float, + remove_dc_offset: bool, + preemphasis_coefficient: float) -> Tuple[Tensor, Tensor]: + dtype = waveform.dtype + epsilon = _get_epsilon(dtype) + + # size (m, window_size) + strided_input = _get_strided(waveform, window_size, window_shift, + snip_edges) + + if dither != 0.0: + # Returns a random number strictly between 0 and 1 + x = paddle.maximum(epsilon, + paddle.rand(strided_input.shape, dtype=dtype)) + rand_gauss = paddle.sqrt(-2 * x.log()) * paddle.cos(2 * math.pi * x) + strided_input = strided_input + rand_gauss * dither + + if remove_dc_offset: + # Subtract each row/frame by its mean + row_means = paddle.mean( + strided_input, axis=1).unsqueeze(1) # size (m, 1) + strided_input = strided_input - row_means + + if raw_energy: + # Compute the log energy of each row/frame before applying preemphasis and + # window function + signal_log_energy = _get_log_energy(strided_input, epsilon, + energy_floor) # size (m) + + if preemphasis_coefficient != 0.0: + # strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j + offset_strided_input = paddle.nn.functional.pad( + strided_input.unsqueeze(0), (1, 0), + data_format='NCL', + mode='replicate').squeeze(0) # size (m, window_size + 1) + strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, : + -1] + + # Apply window_function to each row/frame + window_function = _feature_window_function( + window_type, window_size, blackman_coeff, + dtype).unsqueeze(0) # size (1, window_size) + strided_input = strided_input * window_function # size (m, window_size) + + # Pad columns with zero until we reach size (m, padded_window_size) + if padded_window_size != window_size: + padding_right = padded_window_size - window_size + strided_input = paddle.nn.functional.pad( + strided_input.unsqueeze(0), (0, padding_right), + data_format='NCL', + mode='constant', + value=0).squeeze(0) + + # Compute energy after window function (not the raw one) + if not raw_energy: + signal_log_energy = _get_log_energy(strided_input, epsilon, + energy_floor) # size (m) + + return strided_input, signal_log_energy + + +def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor: + # subtracts the column mean of the tensor size (m, n) if subtract_mean=True + # it returns size (m, n) + if subtract_mean: + col_means = paddle.mean(tensor, axis=0).unsqueeze(0) + tensor = tensor - col_means + return tensor + + +def spectrogram(waveform: Tensor, + blackman_coeff: float=0.42, + channel: int=-1, + dither: float=0.0, + energy_floor: float=1.0, + frame_length: float=25.0, + frame_shift: float=10.0, + min_duration: float=0.0, + preemphasis_coefficient: float=0.97, + raw_energy: bool=True, + remove_dc_offset: bool=True, + round_to_power_of_two: bool=True, + sample_frequency: float=16000.0, + snip_edges: bool=True, + subtract_mean: bool=False, + window_type: str=POVEY) -> Tensor: + """[summary] + + Args: + waveform (Tensor): [description] + blackman_coeff (float, optional): [description]. Defaults to 0.42. + channel (int, optional): [description]. Defaults to -1. + dither (float, optional): [description]. Defaults to 0.0. + energy_floor (float, optional): [description]. Defaults to 1.0. + frame_length (float, optional): [description]. Defaults to 25.0. + frame_shift (float, optional): [description]. Defaults to 10.0. + min_duration (float, optional): [description]. Defaults to 0.0. + preemphasis_coefficient (float, optional): [description]. Defaults to 0.97. + raw_energy (bool, optional): [description]. Defaults to True. + remove_dc_offset (bool, optional): [description]. Defaults to True. + round_to_power_of_two (bool, optional): [description]. Defaults to True. + sample_frequency (float, optional): [description]. Defaults to 16000.0. + snip_edges (bool, optional): [description]. Defaults to True. + subtract_mean (bool, optional): [description]. Defaults to False. + window_type (str, optional): [description]. Defaults to POVEY. + + Returns: + Tensor: [description] + """ + dtype = waveform.dtype + epsilon = _get_epsilon(dtype) + + waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties( + waveform, channel, sample_frequency, frame_shift, frame_length, + round_to_power_of_two, preemphasis_coefficient) + + if len(waveform) < min_duration * sample_frequency: + # signal is too short + return paddle.empty([0]) + + strided_input, signal_log_energy = _get_window( + waveform, padded_window_size, window_size, window_shift, window_type, + blackman_coeff, snip_edges, raw_energy, energy_floor, dither, + remove_dc_offset, preemphasis_coefficient) + + # size (m, padded_window_size // 2 + 1, 2) + fft = paddle.fft.rfft(strided_input) + + # Convert the FFT into a power spectrum + power_spectrum = paddle.maximum( + fft.abs().pow(2.), + epsilon).log() # size (m, padded_window_size // 2 + 1) + power_spectrum[:, 0] = signal_log_energy + + power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean) + return power_spectrum + + +def _inverse_mel_scale_scalar(mel_freq: float) -> float: + return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0) + + +def _inverse_mel_scale(mel_freq: Tensor) -> Tensor: + return 700.0 * ((mel_freq / 1127.0).exp() - 1.0) + + +def _mel_scale_scalar(freq: float) -> float: + return 1127.0 * math.log(1.0 + freq / 700.0) + + +def _mel_scale(freq: Tensor) -> Tensor: + return 1127.0 * (1.0 + freq / 700.0).log() + + +def _vtln_warp_freq(vtln_low_cutoff: float, + vtln_high_cutoff: float, + low_freq: float, + high_freq: float, + vtln_warp_factor: float, + freq: Tensor) -> Tensor: + assert vtln_low_cutoff > low_freq, 'be sure to set the vtln_low option higher than low_freq' + assert vtln_high_cutoff < high_freq, 'be sure to set the vtln_high option lower than high_freq [or negative]' + l = vtln_low_cutoff * max(1.0, vtln_warp_factor) + h = vtln_high_cutoff * min(1.0, vtln_warp_factor) + scale = 1.0 / vtln_warp_factor + Fl = scale * l # F(l) + Fh = scale * h # F(h) + assert l > low_freq and h < high_freq + # slope of left part of the 3-piece linear function + scale_left = (Fl - low_freq) / (l - low_freq) + # [slope of center part is just "scale"] + + # slope of right part of the 3-piece linear function + scale_right = (high_freq - Fh) / (high_freq - h) + + res = paddle.empty_like(freq) + + outside_low_high_freq = paddle.less_than(freq, paddle.to_tensor(low_freq)) \ + | paddle.greater_than(freq, paddle.to_tensor(high_freq)) # freq < low_freq || freq > high_freq + before_l = paddle.less_than(freq, paddle.to_tensor(l)) # freq < l + before_h = paddle.less_than(freq, paddle.to_tensor(h)) # freq < h + after_h = paddle.greater_equal(freq, paddle.to_tensor(h)) # freq >= h + + # order of operations matter here (since there is overlapping frequency regions) + res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq) + res[before_h] = scale * freq[before_h] + res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq) + res[outside_low_high_freq] = freq[outside_low_high_freq] + + return res + + +def _vtln_warp_mel_freq(vtln_low_cutoff: float, + vtln_high_cutoff: float, + low_freq, + high_freq: float, + vtln_warp_factor: float, + mel_freq: Tensor) -> Tensor: + return _mel_scale( + _vtln_warp_freq(vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, + vtln_warp_factor, _inverse_mel_scale(mel_freq))) + + +def _get_mel_banks(num_bins: int, + window_length_padded: int, + sample_freq: float, + low_freq: float, + high_freq: float, + vtln_low: float, + vtln_high: float, + vtln_warp_factor: float) -> Tuple[Tensor, Tensor]: + assert num_bins > 3, 'Must have at least 3 mel bins' + assert window_length_padded % 2 == 0 + num_fft_bins = window_length_padded / 2 + nyquist = 0.5 * sample_freq + + if high_freq <= 0.0: + high_freq += nyquist + + assert (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq), \ + ('Bad values in options: low-freq {} and high-freq {} vs. nyquist {}'.format(low_freq, high_freq, nyquist)) + + # fft-bin width [think of it as Nyquist-freq / half-window-length] + fft_bin_width = sample_freq / window_length_padded + mel_low_freq = _mel_scale_scalar(low_freq) + mel_high_freq = _mel_scale_scalar(high_freq) + + # divide by num_bins+1 in next line because of end-effects where the bins + # spread out to the sides. + mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1) + + if vtln_high < 0.0: + vtln_high += nyquist + + assert vtln_warp_factor == 1.0 or ((low_freq < vtln_low < high_freq) and + (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)), \ + ('Bad values in options: vtln-low {} and vtln-high {}, versus ' + 'low-freq {} and high-freq {}'.format(vtln_low, vtln_high, low_freq, high_freq)) + + bin = paddle.arange(num_bins).unsqueeze(1) + left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1) + center_mel = mel_low_freq + (bin + 1.0 + ) * mel_freq_delta # size(num_bins, 1) + right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1) + + if vtln_warp_factor != 1.0: + left_mel = _vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, + vtln_warp_factor, left_mel) + center_mel = _vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, + high_freq, vtln_warp_factor, + center_mel) + right_mel = _vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, + high_freq, vtln_warp_factor, right_mel) + + center_freqs = _inverse_mel_scale(center_mel) # size (num_bins) + # size(1, num_fft_bins) + mel = _mel_scale(fft_bin_width * paddle.arange(num_fft_bins)).unsqueeze(0) + + # size (num_bins, num_fft_bins) + up_slope = (mel - left_mel) / (center_mel - left_mel) + down_slope = (right_mel - mel) / (right_mel - center_mel) + + if vtln_warp_factor == 1.0: + # left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values + bins = paddle.maximum( + paddle.zeros([1]), paddle.minimum(up_slope, down_slope)) + else: + # warping can move the order of left_mel, center_mel, right_mel anywhere + bins = paddle.zeros_like(up_slope) + up_idx = paddle.greater_than(mel, left_mel) & paddle.less_than( + mel, center_mel) # left_mel < mel <= center_mel + down_idx = paddle.greater_than(mel, center_mel) & paddle.less_than( + mel, right_mel) # center_mel < mel < right_mel + bins[up_idx] = up_slope[up_idx] + bins[down_idx] = down_slope[down_idx] + + return bins, center_freqs + + +def fbank(waveform: Tensor, + blackman_coeff: float=0.42, + channel: int=-1, + dither: float=0.0, + energy_floor: float=1.0, + frame_length: float=25.0, + frame_shift: float=10.0, + high_freq: float=0.0, + htk_compat: bool=False, + low_freq: float=20.0, + min_duration: float=0.0, + num_mel_bins: int=23, + preemphasis_coefficient: float=0.97, + raw_energy: bool=True, + remove_dc_offset: bool=True, + round_to_power_of_two: bool=True, + sample_frequency: float=16000.0, + snip_edges: bool=True, + subtract_mean: bool=False, + use_energy: bool=False, + use_log_fbank: bool=True, + use_power: bool=True, + vtln_high: float=-500.0, + vtln_low: float=100.0, + vtln_warp: float=1.0, + window_type: str=POVEY) -> Tensor: + """[summary] + + Args: + waveform (Tensor): [description] + blackman_coeff (float, optional): [description]. Defaults to 0.42. + channel (int, optional): [description]. Defaults to -1. + dither (float, optional): [description]. Defaults to 0.0. + energy_floor (float, optional): [description]. Defaults to 1.0. + frame_length (float, optional): [description]. Defaults to 25.0. + frame_shift (float, optional): [description]. Defaults to 10.0. + high_freq (float, optional): [description]. Defaults to 0.0. + htk_compat (bool, optional): [description]. Defaults to False. + low_freq (float, optional): [description]. Defaults to 20.0. + min_duration (float, optional): [description]. Defaults to 0.0. + num_mel_bins (int, optional): [description]. Defaults to 23. + preemphasis_coefficient (float, optional): [description]. Defaults to 0.97. + raw_energy (bool, optional): [description]. Defaults to True. + remove_dc_offset (bool, optional): [description]. Defaults to True. + round_to_power_of_two (bool, optional): [description]. Defaults to True. + sample_frequency (float, optional): [description]. Defaults to 16000.0. + snip_edges (bool, optional): [description]. Defaults to True. + subtract_mean (bool, optional): [description]. Defaults to False. + use_energy (bool, optional): [description]. Defaults to False. + use_log_fbank (bool, optional): [description]. Defaults to True. + use_power (bool, optional): [description]. Defaults to True. + vtln_high (float, optional): [description]. Defaults to -500.0. + vtln_low (float, optional): [description]. Defaults to 100.0. + vtln_warp (float, optional): [description]. Defaults to 1.0. + window_type (str, optional): [description]. Defaults to POVEY. + + Returns: + Tensor: [description] + """ + dtype = waveform.dtype + + waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties( + waveform, channel, sample_frequency, frame_shift, frame_length, + round_to_power_of_two, preemphasis_coefficient) + + if len(waveform) < min_duration * sample_frequency: + # signal is too short + return paddle.empty([0], dtype=dtype) + + # strided_input, size (m, padded_window_size) and signal_log_energy, size (m) + strided_input, signal_log_energy = _get_window( + waveform, padded_window_size, window_size, window_shift, window_type, + blackman_coeff, snip_edges, raw_energy, energy_floor, dither, + remove_dc_offset, preemphasis_coefficient) + + # size (m, padded_window_size // 2 + 1) + spectrum = paddle.fft.rfft(strided_input).abs() + if use_power: + spectrum = spectrum.pow(2.) + + # size (num_mel_bins, padded_window_size // 2) + mel_energies, _ = _get_mel_banks(num_mel_bins, padded_window_size, + sample_frequency, low_freq, high_freq, + vtln_low, vtln_high, vtln_warp) + mel_energies = mel_energies.astype(dtype) + + # pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1) + mel_energies = paddle.nn.functional.pad( + mel_energies.unsqueeze(0), (0, 1), + data_format='NCL', + mode='constant', + value=0).squeeze(0) + + # sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins) + mel_energies = paddle.mm(spectrum, mel_energies.T) + if use_log_fbank: + # avoid log of zero (which should be prevented anyway by dithering) + mel_energies = paddle.maximum(mel_energies, _get_epsilon(dtype)).log() + + # if use_energy then add it as the last column for htk_compat == true else first column + if use_energy: + signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1) + # returns size (m, num_mel_bins + 1) + if htk_compat: + mel_energies = paddle.concat( + (mel_energies, signal_log_energy), axis=1) + else: + mel_energies = paddle.concat( + (signal_log_energy, mel_energies), axis=1) + + mel_energies = _subtract_column_mean(mel_energies, subtract_mean) + return mel_energies + + +def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor: + # returns a dct matrix of size (num_mel_bins, num_ceps) + # size (num_mel_bins, num_mel_bins) + dct_matrix = create_dct(num_mel_bins, num_mel_bins, 'ortho') + # kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins) + # this would be the first column in the dct_matrix for torchaudio as it expects a + # right multiply (which would be the first column of the kaldi's dct_matrix as kaldi + # expects a left multiply e.g. dct_matrix * vector). + dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins)) + dct_matrix = dct_matrix[:, :num_ceps] + return dct_matrix + + +def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor: + # returns size (num_ceps) + # Compute liftering coefficients (scaling on cepstral coeffs) + # coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected. + i = paddle.arange(num_ceps) + return 1.0 + 0.5 * cepstral_lifter * paddle.sin(math.pi * i / + cepstral_lifter) + + +def mfcc(waveform: Tensor, + blackman_coeff: float=0.42, + cepstral_lifter: float=22.0, + channel: int=-1, + dither: float=0.0, + energy_floor: float=1.0, + frame_length: float=25.0, + frame_shift: float=10.0, + high_freq: float=0.0, + htk_compat: bool=False, + low_freq: float=20.0, + num_ceps: int=13, + min_duration: float=0.0, + num_mel_bins: int=23, + preemphasis_coefficient: float=0.97, + raw_energy: bool=True, + remove_dc_offset: bool=True, + round_to_power_of_two: bool=True, + sample_frequency: float=16000.0, + snip_edges: bool=True, + subtract_mean: bool=False, + use_energy: bool=False, + vtln_high: float=-500.0, + vtln_low: float=100.0, + vtln_warp: float=1.0, + window_type: str=POVEY) -> Tensor: + """[summary] + + Args: + waveform (Tensor): [description] + blackman_coeff (float, optional): [description]. Defaults to 0.42. + cepstral_lifter (float, optional): [description]. Defaults to 22.0. + channel (int, optional): [description]. Defaults to -1. + dither (float, optional): [description]. Defaults to 0.0. + energy_floor (float, optional): [description]. Defaults to 1.0. + frame_length (float, optional): [description]. Defaults to 25.0. + frame_shift (float, optional): [description]. Defaults to 10.0. + high_freq (float, optional): [description]. Defaults to 0.0. + htk_compat (bool, optional): [description]. Defaults to False. + low_freq (float, optional): [description]. Defaults to 20.0. + num_ceps (int, optional): [description]. Defaults to 13. + min_duration (float, optional): [description]. Defaults to 0.0. + num_mel_bins (int, optional): [description]. Defaults to 23. + preemphasis_coefficient (float, optional): [description]. Defaults to 0.97. + raw_energy (bool, optional): [description]. Defaults to True. + remove_dc_offset (bool, optional): [description]. Defaults to True. + round_to_power_of_two (bool, optional): [description]. Defaults to True. + sample_frequency (float, optional): [description]. Defaults to 16000.0. + snip_edges (bool, optional): [description]. Defaults to True. + subtract_mean (bool, optional): [description]. Defaults to False. + use_energy (bool, optional): [description]. Defaults to False. + vtln_high (float, optional): [description]. Defaults to -500.0. + vtln_low (float, optional): [description]. Defaults to 100.0. + vtln_warp (float, optional): [description]. Defaults to 1.0. + window_type (str, optional): [description]. Defaults to POVEY. + + Returns: + Tensor: [description] + """ + assert num_ceps <= num_mel_bins, 'num_ceps cannot be larger than num_mel_bins: %d vs %d' % ( + num_ceps, num_mel_bins) + + dtype = waveform.dtype + + # The mel_energies should not be squared (use_power=True), not have mean subtracted + # (subtract_mean=False), and use log (use_log_fbank=True). + # size (m, num_mel_bins + use_energy) + feature = fbank( + waveform=waveform, + blackman_coeff=blackman_coeff, + channel=channel, + dither=dither, + energy_floor=energy_floor, + frame_length=frame_length, + frame_shift=frame_shift, + high_freq=high_freq, + htk_compat=htk_compat, + low_freq=low_freq, + min_duration=min_duration, + num_mel_bins=num_mel_bins, + preemphasis_coefficient=preemphasis_coefficient, + raw_energy=raw_energy, + remove_dc_offset=remove_dc_offset, + round_to_power_of_two=round_to_power_of_two, + sample_frequency=sample_frequency, + snip_edges=snip_edges, + subtract_mean=False, + use_energy=use_energy, + use_log_fbank=True, + use_power=True, + vtln_high=vtln_high, + vtln_low=vtln_low, + vtln_warp=vtln_warp, + window_type=window_type) + + if use_energy: + # size (m) + signal_log_energy = feature[:, num_mel_bins if htk_compat else 0] + # offset is 0 if htk_compat==True else 1 + mel_offset = int(not htk_compat) + feature = feature[:, mel_offset:(num_mel_bins + mel_offset)] + + # size (num_mel_bins, num_ceps) + dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).astype(dtype=dtype) + + # size (m, num_ceps) + feature = feature.matmul(dct_matrix) + + if cepstral_lifter != 0.0: + # size (1, num_ceps) + lifter_coeffs = _get_lifter_coeffs(num_ceps, + cepstral_lifter).unsqueeze(0) + feature *= lifter_coeffs.astype(dtype=dtype) + + # if use_energy then replace the last column for htk_compat == true else first column + if use_energy: + feature[:, 0] = signal_log_energy + + if htk_compat: + energy = feature[:, 0].unsqueeze(1) # size (m, 1) + feature = feature[:, 1:] # size (m, num_ceps - 1) + if not use_energy: + # scale on C0 (actually removing a scale we previously added that's + # part of one common definition of the cosine transform.) + energy *= math.sqrt(2) + + feature = paddle.concat((feature, energy), axis=1) + + feature = _subtract_column_mean(feature, subtract_mean) + return feature diff --git a/paddleaudio/paddleaudio/compliance/librosa.py b/paddleaudio/paddleaudio/compliance/librosa.py new file mode 100644 index 00000000..167795c3 --- /dev/null +++ b/paddleaudio/paddleaudio/compliance/librosa.py @@ -0,0 +1,728 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from librosa(https://github.com/librosa/librosa) +import warnings +from typing import List +from typing import Optional +from typing import Union + +import numpy as np +import scipy +from numpy import ndarray as array +from numpy.lib.stride_tricks import as_strided +from scipy import signal + +from ..backends import depth_convert +from ..utils import ParameterError + +__all__ = [ + # dsp + 'stft', + 'mfcc', + 'hz_to_mel', + 'mel_to_hz', + 'split_frames', + 'mel_frequencies', + 'power_to_db', + 'compute_fbank_matrix', + 'melspectrogram', + 'spectrogram', + 'mu_encode', + 'mu_decode', + # augmentation + 'depth_augment', + 'spect_augment', + 'random_crop1d', + 'random_crop2d', + 'adaptive_spect_augment', +] + + +def pad_center(data: array, size: int, axis: int=-1, **kwargs) -> array: + """Pad an array to a target length along a target axis. + + This differs from `np.pad` by centering the data prior to padding, + analogous to `str.center` + """ + + kwargs.setdefault("mode", "constant") + n = data.shape[axis] + lpad = int((size - n) // 2) + lengths = [(0, 0)] * data.ndim + lengths[axis] = (lpad, int(size - n - lpad)) + + if lpad < 0: + raise ParameterError(("Target size ({size:d}) must be " + "at least input size ({n:d})")) + + return np.pad(data, lengths, **kwargs) + + +def split_frames(x: array, frame_length: int, hop_length: int, + axis: int=-1) -> array: + """Slice a data array into (overlapping) frames. + + This function is aligned with librosa.frame + """ + + if not isinstance(x, np.ndarray): + raise ParameterError( + f"Input must be of type numpy.ndarray, given type(x)={type(x)}") + + if x.shape[axis] < frame_length: + raise ParameterError(f"Input is too short (n={x.shape[axis]:d})" + f" for frame_length={frame_length:d}") + + if hop_length < 1: + raise ParameterError(f"Invalid hop_length: {hop_length:d}") + + if axis == -1 and not x.flags["F_CONTIGUOUS"]: + warnings.warn(f"librosa.util.frame called with axis={axis} " + "on a non-contiguous input. This will result in a copy.") + x = np.asfortranarray(x) + elif axis == 0 and not x.flags["C_CONTIGUOUS"]: + warnings.warn(f"librosa.util.frame called with axis={axis} " + "on a non-contiguous input. This will result in a copy.") + x = np.ascontiguousarray(x) + + n_frames = 1 + (x.shape[axis] - frame_length) // hop_length + strides = np.asarray(x.strides) + + new_stride = np.prod(strides[strides > 0] // x.itemsize) * x.itemsize + + if axis == -1: + shape = list(x.shape)[:-1] + [frame_length, n_frames] + strides = list(strides) + [hop_length * new_stride] + + elif axis == 0: + shape = [n_frames, frame_length] + list(x.shape)[1:] + strides = [hop_length * new_stride] + list(strides) + + else: + raise ParameterError(f"Frame axis={axis} must be either 0 or -1") + + return as_strided(x, shape=shape, strides=strides) + + +def _check_audio(y, mono=True) -> bool: + """Determine whether a variable contains valid audio data. + + The audio y must be a np.ndarray, ether 1-channel or two channel + """ + if not isinstance(y, np.ndarray): + raise ParameterError("Audio data must be of type numpy.ndarray") + if y.ndim > 2: + raise ParameterError( + f"Invalid shape for audio ndim={y.ndim:d}, shape={y.shape}") + + if mono and y.ndim == 2: + raise ParameterError( + f"Invalid shape for mono audio ndim={y.ndim:d}, shape={y.shape}") + + if (mono and len(y) == 0) or (not mono and y.shape[1] < 0): + raise ParameterError(f"Audio is empty ndim={y.ndim:d}, shape={y.shape}") + + if not np.issubdtype(y.dtype, np.floating): + raise ParameterError("Audio data must be floating-point") + + if not np.isfinite(y).all(): + raise ParameterError("Audio buffer is not finite everywhere") + + return True + + +def hz_to_mel(frequencies: Union[float, List[float], array], + htk: bool=False) -> array: + """Convert Hz to Mels + + This function is aligned with librosa. + """ + freq = np.asanyarray(frequencies) + + if htk: + return 2595.0 * np.log10(1.0 + freq / 700.0) + + # Fill in the linear part + f_min = 0.0 + f_sp = 200.0 / 3 + + mels = (freq - f_min) / f_sp + + # Fill in the log-scale part + + min_log_hz = 1000.0 # beginning of log region (Hz) + min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) + logstep = np.log(6.4) / 27.0 # step size for log region + + if freq.ndim: + # If we have array data, vectorize + log_t = freq >= min_log_hz + mels[log_t] = min_log_mel + \ + np.log(freq[log_t] / min_log_hz) / logstep + elif freq >= min_log_hz: + # If we have scalar data, heck directly + mels = min_log_mel + np.log(freq / min_log_hz) / logstep + + return mels + + +def mel_to_hz(mels: Union[float, List[float], array], htk: int=False) -> array: + """Convert mel bin numbers to frequencies. + + This function is aligned with librosa. + """ + mel_array = np.asanyarray(mels) + + if htk: + return 700.0 * (10.0**(mel_array / 2595.0) - 1.0) + + # Fill in the linear scale + f_min = 0.0 + f_sp = 200.0 / 3 + freqs = f_min + f_sp * mel_array + + # And now the nonlinear scale + min_log_hz = 1000.0 # beginning of log region (Hz) + min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) + logstep = np.log(6.4) / 27.0 # step size for log region + + if mel_array.ndim: + # If we have vector data, vectorize + log_t = mel_array >= min_log_mel + freqs[log_t] = min_log_hz * \ + np.exp(logstep * (mel_array[log_t] - min_log_mel)) + elif mel_array >= min_log_mel: + # If we have scalar data, check directly + freqs = min_log_hz * np.exp(logstep * (mel_array - min_log_mel)) + + return freqs + + +def mel_frequencies(n_mels: int=128, + fmin: float=0.0, + fmax: float=11025.0, + htk: bool=False) -> array: + """Compute mel frequencies + + This function is aligned with librosa. + """ + # 'Center freqs' of mel bands - uniformly spaced between limits + min_mel = hz_to_mel(fmin, htk=htk) + max_mel = hz_to_mel(fmax, htk=htk) + + mels = np.linspace(min_mel, max_mel, n_mels) + + return mel_to_hz(mels, htk=htk) + + +def fft_frequencies(sr: int, n_fft: int) -> array: + """Compute fourier frequencies. + + This function is aligned with librosa. + """ + return np.linspace(0, float(sr) / 2, int(1 + n_fft // 2), endpoint=True) + + +def compute_fbank_matrix(sr: int, + n_fft: int, + n_mels: int=128, + fmin: float=0.0, + fmax: Optional[float]=None, + htk: bool=False, + norm: str="slaney", + dtype: type=np.float32): + """Compute fbank matrix. + + This funciton is aligned with librosa. + """ + if norm != "slaney": + raise ParameterError('norm must set to slaney') + + if fmax is None: + fmax = float(sr) / 2 + + # Initialize the weights + n_mels = int(n_mels) + weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype) + + # Center freqs of each FFT bin + fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft) + + # 'Center freqs' of mel bands - uniformly spaced between limits + mel_f = mel_frequencies(n_mels + 2, fmin=fmin, fmax=fmax, htk=htk) + + fdiff = np.diff(mel_f) + ramps = np.subtract.outer(mel_f, fftfreqs) + + for i in range(n_mels): + # lower and upper slopes for all bins + lower = -ramps[i] / fdiff[i] + upper = ramps[i + 2] / fdiff[i + 1] + + # .. then intersect them with each other and zero + weights[i] = np.maximum(0, np.minimum(lower, upper)) + + if norm == "slaney": + # Slaney-style mel is scaled to be approx constant energy per channel + enorm = 2.0 / (mel_f[2:n_mels + 2] - mel_f[:n_mels]) + weights *= enorm[:, np.newaxis] + + # Only check weights if f_mel[0] is positive + if not np.all((mel_f[:-2] == 0) | (weights.max(axis=1) > 0)): + # This means we have an empty channel somewhere + warnings.warn("Empty filters detected in mel frequency basis. " + "Some channels will produce empty responses. " + "Try increasing your sampling rate (and fmax) or " + "reducing n_mels.") + + return weights + + +def stft(x: array, + n_fft: int=2048, + hop_length: Optional[int]=None, + win_length: Optional[int]=None, + window: str="hann", + center: bool=True, + dtype: type=np.complex64, + pad_mode: str="reflect") -> array: + """Short-time Fourier transform (STFT). + + This function is aligned with librosa. + """ + _check_audio(x) + + # By default, use the entire frame + if win_length is None: + win_length = n_fft + + # Set the default hop, if it's not already specified + if hop_length is None: + hop_length = int(win_length // 4) + + fft_window = signal.get_window(window, win_length, fftbins=True) + + # Pad the window out to n_fft size + fft_window = pad_center(fft_window, n_fft) + + # Reshape so that the window can be broadcast + fft_window = fft_window.reshape((-1, 1)) + + # Pad the time series so that frames are centered + if center: + if n_fft > x.shape[-1]: + warnings.warn( + f"n_fft={n_fft} is too small for input signal of length={x.shape[-1]}" + ) + x = np.pad(x, int(n_fft // 2), mode=pad_mode) + + elif n_fft > x.shape[-1]: + raise ParameterError( + f"n_fft={n_fft} is too small for input signal of length={x.shape[-1]}" + ) + + # Window the time series. + x_frames = split_frames(x, frame_length=n_fft, hop_length=hop_length) + # Pre-allocate the STFT matrix + stft_matrix = np.empty( + (int(1 + n_fft // 2), x_frames.shape[1]), dtype=dtype, order="F") + fft = np.fft # use numpy fft as default + # Constrain STFT block sizes to 256 KB + MAX_MEM_BLOCK = 2**8 * 2**10 + # how many columns can we fit within MAX_MEM_BLOCK? + n_columns = MAX_MEM_BLOCK // (stft_matrix.shape[0] * stft_matrix.itemsize) + n_columns = max(n_columns, 1) + + for bl_s in range(0, stft_matrix.shape[1], n_columns): + bl_t = min(bl_s + n_columns, stft_matrix.shape[1]) + stft_matrix[:, bl_s:bl_t] = fft.rfft( + fft_window * x_frames[:, bl_s:bl_t], axis=0) + + return stft_matrix + + +def power_to_db(spect: array, + ref: float=1.0, + amin: float=1e-10, + top_db: Optional[float]=80.0) -> array: + """Convert a power spectrogram (amplitude squared) to decibel (dB) units + + This computes the scaling ``10 * log10(spect / ref)`` in a numerically + stable way. + + This function is aligned with librosa. + """ + spect = np.asarray(spect) + + if amin <= 0: + raise ParameterError("amin must be strictly positive") + + if np.issubdtype(spect.dtype, np.complexfloating): + warnings.warn( + "power_to_db was called on complex input so phase " + "information will be discarded. To suppress this warning, " + "call power_to_db(np.abs(D)**2) instead.") + magnitude = np.abs(spect) + else: + magnitude = spect + + if callable(ref): + # User supplied a function to calculate reference power + ref_value = ref(magnitude) + else: + ref_value = np.abs(ref) + + log_spec = 10.0 * np.log10(np.maximum(amin, magnitude)) + log_spec -= 10.0 * np.log10(np.maximum(amin, ref_value)) + + if top_db is not None: + if top_db < 0: + raise ParameterError("top_db must be non-negative") + log_spec = np.maximum(log_spec, log_spec.max() - top_db) + + return log_spec + + +def mfcc(x, + sr: int=16000, + spect: Optional[array]=None, + n_mfcc: int=20, + dct_type: int=2, + norm: str="ortho", + lifter: int=0, + **kwargs) -> array: + """Mel-frequency cepstral coefficients (MFCCs) + + This function is NOT strictly aligned with librosa. The following example shows how to get the + same result with librosa: + + # mfcc: + kwargs = { + 'window_size':512, + 'hop_length':320, + 'mel_bins':64, + 'fmin':50, + 'to_db':False} + a = mfcc(x, + spect=None, + n_mfcc=20, + dct_type=2, + norm='ortho', + lifter=0, + **kwargs) + + # librosa mfcc: + spect = librosa.feature.melspectrogram(y=x,sr=16000,n_fft=512, + win_length=512, + hop_length=320, + n_mels=64, fmin=50) + b = librosa.feature.mfcc(y=x, + sr=16000, + S=spect, + n_mfcc=20, + dct_type=2, + norm='ortho', + lifter=0) + + assert np.mean( (a-b)**2) < 1e-8 + + """ + if spect is None: + spect = melspectrogram(x, sr=sr, **kwargs) + + M = scipy.fftpack.dct(spect, axis=0, type=dct_type, norm=norm)[:n_mfcc] + + if lifter > 0: + factor = np.sin(np.pi * np.arange(1, 1 + n_mfcc, dtype=M.dtype) / + lifter) + return M * factor[:, np.newaxis] + elif lifter == 0: + return M + else: + raise ParameterError( + f"MFCC lifter={lifter} must be a non-negative number") + + +def melspectrogram(x: array, + sr: int=16000, + window_size: int=512, + hop_length: int=320, + n_mels: int=64, + fmin: int=50, + fmax: Optional[float]=None, + window: str='hann', + center: bool=True, + pad_mode: str='reflect', + power: float=2.0, + to_db: bool=True, + ref: float=1.0, + amin: float=1e-10, + top_db: Optional[float]=None) -> array: + """Compute mel-spectrogram. + + Parameters: + x: numpy.ndarray + The input wavform is a numpy array [shape=(n,)] + + window_size: int, typically 512, 1024, 2048, etc. + The window size for framing, also used as n_fft for stft + + + Returns: + The mel-spectrogram in power scale or db scale(default) + + + Notes: + 1. sr is default to 16000, which is commonly used in speech/speaker processing. + 2. when fmax is None, it is set to sr//2. + 3. this function will convert mel spectgrum to db scale by default. This is different + that of librosa. + + """ + _check_audio(x, mono=True) + if len(x) <= 0: + raise ParameterError('The input waveform is empty') + + if fmax is None: + fmax = sr // 2 + if fmin < 0 or fmin >= fmax: + raise ParameterError('fmin and fmax must statisfy 0 array: + """Compute spectrogram from an input waveform. + + This function is a wrapper for librosa.feature.stft, with addition step to + compute the magnitude of the complex spectrogram. + """ + + s = stft( + x, + n_fft=window_size, + hop_length=hop_length, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode) + + return np.abs(s)**power + + +def mu_encode(x: array, mu: int=255, quantized: bool=True) -> array: + """Mu-law encoding. + + Compute the mu-law decoding given an input code. + When quantized is True, the result will be converted to + integer in range [0,mu-1]. Otherwise, the resulting signal + is in range [-1,1] + + + Reference: + https://en.wikipedia.org/wiki/%CE%9C-law_algorithm + + """ + mu = 255 + y = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu) + if quantized: + y = np.floor((y + 1) / 2 * mu + 0.5) # convert to [0 , mu-1] + return y + + +def mu_decode(y: array, mu: int=255, quantized: bool=True) -> array: + """Mu-law decoding. + + Compute the mu-law decoding given an input code. + + it assumes that the input y is in + range [0,mu-1] when quantize is True and [-1,1] otherwise + + Reference: + https://en.wikipedia.org/wiki/%CE%9C-law_algorithm + + """ + if mu < 1: + raise ParameterError('mu is typically set as 2**k-1, k=1, 2, 3,...') + + mu = mu - 1 + if quantized: # undo the quantization + y = y * 2 / mu - 1 + x = np.sign(y) / mu * ((1 + mu)**np.abs(y) - 1) + return x + + +def randint(high: int) -> int: + """Generate one random integer in range [0 high) + + This is a helper function for random data augmentaiton + """ + return int(np.random.randint(0, high=high)) + + +def rand() -> float: + """Generate one floating-point number in range [0 1) + + This is a helper function for random data augmentaiton + """ + return float(np.random.rand(1)) + + +def depth_augment(y: array, + choices: List=['int8', 'int16'], + probs: List[float]=[0.5, 0.5]) -> array: + """ Audio depth augmentation + + Do audio depth augmentation to simulate the distortion brought by quantization. + """ + assert len(probs) == len( + choices + ), 'number of choices {} must be equal to size of probs {}'.format( + len(choices), len(probs)) + depth = np.random.choice(choices, p=probs) + src_depth = y.dtype + y1 = depth_convert(y, depth) + y2 = depth_convert(y1, src_depth) + + return y2 + + +def adaptive_spect_augment(spect: array, tempo_axis: int=0, + level: float=0.1) -> array: + """Do adpative spectrogram augmentation + + The level of the augmentation is gowern by the paramter level, + ranging from 0 to 1, with 0 represents no augmentation。 + + """ + assert spect.ndim == 2., 'only supports 2d tensor or numpy array' + if tempo_axis == 0: + nt, nf = spect.shape + else: + nf, nt = spect.shape + + time_mask_width = int(nt * level * 0.5) + freq_mask_width = int(nf * level * 0.5) + + num_time_mask = int(10 * level) + num_freq_mask = int(10 * level) + + if tempo_axis == 0: + for _ in range(num_time_mask): + start = randint(nt - time_mask_width) + spect[start:start + time_mask_width, :] = 0 + for _ in range(num_freq_mask): + start = randint(nf - freq_mask_width) + spect[:, start:start + freq_mask_width] = 0 + else: + for _ in range(num_time_mask): + start = randint(nt - time_mask_width) + spect[:, start:start + time_mask_width] = 0 + for _ in range(num_freq_mask): + start = randint(nf - freq_mask_width) + spect[start:start + freq_mask_width, :] = 0 + + return spect + + +def spect_augment(spect: array, + tempo_axis: int=0, + max_time_mask: int=3, + max_freq_mask: int=3, + max_time_mask_width: int=30, + max_freq_mask_width: int=20) -> array: + """Do spectrogram augmentation in both time and freq axis + + Reference: + + """ + assert spect.ndim == 2., 'only supports 2d tensor or numpy array' + if tempo_axis == 0: + nt, nf = spect.shape + else: + nf, nt = spect.shape + + num_time_mask = randint(max_time_mask) + num_freq_mask = randint(max_freq_mask) + + time_mask_width = randint(max_time_mask_width) + freq_mask_width = randint(max_freq_mask_width) + + if tempo_axis == 0: + for _ in range(num_time_mask): + start = randint(nt - time_mask_width) + spect[start:start + time_mask_width, :] = 0 + for _ in range(num_freq_mask): + start = randint(nf - freq_mask_width) + spect[:, start:start + freq_mask_width] = 0 + else: + for _ in range(num_time_mask): + start = randint(nt - time_mask_width) + spect[:, start:start + time_mask_width] = 0 + for _ in range(num_freq_mask): + start = randint(nf - freq_mask_width) + spect[start:start + freq_mask_width, :] = 0 + + return spect + + +def random_crop1d(y: array, crop_len: int) -> array: + """ Do random cropping on 1d input signal + + The input is a 1d signal, typically a sound waveform + """ + if y.ndim != 1: + 'only accept 1d tensor or numpy array' + n = len(y) + idx = randint(n - crop_len) + return y[idx:idx + crop_len] + + +def random_crop2d(s: array, crop_len: int, tempo_axis: int=0) -> array: + """ Do random cropping for 2D array, typically a spectrogram. + + The cropping is done in temporal direction on the time-freq input signal. + """ + if tempo_axis >= s.ndim: + raise ParameterError('axis out of range') + + n = s.shape[tempo_axis] + idx = randint(high=n - crop_len) + sli = [slice(None) for i in range(s.ndim)] + sli[tempo_axis] = slice(idx, idx + crop_len) + out = s[tuple(sli)] + return out diff --git a/paddleaudio/paddleaudio/features/librosa.py b/paddleaudio/paddleaudio/features/layers.py similarity index 59% rename from paddleaudio/paddleaudio/features/librosa.py rename to paddleaudio/paddleaudio/features/layers.py index 1cbd2d1a..69f814d6 100644 --- a/paddleaudio/paddleaudio/features/librosa.py +++ b/paddleaudio/paddleaudio/features/layers.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math from functools import partial from typing import Optional from typing import Union @@ -19,225 +18,19 @@ from typing import Union import paddle import paddle.nn as nn +from ..functional import compute_fbank_matrix +from ..functional import create_dct +from ..functional import power_to_db from ..functional.window import get_window __all__ = [ 'Spectrogram', 'MelSpectrogram', 'LogMelSpectrogram', + 'MFCC', ] -def hz_to_mel(freq: Union[paddle.Tensor, float], - htk: bool=False) -> Union[paddle.Tensor, float]: - """Convert Hz to Mels. - Parameters: - freq: the input tensor of arbitrary shape, or a single floating point number. - htk: use HTK formula to do the conversion. - The default value is False. - Returns: - The frequencies represented in Mel-scale. - """ - - if htk: - if isinstance(freq, paddle.Tensor): - return 2595.0 * paddle.log10(1.0 + freq / 700.0) - else: - return 2595.0 * math.log10(1.0 + freq / 700.0) - - # Fill in the linear part - f_min = 0.0 - f_sp = 200.0 / 3 - - mels = (freq - f_min) / f_sp - - # Fill in the log-scale part - - min_log_hz = 1000.0 # beginning of log region (Hz) - min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) - logstep = math.log(6.4) / 27.0 # step size for log region - - if isinstance(freq, paddle.Tensor): - target = min_log_mel + paddle.log( - freq / min_log_hz + 1e-10) / logstep # prevent nan with 1e-10 - mask = (freq > min_log_hz).astype(freq.dtype) - mels = target * mask + mels * ( - 1 - mask) # will replace by masked_fill OP in future - else: - if freq >= min_log_hz: - mels = min_log_mel + math.log(freq / min_log_hz + 1e-10) / logstep - - return mels - - -def mel_to_hz(mel: Union[float, paddle.Tensor], - htk: bool=False) -> Union[float, paddle.Tensor]: - """Convert mel bin numbers to frequencies. - Parameters: - mel: the mel frequency represented as a tensor of arbitrary shape, or a floating point number. - htk: use HTK formula to do the conversion. - Returns: - The frequencies represented in hz. - """ - if htk: - return 700.0 * (10.0**(mel / 2595.0) - 1.0) - - f_min = 0.0 - f_sp = 200.0 / 3 - freqs = f_min + f_sp * mel - # And now the nonlinear scale - min_log_hz = 1000.0 # beginning of log region (Hz) - min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) - logstep = math.log(6.4) / 27.0 # step size for log region - if isinstance(mel, paddle.Tensor): - target = min_log_hz * paddle.exp(logstep * (mel - min_log_mel)) - mask = (mel > min_log_mel).astype(mel.dtype) - freqs = target * mask + freqs * ( - 1 - mask) # will replace by masked_fill OP in future - else: - if mel >= min_log_mel: - freqs = min_log_hz * math.exp(logstep * (mel - min_log_mel)) - - return freqs - - -def mel_frequencies(n_mels: int=64, - f_min: float=0.0, - f_max: float=11025.0, - htk: bool=False, - dtype: str=paddle.float32): - """Compute mel frequencies. - Parameters: - n_mels(int): number of Mel bins. - f_min(float): the lower cut-off frequency, below which the filter response is zero. - f_max(float): the upper cut-off frequency, above which the filter response is zero. - htk(bool): whether to use htk formula. - dtype(str): the datatype of the return frequencies. - Returns: - The frequencies represented in Mel-scale - """ - # 'Center freqs' of mel bands - uniformly spaced between limits - min_mel = hz_to_mel(f_min, htk=htk) - max_mel = hz_to_mel(f_max, htk=htk) - mels = paddle.linspace(min_mel, max_mel, n_mels, dtype=dtype) - freqs = mel_to_hz(mels, htk=htk) - return freqs - - -def fft_frequencies(sr: int, n_fft: int, dtype: str=paddle.float32): - """Compute fourier frequencies. - Parameters: - sr(int): the audio sample rate. - n_fft(float): the number of fft bins. - dtype(str): the datatype of the return frequencies. - Returns: - The frequencies represented in hz. - """ - return paddle.linspace(0, float(sr) / 2, int(1 + n_fft // 2), dtype=dtype) - - -def compute_fbank_matrix(sr: int, - n_fft: int, - n_mels: int=64, - f_min: float=0.0, - f_max: Optional[float]=None, - htk: bool=False, - norm: Union[str, float]='slaney', - dtype: str=paddle.float32): - """Compute fbank matrix. - Parameters: - sr(int): the audio sample rate. - n_fft(int): the number of fft bins. - n_mels(int): the number of Mel bins. - f_min(float): the lower cut-off frequency, below which the filter response is zero. - f_max(float): the upper cut-off frequency, above which the filter response is zero. - htk: whether to use htk formula. - return_complex(bool): whether to return complex matrix. If True, the matrix will - be complex type. Otherwise, the real and image part will be stored in the last - axis of returned tensor. - dtype(str): the datatype of the returned fbank matrix. - Returns: - The fbank matrix of shape (n_mels, int(1+n_fft//2)). - Shape: - output: (n_mels, int(1+n_fft//2)) - """ - - if f_max is None: - f_max = float(sr) / 2 - - # Initialize the weights - weights = paddle.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype) - - # Center freqs of each FFT bin - fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft, dtype=dtype) - - # 'Center freqs' of mel bands - uniformly spaced between limits - mel_f = mel_frequencies( - n_mels + 2, f_min=f_min, f_max=f_max, htk=htk, dtype=dtype) - - fdiff = mel_f[1:] - mel_f[:-1] #np.diff(mel_f) - ramps = mel_f.unsqueeze(1) - fftfreqs.unsqueeze(0) - #ramps = np.subtract.outer(mel_f, fftfreqs) - - for i in range(n_mels): - # lower and upper slopes for all bins - lower = -ramps[i] / fdiff[i] - upper = ramps[i + 2] / fdiff[i + 1] - - # .. then intersect them with each other and zero - weights[i] = paddle.maximum( - paddle.zeros_like(lower), paddle.minimum(lower, upper)) - - # Slaney-style mel is scaled to be approx constant energy per channel - if norm == 'slaney': - enorm = 2.0 / (mel_f[2:n_mels + 2] - mel_f[:n_mels]) - weights *= enorm.unsqueeze(1) - elif isinstance(norm, int) or isinstance(norm, float): - weights = paddle.nn.functional.normalize(weights, p=norm, axis=-1) - - return weights - - -def power_to_db(magnitude: paddle.Tensor, - ref_value: float=1.0, - amin: float=1e-10, - top_db: Optional[float]=None) -> paddle.Tensor: - """Convert a power spectrogram (amplitude squared) to decibel (dB) units. - The function computes the scaling ``10 * log10(x / ref)`` in a numerically - stable way. - Parameters: - magnitude(Tensor): the input magnitude tensor of any shape. - ref_value(float): the reference value. If smaller than 1.0, the db level - of the signal will be pulled up accordingly. Otherwise, the db level - is pushed down. - amin(float): the minimum value of input magnitude, below which the input - magnitude is clipped(to amin). - top_db(float): the maximum db value of resulting spectrum, above which the - spectrum is clipped(to top_db). - Returns: - The spectrogram in log-scale. - shape: - input: any shape - output: same as input - """ - if amin <= 0: - raise Exception("amin must be strictly positive") - - if ref_value <= 0: - raise Exception("ref_value must be strictly positive") - - ones = paddle.ones_like(magnitude) - log_spec = 10.0 * paddle.log10(paddle.maximum(ones * amin, magnitude)) - log_spec -= 10.0 * math.log10(max(ref_value, amin)) - - if top_db is not None: - if top_db < 0: - raise Exception("top_db must be non-negative") - log_spec = paddle.maximum(log_spec, ones * (log_spec.max() - top_db)) - - return log_spec - - class Spectrogram(nn.Layer): def __init__(self, n_fft: int=512, @@ -459,3 +252,29 @@ class LogMelSpectrogram(nn.Layer): amin=self.amin, top_db=self.top_db) return log_mel_feature + + +class MFCC(nn.Layer): + def __init__(self, + sr: int=22050, + n_mfcc: int=40, + norm: str='ortho', + **kwargs): + """[summary] + Parameters: + sr (int, optional): [description]. Defaults to 22050. + n_mfcc (int, optional): [description]. Defaults to 40. + norm (str, optional): [description]. Defaults to 'ortho'. + """ + super(MFCC, self).__init__() + self._log_melspectrogram = LogMelSpectrogram(sr=sr, **kwargs) + self.dct_matrix = create_dct( + n_mfcc=n_mfcc, n_mels=self._log_melspectrogram.n_mels, norm=norm) + self.register_buffer('dct_matrix', self.dct_matrix) + + def forward(self, x): + log_mel_feature = self._log_melspectrogram(x) + mfcc = paddle.matmul( + log_mel_feature.transpose((0, 2, 1)), self.dct_matrix).transpose( + (0, 2, 1)) # (B, n_mels, L) + return mfcc diff --git a/paddleaudio/paddleaudio/functional/__init__.py b/paddleaudio/paddleaudio/functional/__init__.py index 97043fd7..c85232df 100644 --- a/paddleaudio/paddleaudio/functional/__init__.py +++ b/paddleaudio/paddleaudio/functional/__init__.py @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from .functional import compute_fbank_matrix +from .functional import create_dct +from .functional import fft_frequencies +from .functional import hz_to_mel +from .functional import mel_frequencies +from .functional import mel_to_hz +from .functional import power_to_db diff --git a/paddleaudio/paddleaudio/functional/functional.py b/paddleaudio/paddleaudio/functional/functional.py index 167795c3..c07f14fd 100644 --- a/paddleaudio/paddleaudio/functional/functional.py +++ b/paddleaudio/paddleaudio/functional/functional.py @@ -12,146 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. # Modified from librosa(https://github.com/librosa/librosa) -import warnings -from typing import List +import math from typing import Optional from typing import Union -import numpy as np -import scipy -from numpy import ndarray as array -from numpy.lib.stride_tricks import as_strided -from scipy import signal - -from ..backends import depth_convert -from ..utils import ParameterError +import paddle __all__ = [ - # dsp - 'stft', - 'mfcc', 'hz_to_mel', 'mel_to_hz', - 'split_frames', 'mel_frequencies', - 'power_to_db', + 'fft_frequencies', 'compute_fbank_matrix', - 'melspectrogram', - 'spectrogram', - 'mu_encode', - 'mu_decode', - # augmentation - 'depth_augment', - 'spect_augment', - 'random_crop1d', - 'random_crop2d', - 'adaptive_spect_augment', + 'power_to_db', + 'create_dct', ] -def pad_center(data: array, size: int, axis: int=-1, **kwargs) -> array: - """Pad an array to a target length along a target axis. - - This differs from `np.pad` by centering the data prior to padding, - analogous to `str.center` - """ - - kwargs.setdefault("mode", "constant") - n = data.shape[axis] - lpad = int((size - n) // 2) - lengths = [(0, 0)] * data.ndim - lengths[axis] = (lpad, int(size - n - lpad)) - - if lpad < 0: - raise ParameterError(("Target size ({size:d}) must be " - "at least input size ({n:d})")) - - return np.pad(data, lengths, **kwargs) - - -def split_frames(x: array, frame_length: int, hop_length: int, - axis: int=-1) -> array: - """Slice a data array into (overlapping) frames. - - This function is aligned with librosa.frame - """ - - if not isinstance(x, np.ndarray): - raise ParameterError( - f"Input must be of type numpy.ndarray, given type(x)={type(x)}") - - if x.shape[axis] < frame_length: - raise ParameterError(f"Input is too short (n={x.shape[axis]:d})" - f" for frame_length={frame_length:d}") - - if hop_length < 1: - raise ParameterError(f"Invalid hop_length: {hop_length:d}") - - if axis == -1 and not x.flags["F_CONTIGUOUS"]: - warnings.warn(f"librosa.util.frame called with axis={axis} " - "on a non-contiguous input. This will result in a copy.") - x = np.asfortranarray(x) - elif axis == 0 and not x.flags["C_CONTIGUOUS"]: - warnings.warn(f"librosa.util.frame called with axis={axis} " - "on a non-contiguous input. This will result in a copy.") - x = np.ascontiguousarray(x) - - n_frames = 1 + (x.shape[axis] - frame_length) // hop_length - strides = np.asarray(x.strides) - - new_stride = np.prod(strides[strides > 0] // x.itemsize) * x.itemsize - - if axis == -1: - shape = list(x.shape)[:-1] + [frame_length, n_frames] - strides = list(strides) + [hop_length * new_stride] - - elif axis == 0: - shape = [n_frames, frame_length] + list(x.shape)[1:] - strides = [hop_length * new_stride] + list(strides) - - else: - raise ParameterError(f"Frame axis={axis} must be either 0 or -1") - - return as_strided(x, shape=shape, strides=strides) - - -def _check_audio(y, mono=True) -> bool: - """Determine whether a variable contains valid audio data. - - The audio y must be a np.ndarray, ether 1-channel or two channel - """ - if not isinstance(y, np.ndarray): - raise ParameterError("Audio data must be of type numpy.ndarray") - if y.ndim > 2: - raise ParameterError( - f"Invalid shape for audio ndim={y.ndim:d}, shape={y.shape}") - - if mono and y.ndim == 2: - raise ParameterError( - f"Invalid shape for mono audio ndim={y.ndim:d}, shape={y.shape}") - - if (mono and len(y) == 0) or (not mono and y.shape[1] < 0): - raise ParameterError(f"Audio is empty ndim={y.ndim:d}, shape={y.shape}") - - if not np.issubdtype(y.dtype, np.floating): - raise ParameterError("Audio data must be floating-point") - - if not np.isfinite(y).all(): - raise ParameterError("Audio buffer is not finite everywhere") - - return True - - -def hz_to_mel(frequencies: Union[float, List[float], array], - htk: bool=False) -> array: - """Convert Hz to Mels - - This function is aligned with librosa. +def hz_to_mel(freq: Union[paddle.Tensor, float], + htk: bool=False) -> Union[paddle.Tensor, float]: + """Convert Hz to Mels. + Parameters: + freq: the input tensor of arbitrary shape, or a single floating point number. + htk: use HTK formula to do the conversion. + The default value is False. + Returns: + The frequencies represented in Mel-scale. """ - freq = np.asanyarray(frequencies) if htk: - return 2595.0 * np.log10(1.0 + freq / 700.0) + if isinstance(freq, paddle.Tensor): + return 2595.0 * paddle.log10(1.0 + freq / 700.0) + else: + return 2595.0 * math.log10(1.0 + freq / 700.0) # Fill in the linear part f_min = 0.0 @@ -163,107 +56,129 @@ def hz_to_mel(frequencies: Union[float, List[float], array], min_log_hz = 1000.0 # beginning of log region (Hz) min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) - logstep = np.log(6.4) / 27.0 # step size for log region - - if freq.ndim: - # If we have array data, vectorize - log_t = freq >= min_log_hz - mels[log_t] = min_log_mel + \ - np.log(freq[log_t] / min_log_hz) / logstep - elif freq >= min_log_hz: - # If we have scalar data, heck directly - mels = min_log_mel + np.log(freq / min_log_hz) / logstep + logstep = math.log(6.4) / 27.0 # step size for log region + + if isinstance(freq, paddle.Tensor): + target = min_log_mel + paddle.log( + freq / min_log_hz + 1e-10) / logstep # prevent nan with 1e-10 + mask = (freq > min_log_hz).astype(freq.dtype) + mels = target * mask + mels * ( + 1 - mask) # will replace by masked_fill OP in future + else: + if freq >= min_log_hz: + mels = min_log_mel + math.log(freq / min_log_hz + 1e-10) / logstep return mels -def mel_to_hz(mels: Union[float, List[float], array], htk: int=False) -> array: +def mel_to_hz(mel: Union[float, paddle.Tensor], + htk: bool=False) -> Union[float, paddle.Tensor]: """Convert mel bin numbers to frequencies. - - This function is aligned with librosa. + Parameters: + mel: the mel frequency represented as a tensor of arbitrary shape, or a floating point number. + htk: use HTK formula to do the conversion. + Returns: + The frequencies represented in hz. """ - mel_array = np.asanyarray(mels) - if htk: - return 700.0 * (10.0**(mel_array / 2595.0) - 1.0) + return 700.0 * (10.0**(mel / 2595.0) - 1.0) - # Fill in the linear scale f_min = 0.0 f_sp = 200.0 / 3 - freqs = f_min + f_sp * mel_array - + freqs = f_min + f_sp * mel # And now the nonlinear scale min_log_hz = 1000.0 # beginning of log region (Hz) min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) - logstep = np.log(6.4) / 27.0 # step size for log region - - if mel_array.ndim: - # If we have vector data, vectorize - log_t = mel_array >= min_log_mel - freqs[log_t] = min_log_hz * \ - np.exp(logstep * (mel_array[log_t] - min_log_mel)) - elif mel_array >= min_log_mel: - # If we have scalar data, check directly - freqs = min_log_hz * np.exp(logstep * (mel_array - min_log_mel)) + logstep = math.log(6.4) / 27.0 # step size for log region + if isinstance(mel, paddle.Tensor): + target = min_log_hz * paddle.exp(logstep * (mel - min_log_mel)) + mask = (mel > min_log_mel).astype(mel.dtype) + freqs = target * mask + freqs * ( + 1 - mask) # will replace by masked_fill OP in future + else: + if mel >= min_log_mel: + freqs = min_log_hz * math.exp(logstep * (mel - min_log_mel)) return freqs -def mel_frequencies(n_mels: int=128, - fmin: float=0.0, - fmax: float=11025.0, - htk: bool=False) -> array: - """Compute mel frequencies - - This function is aligned with librosa. +def mel_frequencies(n_mels: int=64, + f_min: float=0.0, + f_max: float=11025.0, + htk: bool=False, + dtype: str=paddle.float32): + """Compute mel frequencies. + Parameters: + n_mels(int): number of Mel bins. + f_min(float): the lower cut-off frequency, below which the filter response is zero. + f_max(float): the upper cut-off frequency, above which the filter response is zero. + htk(bool): whether to use htk formula. + dtype(str): the datatype of the return frequencies. + Returns: + The frequencies represented in Mel-scale """ # 'Center freqs' of mel bands - uniformly spaced between limits - min_mel = hz_to_mel(fmin, htk=htk) - max_mel = hz_to_mel(fmax, htk=htk) - - mels = np.linspace(min_mel, max_mel, n_mels) - - return mel_to_hz(mels, htk=htk) + min_mel = hz_to_mel(f_min, htk=htk) + max_mel = hz_to_mel(f_max, htk=htk) + mels = paddle.linspace(min_mel, max_mel, n_mels, dtype=dtype) + freqs = mel_to_hz(mels, htk=htk) + return freqs -def fft_frequencies(sr: int, n_fft: int) -> array: +def fft_frequencies(sr: int, n_fft: int, dtype: str=paddle.float32): """Compute fourier frequencies. - - This function is aligned with librosa. + Parameters: + sr(int): the audio sample rate. + n_fft(float): the number of fft bins. + dtype(str): the datatype of the return frequencies. + Returns: + The frequencies represented in hz. """ - return np.linspace(0, float(sr) / 2, int(1 + n_fft // 2), endpoint=True) + return paddle.linspace(0, float(sr) / 2, int(1 + n_fft // 2), dtype=dtype) def compute_fbank_matrix(sr: int, n_fft: int, - n_mels: int=128, - fmin: float=0.0, - fmax: Optional[float]=None, + n_mels: int=64, + f_min: float=0.0, + f_max: Optional[float]=None, htk: bool=False, - norm: str="slaney", - dtype: type=np.float32): + norm: Union[str, float]='slaney', + dtype: str=paddle.float32): """Compute fbank matrix. - - This funciton is aligned with librosa. + Parameters: + sr(int): the audio sample rate. + n_fft(int): the number of fft bins. + n_mels(int): the number of Mel bins. + f_min(float): the lower cut-off frequency, below which the filter response is zero. + f_max(float): the upper cut-off frequency, above which the filter response is zero. + htk: whether to use htk formula. + return_complex(bool): whether to return complex matrix. If True, the matrix will + be complex type. Otherwise, the real and image part will be stored in the last + axis of returned tensor. + dtype(str): the datatype of the returned fbank matrix. + Returns: + The fbank matrix of shape (n_mels, int(1+n_fft//2)). + Shape: + output: (n_mels, int(1+n_fft//2)) """ - if norm != "slaney": - raise ParameterError('norm must set to slaney') - if fmax is None: - fmax = float(sr) / 2 + if f_max is None: + f_max = float(sr) / 2 # Initialize the weights - n_mels = int(n_mels) - weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype) + weights = paddle.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype) # Center freqs of each FFT bin - fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft) + fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft, dtype=dtype) # 'Center freqs' of mel bands - uniformly spaced between limits - mel_f = mel_frequencies(n_mels + 2, fmin=fmin, fmax=fmax, htk=htk) + mel_f = mel_frequencies( + n_mels + 2, f_min=f_min, f_max=f_max, htk=htk, dtype=dtype) - fdiff = np.diff(mel_f) - ramps = np.subtract.outer(mel_f, fftfreqs) + fdiff = mel_f[1:] - mel_f[:-1] #np.diff(mel_f) + ramps = mel_f.unsqueeze(1) - fftfreqs.unsqueeze(0) + #ramps = np.subtract.outer(mel_f, fftfreqs) for i in range(n_mels): # lower and upper slopes for all bins @@ -271,458 +186,79 @@ def compute_fbank_matrix(sr: int, upper = ramps[i + 2] / fdiff[i + 1] # .. then intersect them with each other and zero - weights[i] = np.maximum(0, np.minimum(lower, upper)) + weights[i] = paddle.maximum( + paddle.zeros_like(lower), paddle.minimum(lower, upper)) - if norm == "slaney": - # Slaney-style mel is scaled to be approx constant energy per channel + # Slaney-style mel is scaled to be approx constant energy per channel + if norm == 'slaney': enorm = 2.0 / (mel_f[2:n_mels + 2] - mel_f[:n_mels]) - weights *= enorm[:, np.newaxis] - - # Only check weights if f_mel[0] is positive - if not np.all((mel_f[:-2] == 0) | (weights.max(axis=1) > 0)): - # This means we have an empty channel somewhere - warnings.warn("Empty filters detected in mel frequency basis. " - "Some channels will produce empty responses. " - "Try increasing your sampling rate (and fmax) or " - "reducing n_mels.") + weights *= enorm.unsqueeze(1) + elif isinstance(norm, int) or isinstance(norm, float): + weights = paddle.nn.functional.normalize(weights, p=norm, axis=-1) return weights -def stft(x: array, - n_fft: int=2048, - hop_length: Optional[int]=None, - win_length: Optional[int]=None, - window: str="hann", - center: bool=True, - dtype: type=np.complex64, - pad_mode: str="reflect") -> array: - """Short-time Fourier transform (STFT). - - This function is aligned with librosa. - """ - _check_audio(x) - - # By default, use the entire frame - if win_length is None: - win_length = n_fft - - # Set the default hop, if it's not already specified - if hop_length is None: - hop_length = int(win_length // 4) - - fft_window = signal.get_window(window, win_length, fftbins=True) - - # Pad the window out to n_fft size - fft_window = pad_center(fft_window, n_fft) - - # Reshape so that the window can be broadcast - fft_window = fft_window.reshape((-1, 1)) - - # Pad the time series so that frames are centered - if center: - if n_fft > x.shape[-1]: - warnings.warn( - f"n_fft={n_fft} is too small for input signal of length={x.shape[-1]}" - ) - x = np.pad(x, int(n_fft // 2), mode=pad_mode) - - elif n_fft > x.shape[-1]: - raise ParameterError( - f"n_fft={n_fft} is too small for input signal of length={x.shape[-1]}" - ) - - # Window the time series. - x_frames = split_frames(x, frame_length=n_fft, hop_length=hop_length) - # Pre-allocate the STFT matrix - stft_matrix = np.empty( - (int(1 + n_fft // 2), x_frames.shape[1]), dtype=dtype, order="F") - fft = np.fft # use numpy fft as default - # Constrain STFT block sizes to 256 KB - MAX_MEM_BLOCK = 2**8 * 2**10 - # how many columns can we fit within MAX_MEM_BLOCK? - n_columns = MAX_MEM_BLOCK // (stft_matrix.shape[0] * stft_matrix.itemsize) - n_columns = max(n_columns, 1) - - for bl_s in range(0, stft_matrix.shape[1], n_columns): - bl_t = min(bl_s + n_columns, stft_matrix.shape[1]) - stft_matrix[:, bl_s:bl_t] = fft.rfft( - fft_window * x_frames[:, bl_s:bl_t], axis=0) - - return stft_matrix - - -def power_to_db(spect: array, - ref: float=1.0, +def power_to_db(magnitude: paddle.Tensor, + ref_value: float=1.0, amin: float=1e-10, - top_db: Optional[float]=80.0) -> array: - """Convert a power spectrogram (amplitude squared) to decibel (dB) units - - This computes the scaling ``10 * log10(spect / ref)`` in a numerically + top_db: Optional[float]=None) -> paddle.Tensor: + """Convert a power spectrogram (amplitude squared) to decibel (dB) units. + The function computes the scaling ``10 * log10(x / ref)`` in a numerically stable way. - - This function is aligned with librosa. + Parameters: + magnitude(Tensor): the input magnitude tensor of any shape. + ref_value(float): the reference value. If smaller than 1.0, the db level + of the signal will be pulled up accordingly. Otherwise, the db level + is pushed down. + amin(float): the minimum value of input magnitude, below which the input + magnitude is clipped(to amin). + top_db(float): the maximum db value of resulting spectrum, above which the + spectrum is clipped(to top_db). + Returns: + The spectrogram in log-scale. + shape: + input: any shape + output: same as input """ - spect = np.asarray(spect) - if amin <= 0: - raise ParameterError("amin must be strictly positive") - - if np.issubdtype(spect.dtype, np.complexfloating): - warnings.warn( - "power_to_db was called on complex input so phase " - "information will be discarded. To suppress this warning, " - "call power_to_db(np.abs(D)**2) instead.") - magnitude = np.abs(spect) - else: - magnitude = spect + raise Exception("amin must be strictly positive") - if callable(ref): - # User supplied a function to calculate reference power - ref_value = ref(magnitude) - else: - ref_value = np.abs(ref) + if ref_value <= 0: + raise Exception("ref_value must be strictly positive") - log_spec = 10.0 * np.log10(np.maximum(amin, magnitude)) - log_spec -= 10.0 * np.log10(np.maximum(amin, ref_value)) + ones = paddle.ones_like(magnitude) + log_spec = 10.0 * paddle.log10(paddle.maximum(ones * amin, magnitude)) + log_spec -= 10.0 * math.log10(max(ref_value, amin)) if top_db is not None: if top_db < 0: - raise ParameterError("top_db must be non-negative") - log_spec = np.maximum(log_spec, log_spec.max() - top_db) + raise Exception("top_db must be non-negative") + log_spec = paddle.maximum(log_spec, ones * (log_spec.max() - top_db)) return log_spec -def mfcc(x, - sr: int=16000, - spect: Optional[array]=None, - n_mfcc: int=20, - dct_type: int=2, - norm: str="ortho", - lifter: int=0, - **kwargs) -> array: - """Mel-frequency cepstral coefficients (MFCCs) - - This function is NOT strictly aligned with librosa. The following example shows how to get the - same result with librosa: - - # mfcc: - kwargs = { - 'window_size':512, - 'hop_length':320, - 'mel_bins':64, - 'fmin':50, - 'to_db':False} - a = mfcc(x, - spect=None, - n_mfcc=20, - dct_type=2, - norm='ortho', - lifter=0, - **kwargs) - - # librosa mfcc: - spect = librosa.feature.melspectrogram(y=x,sr=16000,n_fft=512, - win_length=512, - hop_length=320, - n_mels=64, fmin=50) - b = librosa.feature.mfcc(y=x, - sr=16000, - S=spect, - n_mfcc=20, - dct_type=2, - norm='ortho', - lifter=0) - - assert np.mean( (a-b)**2) < 1e-8 - - """ - if spect is None: - spect = melspectrogram(x, sr=sr, **kwargs) - - M = scipy.fftpack.dct(spect, axis=0, type=dct_type, norm=norm)[:n_mfcc] - - if lifter > 0: - factor = np.sin(np.pi * np.arange(1, 1 + n_mfcc, dtype=M.dtype) / - lifter) - return M * factor[:, np.newaxis] - elif lifter == 0: - return M - else: - raise ParameterError( - f"MFCC lifter={lifter} must be a non-negative number") - - -def melspectrogram(x: array, - sr: int=16000, - window_size: int=512, - hop_length: int=320, - n_mels: int=64, - fmin: int=50, - fmax: Optional[float]=None, - window: str='hann', - center: bool=True, - pad_mode: str='reflect', - power: float=2.0, - to_db: bool=True, - ref: float=1.0, - amin: float=1e-10, - top_db: Optional[float]=None) -> array: - """Compute mel-spectrogram. - +def create_dct(n_mfcc: int, + n_mels: int, + norm: Optional[str]='ortho', + dtype: Optional[str]=paddle.float32): + """[summary] Parameters: - x: numpy.ndarray - The input wavform is a numpy array [shape=(n,)] - - window_size: int, typically 512, 1024, 2048, etc. - The window size for framing, also used as n_fft for stft - - + n_mfcc (int): [description] + n_mels (int): [description] + norm (str, optional): [description]. Defaults to 'ortho'. Returns: - The mel-spectrogram in power scale or db scale(default) - - - Notes: - 1. sr is default to 16000, which is commonly used in speech/speaker processing. - 2. when fmax is None, it is set to sr//2. - 3. this function will convert mel spectgrum to db scale by default. This is different - that of librosa. - - """ - _check_audio(x, mono=True) - if len(x) <= 0: - raise ParameterError('The input waveform is empty') - - if fmax is None: - fmax = sr // 2 - if fmin < 0 or fmin >= fmax: - raise ParameterError('fmin and fmax must statisfy 0 array: - """Compute spectrogram from an input waveform. - - This function is a wrapper for librosa.feature.stft, with addition step to - compute the magnitude of the complex spectrogram. - """ - - s = stft( - x, - n_fft=window_size, - hop_length=hop_length, - win_length=window_size, - window=window, - center=center, - pad_mode=pad_mode) - - return np.abs(s)**power - - -def mu_encode(x: array, mu: int=255, quantized: bool=True) -> array: - """Mu-law encoding. - - Compute the mu-law decoding given an input code. - When quantized is True, the result will be converted to - integer in range [0,mu-1]. Otherwise, the resulting signal - is in range [-1,1] - - - Reference: - https://en.wikipedia.org/wiki/%CE%9C-law_algorithm - - """ - mu = 255 - y = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu) - if quantized: - y = np.floor((y + 1) / 2 * mu + 0.5) # convert to [0 , mu-1] - return y - - -def mu_decode(y: array, mu: int=255, quantized: bool=True) -> array: - """Mu-law decoding. - - Compute the mu-law decoding given an input code. - - it assumes that the input y is in - range [0,mu-1] when quantize is True and [-1,1] otherwise - - Reference: - https://en.wikipedia.org/wiki/%CE%9C-law_algorithm - - """ - if mu < 1: - raise ParameterError('mu is typically set as 2**k-1, k=1, 2, 3,...') - - mu = mu - 1 - if quantized: # undo the quantization - y = y * 2 / mu - 1 - x = np.sign(y) / mu * ((1 + mu)**np.abs(y) - 1) - return x - - -def randint(high: int) -> int: - """Generate one random integer in range [0 high) - - This is a helper function for random data augmentaiton - """ - return int(np.random.randint(0, high=high)) - - -def rand() -> float: - """Generate one floating-point number in range [0 1) - - This is a helper function for random data augmentaiton - """ - return float(np.random.rand(1)) - - -def depth_augment(y: array, - choices: List=['int8', 'int16'], - probs: List[float]=[0.5, 0.5]) -> array: - """ Audio depth augmentation - - Do audio depth augmentation to simulate the distortion brought by quantization. - """ - assert len(probs) == len( - choices - ), 'number of choices {} must be equal to size of probs {}'.format( - len(choices), len(probs)) - depth = np.random.choice(choices, p=probs) - src_depth = y.dtype - y1 = depth_convert(y, depth) - y2 = depth_convert(y1, src_depth) - - return y2 - - -def adaptive_spect_augment(spect: array, tempo_axis: int=0, - level: float=0.1) -> array: - """Do adpative spectrogram augmentation - - The level of the augmentation is gowern by the paramter level, - ranging from 0 to 1, with 0 represents no augmentation。 - - """ - assert spect.ndim == 2., 'only supports 2d tensor or numpy array' - if tempo_axis == 0: - nt, nf = spect.shape - else: - nf, nt = spect.shape - - time_mask_width = int(nt * level * 0.5) - freq_mask_width = int(nf * level * 0.5) - - num_time_mask = int(10 * level) - num_freq_mask = int(10 * level) - - if tempo_axis == 0: - for _ in range(num_time_mask): - start = randint(nt - time_mask_width) - spect[start:start + time_mask_width, :] = 0 - for _ in range(num_freq_mask): - start = randint(nf - freq_mask_width) - spect[:, start:start + freq_mask_width] = 0 - else: - for _ in range(num_time_mask): - start = randint(nt - time_mask_width) - spect[:, start:start + time_mask_width] = 0 - for _ in range(num_freq_mask): - start = randint(nf - freq_mask_width) - spect[start:start + freq_mask_width, :] = 0 - - return spect - - -def spect_augment(spect: array, - tempo_axis: int=0, - max_time_mask: int=3, - max_freq_mask: int=3, - max_time_mask_width: int=30, - max_freq_mask_width: int=20) -> array: - """Do spectrogram augmentation in both time and freq axis - - Reference: - + [type]: [description] """ - assert spect.ndim == 2., 'only supports 2d tensor or numpy array' - if tempo_axis == 0: - nt, nf = spect.shape - else: - nf, nt = spect.shape - - num_time_mask = randint(max_time_mask) - num_freq_mask = randint(max_freq_mask) - - time_mask_width = randint(max_time_mask_width) - freq_mask_width = randint(max_freq_mask_width) - - if tempo_axis == 0: - for _ in range(num_time_mask): - start = randint(nt - time_mask_width) - spect[start:start + time_mask_width, :] = 0 - for _ in range(num_freq_mask): - start = randint(nf - freq_mask_width) - spect[:, start:start + freq_mask_width] = 0 + n = paddle.arange(n_mels, dtype=dtype) + k = paddle.arange(n_mfcc, dtype=dtype).unsqueeze(1) + dct = paddle.cos(math.pi / float(n_mels) * (n + 0.5) * + k) # size (n_mfcc, n_mels) + if norm is None: + dct *= 2.0 else: - for _ in range(num_time_mask): - start = randint(nt - time_mask_width) - spect[:, start:start + time_mask_width] = 0 - for _ in range(num_freq_mask): - start = randint(nf - freq_mask_width) - spect[start:start + freq_mask_width, :] = 0 - - return spect - - -def random_crop1d(y: array, crop_len: int) -> array: - """ Do random cropping on 1d input signal - - The input is a 1d signal, typically a sound waveform - """ - if y.ndim != 1: - 'only accept 1d tensor or numpy array' - n = len(y) - idx = randint(n - crop_len) - return y[idx:idx + crop_len] - - -def random_crop2d(s: array, crop_len: int, tempo_axis: int=0) -> array: - """ Do random cropping for 2D array, typically a spectrogram. - - The cropping is done in temporal direction on the time-freq input signal. - """ - if tempo_axis >= s.ndim: - raise ParameterError('axis out of range') - - n = s.shape[tempo_axis] - idx = randint(high=n - crop_len) - sli = [slice(None) for i in range(s.ndim)] - sli[tempo_axis] = slice(idx, idx + crop_len) - out = s[tuple(sli)] - return out + assert norm == "ortho" + dct[0] *= 1.0 / math.sqrt(2.0) + dct *= math.sqrt(2.0 / float(n_mels)) + return dct.T diff --git a/paddleaudio/paddleaudio/io/__init__.py b/paddleaudio/paddleaudio/io/__init__.py index cc2538f7..185a92b8 100644 --- a/paddleaudio/paddleaudio/io/__init__.py +++ b/paddleaudio/paddleaudio/io/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,9 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .audio import depth_convert -from .audio import load -from .audio import normalize -from .audio import resample -from .audio import save_wav -from .audio import to_mono diff --git a/paddleaudio/paddleaudio/io/audio.py b/paddleaudio/paddleaudio/io/audio.py deleted file mode 100644 index 4127570e..00000000 --- a/paddleaudio/paddleaudio/io/audio.py +++ /dev/null @@ -1,303 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import warnings -from typing import Optional -from typing import Tuple -from typing import Union - -import numpy as np -import resampy -import soundfile as sf -from numpy import ndarray as array -from scipy.io import wavfile - -from ..utils import ParameterError - -__all__ = [ - 'resample', - 'to_mono', - 'depth_convert', - 'normalize', - 'save_wav', - 'load', -] -NORMALMIZE_TYPES = ['linear', 'gaussian'] -MERGE_TYPES = ['ch0', 'ch1', 'random', 'average'] -RESAMPLE_MODES = ['kaiser_best', 'kaiser_fast'] -EPS = 1e-8 - - -def resample(y: array, src_sr: int, target_sr: int, - mode: str='kaiser_fast') -> array: - """ Audio resampling - - This function is the same as using resampy.resample(). - - Notes: - The default mode is kaiser_fast. For better audio quality, use mode = 'kaiser_fast' - - """ - - if mode == 'kaiser_best': - warnings.warn( - f'Using resampy in kaiser_best to {src_sr}=>{target_sr}. This function is pretty slow, \ - we recommend the mode kaiser_fast in large scale audio trainning') - - if not isinstance(y, np.ndarray): - raise ParameterError( - 'Only support numpy array, but received y in {type(y)}') - - if mode not in RESAMPLE_MODES: - raise ParameterError(f'resample mode must in {RESAMPLE_MODES}') - - return resampy.resample(y, src_sr, target_sr, filter=mode) - - -def to_mono(y: array, merge_type: str='average') -> array: - """ convert sterior audio to mono - """ - if merge_type not in MERGE_TYPES: - raise ParameterError( - f'Unsupported merge type {merge_type}, available types are {MERGE_TYPES}' - ) - if y.ndim > 2: - raise ParameterError( - f'Unsupported audio array, y.ndim > 2, the shape is {y.shape}') - if y.ndim == 1: # nothing to merge - return y - - if merge_type == 'ch0': - return y[0] - if merge_type == 'ch1': - return y[1] - if merge_type == 'random': - return y[np.random.randint(0, 2)] - - # need to do averaging according to dtype - - if y.dtype == 'float32': - y_out = (y[0] + y[1]) * 0.5 - elif y.dtype == 'int16': - y_out = y.astype('int32') - y_out = (y_out[0] + y_out[1]) // 2 - y_out = np.clip(y_out, np.iinfo(y.dtype).min, - np.iinfo(y.dtype).max).astype(y.dtype) - - elif y.dtype == 'int8': - y_out = y.astype('int16') - y_out = (y_out[0] + y_out[1]) // 2 - y_out = np.clip(y_out, np.iinfo(y.dtype).min, - np.iinfo(y.dtype).max).astype(y.dtype) - else: - raise ParameterError(f'Unsupported dtype: {y.dtype}') - return y_out - - -def _safe_cast(y: array, dtype: Union[type, str]) -> array: - """ data type casting in a safe way, i.e., prevent overflow or underflow - - This function is used internally. - """ - return np.clip(y, np.iinfo(dtype).min, np.iinfo(dtype).max).astype(dtype) - - -def depth_convert(y: array, dtype: Union[type, str], - dithering: bool=True) -> array: - """Convert audio array to target dtype safely - - This function convert audio waveform to a target dtype, with addition steps of - preventing overflow/underflow and preserving audio range. - - """ - - SUPPORT_DTYPE = ['int16', 'int8', 'float32', 'float64'] - if y.dtype not in SUPPORT_DTYPE: - raise ParameterError( - 'Unsupported audio dtype, ' - f'y.dtype is {y.dtype}, supported dtypes are {SUPPORT_DTYPE}') - - if dtype not in SUPPORT_DTYPE: - raise ParameterError( - 'Unsupported audio dtype, ' - f'target dtype is {dtype}, supported dtypes are {SUPPORT_DTYPE}') - - if dtype == y.dtype: - return y - - if dtype == 'float64' and y.dtype == 'float32': - return _safe_cast(y, dtype) - if dtype == 'float32' and y.dtype == 'float64': - return _safe_cast(y, dtype) - - if dtype == 'int16' or dtype == 'int8': - if y.dtype in ['float64', 'float32']: - factor = np.iinfo(dtype).max - y = np.clip(y * factor, np.iinfo(dtype).min, - np.iinfo(dtype).max).astype(dtype) - y = y.astype(dtype) - else: - if dtype == 'int16' and y.dtype == 'int8': - factor = np.iinfo('int16').max / np.iinfo('int8').max - EPS - y = y.astype('float32') * factor - y = y.astype('int16') - - else: # dtype == 'int8' and y.dtype=='int16': - y = y.astype('int32') * np.iinfo('int8').max / \ - np.iinfo('int16').max - y = y.astype('int8') - - if dtype in ['float32', 'float64']: - org_dtype = y.dtype - y = y.astype(dtype) / np.iinfo(org_dtype).max - return y - - -def sound_file_load(file: str, - offset: Optional[float]=None, - dtype: str='int16', - duration: Optional[int]=None) -> Tuple[array, int]: - """Load audio using soundfile library - - This function load audio file using libsndfile. - - Reference: - http://www.mega-nerd.com/libsndfile/#Features - - """ - with sf.SoundFile(file) as sf_desc: - sr_native = sf_desc.samplerate - if offset: - sf_desc.seek(int(offset * sr_native)) - if duration is not None: - frame_duration = int(duration * sr_native) - else: - frame_duration = -1 - y = sf_desc.read(frames=frame_duration, dtype=dtype, always_2d=False).T - - return y, sf_desc.samplerate - - -def audio_file_load(): - """Load audio using audiofile library - - This function load audio file using audiofile. - - Reference: - https://audiofile.68k.org/ - - """ - raise NotImplementedError() - - -def sox_file_load(): - """Load audio using sox library - - This function load audio file using sox. - - Reference: - http://sox.sourceforge.net/ - """ - raise NotImplementedError() - - -def normalize(y: array, norm_type: str='linear', - mul_factor: float=1.0) -> array: - """ normalize an input audio with additional multiplier. - - """ - - if norm_type == 'linear': - amax = np.max(np.abs(y)) - factor = 1.0 / (amax + EPS) - y = y * factor * mul_factor - elif norm_type == 'gaussian': - amean = np.mean(y) - astd = np.std(y) - astd = max(astd, EPS) - y = mul_factor * (y - amean) / astd - else: - raise NotImplementedError(f'norm_type should be in {NORMALMIZE_TYPES}') - - return y - - -def save_wav(y: array, sr: int, file: str) -> None: - """Save audio file to disk. - This function saves audio to disk using scipy.io.wavfile, with additional step - to convert input waveform to int16 unless it already is int16 - - Notes: - It only support raw wav format. - - """ - if not file.endswith('.wav'): - raise ParameterError( - f'only .wav file supported, but dst file name is: {file}') - - if sr <= 0: - raise ParameterError( - f'Sample rate should be larger than 0, recieved sr = {sr}') - - if y.dtype not in ['int16', 'int8']: - warnings.warn( - f'input data type is {y.dtype}, will convert data to int16 format before saving' - ) - y_out = depth_convert(y, 'int16') - else: - y_out = y - - wavfile.write(file, sr, y_out) - - -def load( - file: str, - sr: Optional[int]=None, - mono: bool=True, - merge_type: str='average', # ch0,ch1,random,average - normal: bool=True, - norm_type: str='linear', - norm_mul_factor: float=1.0, - offset: float=0.0, - duration: Optional[int]=None, - dtype: str='float32', - resample_mode: str='kaiser_fast') -> Tuple[array, int]: - """Load audio file from disk. - This function loads audio from disk using using audio beackend. - - Parameters: - - Notes: - - """ - - y, r = sound_file_load(file, offset=offset, dtype=dtype, duration=duration) - - if not ((y.ndim == 1 and len(y) > 0) or (y.ndim == 2 and len(y[0]) > 0)): - raise ParameterError(f'audio file {file} looks empty') - - if mono: - y = to_mono(y, merge_type) - - if sr is not None and sr != r: - y = resample(y, r, sr, mode=resample_mode) - r = sr - - if normal: - y = normalize(y, norm_type, norm_mul_factor) - elif dtype in ['int8', 'int16']: - # still need to do normalization, before depth convertion - y = normalize(y, 'linear', 1.0) - - y = depth_convert(y, dtype) - return y, r