Merge pull request #1518 from KPatr1ck/audio

[audio]refactor audio arch
pull/1528/head
Hui Zhang 3 years ago committed by GitHub
commit 25cb4bb06a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -11,3 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .backends import load
from .backends import save

@ -11,3 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .soundfile_backend import depth_convert
from .soundfile_backend import load
from .soundfile_backend import normalize
from .soundfile_backend import resample
from .soundfile_backend import save
from .soundfile_backend import to_mono

@ -11,3 +11,255 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Optional
from typing import Tuple
from typing import Union
import numpy as np
import resampy
import soundfile as sf
from numpy import ndarray as array
from scipy.io import wavfile
from ..utils import ParameterError
__all__ = [
'resample',
'to_mono',
'depth_convert',
'normalize',
'save',
'load',
]
NORMALMIZE_TYPES = ['linear', 'gaussian']
MERGE_TYPES = ['ch0', 'ch1', 'random', 'average']
RESAMPLE_MODES = ['kaiser_best', 'kaiser_fast']
EPS = 1e-8
def resample(y: array, src_sr: int, target_sr: int,
mode: str='kaiser_fast') -> array:
""" Audio resampling
This function is the same as using resampy.resample().
Notes:
The default mode is kaiser_fast. For better audio quality, use mode = 'kaiser_fast'
"""
if mode == 'kaiser_best':
warnings.warn(
f'Using resampy in kaiser_best to {src_sr}=>{target_sr}. This function is pretty slow, \
we recommend the mode kaiser_fast in large scale audio trainning')
if not isinstance(y, np.ndarray):
raise ParameterError(
'Only support numpy array, but received y in {type(y)}')
if mode not in RESAMPLE_MODES:
raise ParameterError(f'resample mode must in {RESAMPLE_MODES}')
return resampy.resample(y, src_sr, target_sr, filter=mode)
def to_mono(y: array, merge_type: str='average') -> array:
""" convert sterior audio to mono
"""
if merge_type not in MERGE_TYPES:
raise ParameterError(
f'Unsupported merge type {merge_type}, available types are {MERGE_TYPES}'
)
if y.ndim > 2:
raise ParameterError(
f'Unsupported audio array, y.ndim > 2, the shape is {y.shape}')
if y.ndim == 1: # nothing to merge
return y
if merge_type == 'ch0':
return y[0]
if merge_type == 'ch1':
return y[1]
if merge_type == 'random':
return y[np.random.randint(0, 2)]
# need to do averaging according to dtype
if y.dtype == 'float32':
y_out = (y[0] + y[1]) * 0.5
elif y.dtype == 'int16':
y_out = y.astype('int32')
y_out = (y_out[0] + y_out[1]) // 2
y_out = np.clip(y_out, np.iinfo(y.dtype).min,
np.iinfo(y.dtype).max).astype(y.dtype)
elif y.dtype == 'int8':
y_out = y.astype('int16')
y_out = (y_out[0] + y_out[1]) // 2
y_out = np.clip(y_out, np.iinfo(y.dtype).min,
np.iinfo(y.dtype).max).astype(y.dtype)
else:
raise ParameterError(f'Unsupported dtype: {y.dtype}')
return y_out
def _safe_cast(y: array, dtype: Union[type, str]) -> array:
""" data type casting in a safe way, i.e., prevent overflow or underflow
This function is used internally.
"""
return np.clip(y, np.iinfo(dtype).min, np.iinfo(dtype).max).astype(dtype)
def depth_convert(y: array, dtype: Union[type, str],
dithering: bool=True) -> array:
"""Convert audio array to target dtype safely
This function convert audio waveform to a target dtype, with addition steps of
preventing overflow/underflow and preserving audio range.
"""
SUPPORT_DTYPE = ['int16', 'int8', 'float32', 'float64']
if y.dtype not in SUPPORT_DTYPE:
raise ParameterError(
'Unsupported audio dtype, '
f'y.dtype is {y.dtype}, supported dtypes are {SUPPORT_DTYPE}')
if dtype not in SUPPORT_DTYPE:
raise ParameterError(
'Unsupported audio dtype, '
f'target dtype is {dtype}, supported dtypes are {SUPPORT_DTYPE}')
if dtype == y.dtype:
return y
if dtype == 'float64' and y.dtype == 'float32':
return _safe_cast(y, dtype)
if dtype == 'float32' and y.dtype == 'float64':
return _safe_cast(y, dtype)
if dtype == 'int16' or dtype == 'int8':
if y.dtype in ['float64', 'float32']:
factor = np.iinfo(dtype).max
y = np.clip(y * factor, np.iinfo(dtype).min,
np.iinfo(dtype).max).astype(dtype)
y = y.astype(dtype)
else:
if dtype == 'int16' and y.dtype == 'int8':
factor = np.iinfo('int16').max / np.iinfo('int8').max - EPS
y = y.astype('float32') * factor
y = y.astype('int16')
else: # dtype == 'int8' and y.dtype=='int16':
y = y.astype('int32') * np.iinfo('int8').max / \
np.iinfo('int16').max
y = y.astype('int8')
if dtype in ['float32', 'float64']:
org_dtype = y.dtype
y = y.astype(dtype) / np.iinfo(org_dtype).max
return y
def sound_file_load(file: str,
offset: Optional[float]=None,
dtype: str='int16',
duration: Optional[int]=None) -> Tuple[array, int]:
"""Load audio using soundfile library
This function load audio file using libsndfile.
Reference:
http://www.mega-nerd.com/libsndfile/#Features
"""
with sf.SoundFile(file) as sf_desc:
sr_native = sf_desc.samplerate
if offset:
sf_desc.seek(int(offset * sr_native))
if duration is not None:
frame_duration = int(duration * sr_native)
else:
frame_duration = -1
y = sf_desc.read(frames=frame_duration, dtype=dtype, always_2d=False).T
return y, sf_desc.samplerate
def normalize(y: array, norm_type: str='linear',
mul_factor: float=1.0) -> array:
""" normalize an input audio with additional multiplier.
"""
if norm_type == 'linear':
amax = np.max(np.abs(y))
factor = 1.0 / (amax + EPS)
y = y * factor * mul_factor
elif norm_type == 'gaussian':
amean = np.mean(y)
astd = np.std(y)
astd = max(astd, EPS)
y = mul_factor * (y - amean) / astd
else:
raise NotImplementedError(f'norm_type should be in {NORMALMIZE_TYPES}')
return y
def save(y: array, sr: int, file: str) -> None:
"""Save audio file to disk.
This function saves audio to disk using scipy.io.wavfile, with additional step
to convert input waveform to int16 unless it already is int16
Notes:
It only support raw wav format.
"""
if not file.endswith('.wav'):
raise ParameterError(
f'only .wav file supported, but dst file name is: {file}')
if sr <= 0:
raise ParameterError(
f'Sample rate should be larger than 0, recieved sr = {sr}')
if y.dtype not in ['int16', 'int8']:
warnings.warn(
f'input data type is {y.dtype}, will convert data to int16 format before saving'
)
y_out = depth_convert(y, 'int16')
else:
y_out = y
wavfile.write(file, sr, y_out)
def load(
file: str,
sr: Optional[int]=None,
mono: bool=True,
merge_type: str='average', # ch0,ch1,random,average
normal: bool=True,
norm_type: str='linear',
norm_mul_factor: float=1.0,
offset: float=0.0,
duration: Optional[int]=None,
dtype: str='float32',
resample_mode: str='kaiser_fast') -> Tuple[array, int]:
"""Load audio file from disk.
This function loads audio from disk using using audio beackend.
Parameters:
Notes:
"""
y, r = sound_file_load(file, offset=offset, dtype=dtype, duration=duration)
if not ((y.ndim == 1 and len(y) > 0) or (y.ndim == 2 and len(y[0]) > 0)):
raise ParameterError(f'audio file {file} looks empty')
if mono:
y = to_mono(y, merge_type)
if sr is not None and sr != r:
y = resample(y, r, sr, mode=resample_mode)
r = sr
if normal:
y = normalize(y, norm_type, norm_mul_factor)
elif dtype in ['int8', 'int16']:
# still need to do normalization, before depth convertion
y = normalize(y, 'linear', 1.0)
y = depth_convert(y, dtype)
return y, r

@ -0,0 +1,688 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Tuple
import paddle
from paddle import Tensor
from ..functional import create_dct
from ..functional.window import get_window
__all__ = [
'spectrogram',
'fbank',
'mfcc',
]
# window types
HANNING = 'hann'
HAMMING = 'hamming'
POVEY = 'povey'
RECTANGULAR = 'rect'
BLACKMAN = 'blackman'
def _get_epsilon(dtype):
return paddle.to_tensor(1e-07, dtype=dtype)
def _next_power_of_2(x: int) -> int:
return 1 if x == 0 else 2**(x - 1).bit_length()
def _get_strided(waveform: Tensor,
window_size: int,
window_shift: int,
snip_edges: bool) -> Tensor:
assert waveform.dim() == 1
num_samples = waveform.shape[0]
if snip_edges:
if num_samples < window_size:
return paddle.empty((0, 0), dtype=waveform.dtype)
else:
m = 1 + (num_samples - window_size) // window_shift
else:
reversed_waveform = paddle.flip(waveform, [0])
m = (num_samples + (window_shift // 2)) // window_shift
pad = window_size // 2 - window_shift // 2
pad_right = reversed_waveform
if pad > 0:
pad_left = reversed_waveform[-pad:]
waveform = paddle.concat((pad_left, waveform, pad_right), axis=0)
else:
waveform = paddle.concat((waveform[-pad:], pad_right), axis=0)
return paddle.signal.frame(waveform, window_size, window_shift)[:, :m].T
def _feature_window_function(
window_type: str,
window_size: int,
blackman_coeff: float,
dtype: int, ) -> Tensor:
if window_type == HANNING:
return get_window('hann', window_size, fftbins=False, dtype=dtype)
elif window_type == HAMMING:
return get_window('hamming', window_size, fftbins=False, dtype=dtype)
elif window_type == POVEY:
return get_window(
'hann', window_size, fftbins=False, dtype=dtype).pow(0.85)
elif window_type == RECTANGULAR:
return paddle.ones([window_size], dtype=dtype)
elif window_type == BLACKMAN:
a = 2 * math.pi / (window_size - 1)
window_function = paddle.arange(window_size, dtype=dtype)
return (blackman_coeff - 0.5 * paddle.cos(a * window_function) +
(0.5 - blackman_coeff) * paddle.cos(2 * a * window_function)
).astype(dtype)
else:
raise Exception('Invalid window type ' + window_type)
def _get_log_energy(strided_input: Tensor, epsilon: Tensor,
energy_floor: float) -> Tensor:
log_energy = paddle.maximum(strided_input.pow(2).sum(1), epsilon).log()
if energy_floor == 0.0:
return log_energy
return paddle.maximum(
log_energy,
paddle.to_tensor(math.log(energy_floor), dtype=strided_input.dtype))
def _get_waveform_and_window_properties(
waveform: Tensor,
channel: int,
sample_frequency: float,
frame_shift: float,
frame_length: float,
round_to_power_of_two: bool,
preemphasis_coefficient: float) -> Tuple[Tensor, int, int, int]:
channel = max(channel, 0)
assert channel < waveform.shape[0], (
'Invalid channel {} for size {}'.format(channel, waveform.shape[0]))
waveform = waveform[channel, :] # size (n)
window_shift = int(
sample_frequency * frame_shift *
0.001) # pass frame_shift and frame_length in milliseconds
window_size = int(sample_frequency * frame_length * 0.001)
padded_window_size = _next_power_of_2(
window_size) if round_to_power_of_two else window_size
assert 2 <= window_size <= len(waveform), (
'choose a window size {} that is [2, {}]'.format(window_size,
len(waveform)))
assert 0 < window_shift, '`window_shift` must be greater than 0'
assert padded_window_size % 2 == 0, 'the padded `window_size` must be divisible by two.' \
' use `round_to_power_of_two` or change `frame_length`'
assert 0. <= preemphasis_coefficient <= 1.0, '`preemphasis_coefficient` must be between [0,1]'
assert sample_frequency > 0, '`sample_frequency` must be greater than zero'
return waveform, window_shift, window_size, padded_window_size
def _get_window(waveform: Tensor,
padded_window_size: int,
window_size: int,
window_shift: int,
window_type: str,
blackman_coeff: float,
snip_edges: bool,
raw_energy: bool,
energy_floor: float,
dither: float,
remove_dc_offset: bool,
preemphasis_coefficient: float) -> Tuple[Tensor, Tensor]:
dtype = waveform.dtype
epsilon = _get_epsilon(dtype)
# size (m, window_size)
strided_input = _get_strided(waveform, window_size, window_shift,
snip_edges)
if dither != 0.0:
# Returns a random number strictly between 0 and 1
x = paddle.maximum(epsilon,
paddle.rand(strided_input.shape, dtype=dtype))
rand_gauss = paddle.sqrt(-2 * x.log()) * paddle.cos(2 * math.pi * x)
strided_input = strided_input + rand_gauss * dither
if remove_dc_offset:
# Subtract each row/frame by its mean
row_means = paddle.mean(
strided_input, axis=1).unsqueeze(1) # size (m, 1)
strided_input = strided_input - row_means
if raw_energy:
# Compute the log energy of each row/frame before applying preemphasis and
# window function
signal_log_energy = _get_log_energy(strided_input, epsilon,
energy_floor) # size (m)
if preemphasis_coefficient != 0.0:
# strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
offset_strided_input = paddle.nn.functional.pad(
strided_input.unsqueeze(0), (1, 0),
data_format='NCL',
mode='replicate').squeeze(0) # size (m, window_size + 1)
strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :
-1]
# Apply window_function to each row/frame
window_function = _feature_window_function(
window_type, window_size, blackman_coeff,
dtype).unsqueeze(0) # size (1, window_size)
strided_input = strided_input * window_function # size (m, window_size)
# Pad columns with zero until we reach size (m, padded_window_size)
if padded_window_size != window_size:
padding_right = padded_window_size - window_size
strided_input = paddle.nn.functional.pad(
strided_input.unsqueeze(0), (0, padding_right),
data_format='NCL',
mode='constant',
value=0).squeeze(0)
# Compute energy after window function (not the raw one)
if not raw_energy:
signal_log_energy = _get_log_energy(strided_input, epsilon,
energy_floor) # size (m)
return strided_input, signal_log_energy
def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
# subtracts the column mean of the tensor size (m, n) if subtract_mean=True
# it returns size (m, n)
if subtract_mean:
col_means = paddle.mean(tensor, axis=0).unsqueeze(0)
tensor = tensor - col_means
return tensor
def spectrogram(waveform: Tensor,
blackman_coeff: float=0.42,
channel: int=-1,
dither: float=0.0,
energy_floor: float=1.0,
frame_length: float=25.0,
frame_shift: float=10.0,
min_duration: float=0.0,
preemphasis_coefficient: float=0.97,
raw_energy: bool=True,
remove_dc_offset: bool=True,
round_to_power_of_two: bool=True,
sample_frequency: float=16000.0,
snip_edges: bool=True,
subtract_mean: bool=False,
window_type: str=POVEY) -> Tensor:
"""[summary]
Args:
waveform (Tensor): [description]
blackman_coeff (float, optional): [description]. Defaults to 0.42.
channel (int, optional): [description]. Defaults to -1.
dither (float, optional): [description]. Defaults to 0.0.
energy_floor (float, optional): [description]. Defaults to 1.0.
frame_length (float, optional): [description]. Defaults to 25.0.
frame_shift (float, optional): [description]. Defaults to 10.0.
min_duration (float, optional): [description]. Defaults to 0.0.
preemphasis_coefficient (float, optional): [description]. Defaults to 0.97.
raw_energy (bool, optional): [description]. Defaults to True.
remove_dc_offset (bool, optional): [description]. Defaults to True.
round_to_power_of_two (bool, optional): [description]. Defaults to True.
sample_frequency (float, optional): [description]. Defaults to 16000.0.
snip_edges (bool, optional): [description]. Defaults to True.
subtract_mean (bool, optional): [description]. Defaults to False.
window_type (str, optional): [description]. Defaults to POVEY.
Returns:
Tensor: [description]
"""
dtype = waveform.dtype
epsilon = _get_epsilon(dtype)
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
waveform, channel, sample_frequency, frame_shift, frame_length,
round_to_power_of_two, preemphasis_coefficient)
if len(waveform) < min_duration * sample_frequency:
# signal is too short
return paddle.empty([0])
strided_input, signal_log_energy = _get_window(
waveform, padded_window_size, window_size, window_shift, window_type,
blackman_coeff, snip_edges, raw_energy, energy_floor, dither,
remove_dc_offset, preemphasis_coefficient)
# size (m, padded_window_size // 2 + 1, 2)
fft = paddle.fft.rfft(strided_input)
# Convert the FFT into a power spectrum
power_spectrum = paddle.maximum(
fft.abs().pow(2.),
epsilon).log() # size (m, padded_window_size // 2 + 1)
power_spectrum[:, 0] = signal_log_energy
power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
return power_spectrum
def _inverse_mel_scale_scalar(mel_freq: float) -> float:
return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)
def _inverse_mel_scale(mel_freq: Tensor) -> Tensor:
return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)
def _mel_scale_scalar(freq: float) -> float:
return 1127.0 * math.log(1.0 + freq / 700.0)
def _mel_scale(freq: Tensor) -> Tensor:
return 1127.0 * (1.0 + freq / 700.0).log()
def _vtln_warp_freq(vtln_low_cutoff: float,
vtln_high_cutoff: float,
low_freq: float,
high_freq: float,
vtln_warp_factor: float,
freq: Tensor) -> Tensor:
assert vtln_low_cutoff > low_freq, 'be sure to set the vtln_low option higher than low_freq'
assert vtln_high_cutoff < high_freq, 'be sure to set the vtln_high option lower than high_freq [or negative]'
l = vtln_low_cutoff * max(1.0, vtln_warp_factor)
h = vtln_high_cutoff * min(1.0, vtln_warp_factor)
scale = 1.0 / vtln_warp_factor
Fl = scale * l # F(l)
Fh = scale * h # F(h)
assert l > low_freq and h < high_freq
# slope of left part of the 3-piece linear function
scale_left = (Fl - low_freq) / (l - low_freq)
# [slope of center part is just "scale"]
# slope of right part of the 3-piece linear function
scale_right = (high_freq - Fh) / (high_freq - h)
res = paddle.empty_like(freq)
outside_low_high_freq = paddle.less_than(freq, paddle.to_tensor(low_freq)) \
| paddle.greater_than(freq, paddle.to_tensor(high_freq)) # freq < low_freq || freq > high_freq
before_l = paddle.less_than(freq, paddle.to_tensor(l)) # freq < l
before_h = paddle.less_than(freq, paddle.to_tensor(h)) # freq < h
after_h = paddle.greater_equal(freq, paddle.to_tensor(h)) # freq >= h
# order of operations matter here (since there is overlapping frequency regions)
res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq)
res[before_h] = scale * freq[before_h]
res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq)
res[outside_low_high_freq] = freq[outside_low_high_freq]
return res
def _vtln_warp_mel_freq(vtln_low_cutoff: float,
vtln_high_cutoff: float,
low_freq,
high_freq: float,
vtln_warp_factor: float,
mel_freq: Tensor) -> Tensor:
return _mel_scale(
_vtln_warp_freq(vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq,
vtln_warp_factor, _inverse_mel_scale(mel_freq)))
def _get_mel_banks(num_bins: int,
window_length_padded: int,
sample_freq: float,
low_freq: float,
high_freq: float,
vtln_low: float,
vtln_high: float,
vtln_warp_factor: float) -> Tuple[Tensor, Tensor]:
assert num_bins > 3, 'Must have at least 3 mel bins'
assert window_length_padded % 2 == 0
num_fft_bins = window_length_padded / 2
nyquist = 0.5 * sample_freq
if high_freq <= 0.0:
high_freq += nyquist
assert (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq), \
('Bad values in options: low-freq {} and high-freq {} vs. nyquist {}'.format(low_freq, high_freq, nyquist))
# fft-bin width [think of it as Nyquist-freq / half-window-length]
fft_bin_width = sample_freq / window_length_padded
mel_low_freq = _mel_scale_scalar(low_freq)
mel_high_freq = _mel_scale_scalar(high_freq)
# divide by num_bins+1 in next line because of end-effects where the bins
# spread out to the sides.
mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
if vtln_high < 0.0:
vtln_high += nyquist
assert vtln_warp_factor == 1.0 or ((low_freq < vtln_low < high_freq) and
(0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)), \
('Bad values in options: vtln-low {} and vtln-high {}, versus '
'low-freq {} and high-freq {}'.format(vtln_low, vtln_high, low_freq, high_freq))
bin = paddle.arange(num_bins).unsqueeze(1)
left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1)
center_mel = mel_low_freq + (bin + 1.0
) * mel_freq_delta # size(num_bins, 1)
right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1)
if vtln_warp_factor != 1.0:
left_mel = _vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq,
vtln_warp_factor, left_mel)
center_mel = _vtln_warp_mel_freq(vtln_low, vtln_high, low_freq,
high_freq, vtln_warp_factor,
center_mel)
right_mel = _vtln_warp_mel_freq(vtln_low, vtln_high, low_freq,
high_freq, vtln_warp_factor, right_mel)
center_freqs = _inverse_mel_scale(center_mel) # size (num_bins)
# size(1, num_fft_bins)
mel = _mel_scale(fft_bin_width * paddle.arange(num_fft_bins)).unsqueeze(0)
# size (num_bins, num_fft_bins)
up_slope = (mel - left_mel) / (center_mel - left_mel)
down_slope = (right_mel - mel) / (right_mel - center_mel)
if vtln_warp_factor == 1.0:
# left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values
bins = paddle.maximum(
paddle.zeros([1]), paddle.minimum(up_slope, down_slope))
else:
# warping can move the order of left_mel, center_mel, right_mel anywhere
bins = paddle.zeros_like(up_slope)
up_idx = paddle.greater_than(mel, left_mel) & paddle.less_than(
mel, center_mel) # left_mel < mel <= center_mel
down_idx = paddle.greater_than(mel, center_mel) & paddle.less_than(
mel, right_mel) # center_mel < mel < right_mel
bins[up_idx] = up_slope[up_idx]
bins[down_idx] = down_slope[down_idx]
return bins, center_freqs
def fbank(waveform: Tensor,
blackman_coeff: float=0.42,
channel: int=-1,
dither: float=0.0,
energy_floor: float=1.0,
frame_length: float=25.0,
frame_shift: float=10.0,
high_freq: float=0.0,
htk_compat: bool=False,
low_freq: float=20.0,
min_duration: float=0.0,
num_mel_bins: int=23,
preemphasis_coefficient: float=0.97,
raw_energy: bool=True,
remove_dc_offset: bool=True,
round_to_power_of_two: bool=True,
sample_frequency: float=16000.0,
snip_edges: bool=True,
subtract_mean: bool=False,
use_energy: bool=False,
use_log_fbank: bool=True,
use_power: bool=True,
vtln_high: float=-500.0,
vtln_low: float=100.0,
vtln_warp: float=1.0,
window_type: str=POVEY) -> Tensor:
"""[summary]
Args:
waveform (Tensor): [description]
blackman_coeff (float, optional): [description]. Defaults to 0.42.
channel (int, optional): [description]. Defaults to -1.
dither (float, optional): [description]. Defaults to 0.0.
energy_floor (float, optional): [description]. Defaults to 1.0.
frame_length (float, optional): [description]. Defaults to 25.0.
frame_shift (float, optional): [description]. Defaults to 10.0.
high_freq (float, optional): [description]. Defaults to 0.0.
htk_compat (bool, optional): [description]. Defaults to False.
low_freq (float, optional): [description]. Defaults to 20.0.
min_duration (float, optional): [description]. Defaults to 0.0.
num_mel_bins (int, optional): [description]. Defaults to 23.
preemphasis_coefficient (float, optional): [description]. Defaults to 0.97.
raw_energy (bool, optional): [description]. Defaults to True.
remove_dc_offset (bool, optional): [description]. Defaults to True.
round_to_power_of_two (bool, optional): [description]. Defaults to True.
sample_frequency (float, optional): [description]. Defaults to 16000.0.
snip_edges (bool, optional): [description]. Defaults to True.
subtract_mean (bool, optional): [description]. Defaults to False.
use_energy (bool, optional): [description]. Defaults to False.
use_log_fbank (bool, optional): [description]. Defaults to True.
use_power (bool, optional): [description]. Defaults to True.
vtln_high (float, optional): [description]. Defaults to -500.0.
vtln_low (float, optional): [description]. Defaults to 100.0.
vtln_warp (float, optional): [description]. Defaults to 1.0.
window_type (str, optional): [description]. Defaults to POVEY.
Returns:
Tensor: [description]
"""
dtype = waveform.dtype
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
waveform, channel, sample_frequency, frame_shift, frame_length,
round_to_power_of_two, preemphasis_coefficient)
if len(waveform) < min_duration * sample_frequency:
# signal is too short
return paddle.empty([0], dtype=dtype)
# strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
strided_input, signal_log_energy = _get_window(
waveform, padded_window_size, window_size, window_shift, window_type,
blackman_coeff, snip_edges, raw_energy, energy_floor, dither,
remove_dc_offset, preemphasis_coefficient)
# size (m, padded_window_size // 2 + 1)
spectrum = paddle.fft.rfft(strided_input).abs()
if use_power:
spectrum = spectrum.pow(2.)
# size (num_mel_bins, padded_window_size // 2)
mel_energies, _ = _get_mel_banks(num_mel_bins, padded_window_size,
sample_frequency, low_freq, high_freq,
vtln_low, vtln_high, vtln_warp)
mel_energies = mel_energies.astype(dtype)
# pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
mel_energies = paddle.nn.functional.pad(
mel_energies.unsqueeze(0), (0, 1),
data_format='NCL',
mode='constant',
value=0).squeeze(0)
# sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
mel_energies = paddle.mm(spectrum, mel_energies.T)
if use_log_fbank:
# avoid log of zero (which should be prevented anyway by dithering)
mel_energies = paddle.maximum(mel_energies, _get_epsilon(dtype)).log()
# if use_energy then add it as the last column for htk_compat == true else first column
if use_energy:
signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1)
# returns size (m, num_mel_bins + 1)
if htk_compat:
mel_energies = paddle.concat(
(mel_energies, signal_log_energy), axis=1)
else:
mel_energies = paddle.concat(
(signal_log_energy, mel_energies), axis=1)
mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
return mel_energies
def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor:
# returns a dct matrix of size (num_mel_bins, num_ceps)
# size (num_mel_bins, num_mel_bins)
dct_matrix = create_dct(num_mel_bins, num_mel_bins, 'ortho')
# kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins)
# this would be the first column in the dct_matrix for torchaudio as it expects a
# right multiply (which would be the first column of the kaldi's dct_matrix as kaldi
# expects a left multiply e.g. dct_matrix * vector).
dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins))
dct_matrix = dct_matrix[:, :num_ceps]
return dct_matrix
def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor:
# returns size (num_ceps)
# Compute liftering coefficients (scaling on cepstral coeffs)
# coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected.
i = paddle.arange(num_ceps)
return 1.0 + 0.5 * cepstral_lifter * paddle.sin(math.pi * i /
cepstral_lifter)
def mfcc(waveform: Tensor,
blackman_coeff: float=0.42,
cepstral_lifter: float=22.0,
channel: int=-1,
dither: float=0.0,
energy_floor: float=1.0,
frame_length: float=25.0,
frame_shift: float=10.0,
high_freq: float=0.0,
htk_compat: bool=False,
low_freq: float=20.0,
num_ceps: int=13,
min_duration: float=0.0,
num_mel_bins: int=23,
preemphasis_coefficient: float=0.97,
raw_energy: bool=True,
remove_dc_offset: bool=True,
round_to_power_of_two: bool=True,
sample_frequency: float=16000.0,
snip_edges: bool=True,
subtract_mean: bool=False,
use_energy: bool=False,
vtln_high: float=-500.0,
vtln_low: float=100.0,
vtln_warp: float=1.0,
window_type: str=POVEY) -> Tensor:
"""[summary]
Args:
waveform (Tensor): [description]
blackman_coeff (float, optional): [description]. Defaults to 0.42.
cepstral_lifter (float, optional): [description]. Defaults to 22.0.
channel (int, optional): [description]. Defaults to -1.
dither (float, optional): [description]. Defaults to 0.0.
energy_floor (float, optional): [description]. Defaults to 1.0.
frame_length (float, optional): [description]. Defaults to 25.0.
frame_shift (float, optional): [description]. Defaults to 10.0.
high_freq (float, optional): [description]. Defaults to 0.0.
htk_compat (bool, optional): [description]. Defaults to False.
low_freq (float, optional): [description]. Defaults to 20.0.
num_ceps (int, optional): [description]. Defaults to 13.
min_duration (float, optional): [description]. Defaults to 0.0.
num_mel_bins (int, optional): [description]. Defaults to 23.
preemphasis_coefficient (float, optional): [description]. Defaults to 0.97.
raw_energy (bool, optional): [description]. Defaults to True.
remove_dc_offset (bool, optional): [description]. Defaults to True.
round_to_power_of_two (bool, optional): [description]. Defaults to True.
sample_frequency (float, optional): [description]. Defaults to 16000.0.
snip_edges (bool, optional): [description]. Defaults to True.
subtract_mean (bool, optional): [description]. Defaults to False.
use_energy (bool, optional): [description]. Defaults to False.
vtln_high (float, optional): [description]. Defaults to -500.0.
vtln_low (float, optional): [description]. Defaults to 100.0.
vtln_warp (float, optional): [description]. Defaults to 1.0.
window_type (str, optional): [description]. Defaults to POVEY.
Returns:
Tensor: [description]
"""
assert num_ceps <= num_mel_bins, 'num_ceps cannot be larger than num_mel_bins: %d vs %d' % (
num_ceps, num_mel_bins)
dtype = waveform.dtype
# The mel_energies should not be squared (use_power=True), not have mean subtracted
# (subtract_mean=False), and use log (use_log_fbank=True).
# size (m, num_mel_bins + use_energy)
feature = fbank(
waveform=waveform,
blackman_coeff=blackman_coeff,
channel=channel,
dither=dither,
energy_floor=energy_floor,
frame_length=frame_length,
frame_shift=frame_shift,
high_freq=high_freq,
htk_compat=htk_compat,
low_freq=low_freq,
min_duration=min_duration,
num_mel_bins=num_mel_bins,
preemphasis_coefficient=preemphasis_coefficient,
raw_energy=raw_energy,
remove_dc_offset=remove_dc_offset,
round_to_power_of_two=round_to_power_of_two,
sample_frequency=sample_frequency,
snip_edges=snip_edges,
subtract_mean=False,
use_energy=use_energy,
use_log_fbank=True,
use_power=True,
vtln_high=vtln_high,
vtln_low=vtln_low,
vtln_warp=vtln_warp,
window_type=window_type)
if use_energy:
# size (m)
signal_log_energy = feature[:, num_mel_bins if htk_compat else 0]
# offset is 0 if htk_compat==True else 1
mel_offset = int(not htk_compat)
feature = feature[:, mel_offset:(num_mel_bins + mel_offset)]
# size (num_mel_bins, num_ceps)
dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).astype(dtype=dtype)
# size (m, num_ceps)
feature = feature.matmul(dct_matrix)
if cepstral_lifter != 0.0:
# size (1, num_ceps)
lifter_coeffs = _get_lifter_coeffs(num_ceps,
cepstral_lifter).unsqueeze(0)
feature *= lifter_coeffs.astype(dtype=dtype)
# if use_energy then replace the last column for htk_compat == true else first column
if use_energy:
feature[:, 0] = signal_log_energy
if htk_compat:
energy = feature[:, 0].unsqueeze(1) # size (m, 1)
feature = feature[:, 1:] # size (m, num_ceps - 1)
if not use_energy:
# scale on C0 (actually removing a scale we previously added that's
# part of one common definition of the cosine transform.)
energy *= math.sqrt(2)
feature = paddle.concat((feature, energy), axis=1)
feature = _subtract_column_mean(feature, subtract_mean)
return feature

@ -0,0 +1,728 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from librosa(https://github.com/librosa/librosa)
import warnings
from typing import List
from typing import Optional
from typing import Union
import numpy as np
import scipy
from numpy import ndarray as array
from numpy.lib.stride_tricks import as_strided
from scipy import signal
from ..backends import depth_convert
from ..utils import ParameterError
__all__ = [
# dsp
'stft',
'mfcc',
'hz_to_mel',
'mel_to_hz',
'split_frames',
'mel_frequencies',
'power_to_db',
'compute_fbank_matrix',
'melspectrogram',
'spectrogram',
'mu_encode',
'mu_decode',
# augmentation
'depth_augment',
'spect_augment',
'random_crop1d',
'random_crop2d',
'adaptive_spect_augment',
]
def pad_center(data: array, size: int, axis: int=-1, **kwargs) -> array:
"""Pad an array to a target length along a target axis.
This differs from `np.pad` by centering the data prior to padding,
analogous to `str.center`
"""
kwargs.setdefault("mode", "constant")
n = data.shape[axis]
lpad = int((size - n) // 2)
lengths = [(0, 0)] * data.ndim
lengths[axis] = (lpad, int(size - n - lpad))
if lpad < 0:
raise ParameterError(("Target size ({size:d}) must be "
"at least input size ({n:d})"))
return np.pad(data, lengths, **kwargs)
def split_frames(x: array, frame_length: int, hop_length: int,
axis: int=-1) -> array:
"""Slice a data array into (overlapping) frames.
This function is aligned with librosa.frame
"""
if not isinstance(x, np.ndarray):
raise ParameterError(
f"Input must be of type numpy.ndarray, given type(x)={type(x)}")
if x.shape[axis] < frame_length:
raise ParameterError(f"Input is too short (n={x.shape[axis]:d})"
f" for frame_length={frame_length:d}")
if hop_length < 1:
raise ParameterError(f"Invalid hop_length: {hop_length:d}")
if axis == -1 and not x.flags["F_CONTIGUOUS"]:
warnings.warn(f"librosa.util.frame called with axis={axis} "
"on a non-contiguous input. This will result in a copy.")
x = np.asfortranarray(x)
elif axis == 0 and not x.flags["C_CONTIGUOUS"]:
warnings.warn(f"librosa.util.frame called with axis={axis} "
"on a non-contiguous input. This will result in a copy.")
x = np.ascontiguousarray(x)
n_frames = 1 + (x.shape[axis] - frame_length) // hop_length
strides = np.asarray(x.strides)
new_stride = np.prod(strides[strides > 0] // x.itemsize) * x.itemsize
if axis == -1:
shape = list(x.shape)[:-1] + [frame_length, n_frames]
strides = list(strides) + [hop_length * new_stride]
elif axis == 0:
shape = [n_frames, frame_length] + list(x.shape)[1:]
strides = [hop_length * new_stride] + list(strides)
else:
raise ParameterError(f"Frame axis={axis} must be either 0 or -1")
return as_strided(x, shape=shape, strides=strides)
def _check_audio(y, mono=True) -> bool:
"""Determine whether a variable contains valid audio data.
The audio y must be a np.ndarray, ether 1-channel or two channel
"""
if not isinstance(y, np.ndarray):
raise ParameterError("Audio data must be of type numpy.ndarray")
if y.ndim > 2:
raise ParameterError(
f"Invalid shape for audio ndim={y.ndim:d}, shape={y.shape}")
if mono and y.ndim == 2:
raise ParameterError(
f"Invalid shape for mono audio ndim={y.ndim:d}, shape={y.shape}")
if (mono and len(y) == 0) or (not mono and y.shape[1] < 0):
raise ParameterError(f"Audio is empty ndim={y.ndim:d}, shape={y.shape}")
if not np.issubdtype(y.dtype, np.floating):
raise ParameterError("Audio data must be floating-point")
if not np.isfinite(y).all():
raise ParameterError("Audio buffer is not finite everywhere")
return True
def hz_to_mel(frequencies: Union[float, List[float], array],
htk: bool=False) -> array:
"""Convert Hz to Mels
This function is aligned with librosa.
"""
freq = np.asanyarray(frequencies)
if htk:
return 2595.0 * np.log10(1.0 + freq / 700.0)
# Fill in the linear part
f_min = 0.0
f_sp = 200.0 / 3
mels = (freq - f_min) / f_sp
# Fill in the log-scale part
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = np.log(6.4) / 27.0 # step size for log region
if freq.ndim:
# If we have array data, vectorize
log_t = freq >= min_log_hz
mels[log_t] = min_log_mel + \
np.log(freq[log_t] / min_log_hz) / logstep
elif freq >= min_log_hz:
# If we have scalar data, heck directly
mels = min_log_mel + np.log(freq / min_log_hz) / logstep
return mels
def mel_to_hz(mels: Union[float, List[float], array], htk: int=False) -> array:
"""Convert mel bin numbers to frequencies.
This function is aligned with librosa.
"""
mel_array = np.asanyarray(mels)
if htk:
return 700.0 * (10.0**(mel_array / 2595.0) - 1.0)
# Fill in the linear scale
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mel_array
# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = np.log(6.4) / 27.0 # step size for log region
if mel_array.ndim:
# If we have vector data, vectorize
log_t = mel_array >= min_log_mel
freqs[log_t] = min_log_hz * \
np.exp(logstep * (mel_array[log_t] - min_log_mel))
elif mel_array >= min_log_mel:
# If we have scalar data, check directly
freqs = min_log_hz * np.exp(logstep * (mel_array - min_log_mel))
return freqs
def mel_frequencies(n_mels: int=128,
fmin: float=0.0,
fmax: float=11025.0,
htk: bool=False) -> array:
"""Compute mel frequencies
This function is aligned with librosa.
"""
# 'Center freqs' of mel bands - uniformly spaced between limits
min_mel = hz_to_mel(fmin, htk=htk)
max_mel = hz_to_mel(fmax, htk=htk)
mels = np.linspace(min_mel, max_mel, n_mels)
return mel_to_hz(mels, htk=htk)
def fft_frequencies(sr: int, n_fft: int) -> array:
"""Compute fourier frequencies.
This function is aligned with librosa.
"""
return np.linspace(0, float(sr) / 2, int(1 + n_fft // 2), endpoint=True)
def compute_fbank_matrix(sr: int,
n_fft: int,
n_mels: int=128,
fmin: float=0.0,
fmax: Optional[float]=None,
htk: bool=False,
norm: str="slaney",
dtype: type=np.float32):
"""Compute fbank matrix.
This funciton is aligned with librosa.
"""
if norm != "slaney":
raise ParameterError('norm must set to slaney')
if fmax is None:
fmax = float(sr) / 2
# Initialize the weights
n_mels = int(n_mels)
weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
# Center freqs of each FFT bin
fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft)
# 'Center freqs' of mel bands - uniformly spaced between limits
mel_f = mel_frequencies(n_mels + 2, fmin=fmin, fmax=fmax, htk=htk)
fdiff = np.diff(mel_f)
ramps = np.subtract.outer(mel_f, fftfreqs)
for i in range(n_mels):
# lower and upper slopes for all bins
lower = -ramps[i] / fdiff[i]
upper = ramps[i + 2] / fdiff[i + 1]
# .. then intersect them with each other and zero
weights[i] = np.maximum(0, np.minimum(lower, upper))
if norm == "slaney":
# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (mel_f[2:n_mels + 2] - mel_f[:n_mels])
weights *= enorm[:, np.newaxis]
# Only check weights if f_mel[0] is positive
if not np.all((mel_f[:-2] == 0) | (weights.max(axis=1) > 0)):
# This means we have an empty channel somewhere
warnings.warn("Empty filters detected in mel frequency basis. "
"Some channels will produce empty responses. "
"Try increasing your sampling rate (and fmax) or "
"reducing n_mels.")
return weights
def stft(x: array,
n_fft: int=2048,
hop_length: Optional[int]=None,
win_length: Optional[int]=None,
window: str="hann",
center: bool=True,
dtype: type=np.complex64,
pad_mode: str="reflect") -> array:
"""Short-time Fourier transform (STFT).
This function is aligned with librosa.
"""
_check_audio(x)
# By default, use the entire frame
if win_length is None:
win_length = n_fft
# Set the default hop, if it's not already specified
if hop_length is None:
hop_length = int(win_length // 4)
fft_window = signal.get_window(window, win_length, fftbins=True)
# Pad the window out to n_fft size
fft_window = pad_center(fft_window, n_fft)
# Reshape so that the window can be broadcast
fft_window = fft_window.reshape((-1, 1))
# Pad the time series so that frames are centered
if center:
if n_fft > x.shape[-1]:
warnings.warn(
f"n_fft={n_fft} is too small for input signal of length={x.shape[-1]}"
)
x = np.pad(x, int(n_fft // 2), mode=pad_mode)
elif n_fft > x.shape[-1]:
raise ParameterError(
f"n_fft={n_fft} is too small for input signal of length={x.shape[-1]}"
)
# Window the time series.
x_frames = split_frames(x, frame_length=n_fft, hop_length=hop_length)
# Pre-allocate the STFT matrix
stft_matrix = np.empty(
(int(1 + n_fft // 2), x_frames.shape[1]), dtype=dtype, order="F")
fft = np.fft # use numpy fft as default
# Constrain STFT block sizes to 256 KB
MAX_MEM_BLOCK = 2**8 * 2**10
# how many columns can we fit within MAX_MEM_BLOCK?
n_columns = MAX_MEM_BLOCK // (stft_matrix.shape[0] * stft_matrix.itemsize)
n_columns = max(n_columns, 1)
for bl_s in range(0, stft_matrix.shape[1], n_columns):
bl_t = min(bl_s + n_columns, stft_matrix.shape[1])
stft_matrix[:, bl_s:bl_t] = fft.rfft(
fft_window * x_frames[:, bl_s:bl_t], axis=0)
return stft_matrix
def power_to_db(spect: array,
ref: float=1.0,
amin: float=1e-10,
top_db: Optional[float]=80.0) -> array:
"""Convert a power spectrogram (amplitude squared) to decibel (dB) units
This computes the scaling ``10 * log10(spect / ref)`` in a numerically
stable way.
This function is aligned with librosa.
"""
spect = np.asarray(spect)
if amin <= 0:
raise ParameterError("amin must be strictly positive")
if np.issubdtype(spect.dtype, np.complexfloating):
warnings.warn(
"power_to_db was called on complex input so phase "
"information will be discarded. To suppress this warning, "
"call power_to_db(np.abs(D)**2) instead.")
magnitude = np.abs(spect)
else:
magnitude = spect
if callable(ref):
# User supplied a function to calculate reference power
ref_value = ref(magnitude)
else:
ref_value = np.abs(ref)
log_spec = 10.0 * np.log10(np.maximum(amin, magnitude))
log_spec -= 10.0 * np.log10(np.maximum(amin, ref_value))
if top_db is not None:
if top_db < 0:
raise ParameterError("top_db must be non-negative")
log_spec = np.maximum(log_spec, log_spec.max() - top_db)
return log_spec
def mfcc(x,
sr: int=16000,
spect: Optional[array]=None,
n_mfcc: int=20,
dct_type: int=2,
norm: str="ortho",
lifter: int=0,
**kwargs) -> array:
"""Mel-frequency cepstral coefficients (MFCCs)
This function is NOT strictly aligned with librosa. The following example shows how to get the
same result with librosa:
# mfcc:
kwargs = {
'window_size':512,
'hop_length':320,
'mel_bins':64,
'fmin':50,
'to_db':False}
a = mfcc(x,
spect=None,
n_mfcc=20,
dct_type=2,
norm='ortho',
lifter=0,
**kwargs)
# librosa mfcc:
spect = librosa.feature.melspectrogram(y=x,sr=16000,n_fft=512,
win_length=512,
hop_length=320,
n_mels=64, fmin=50)
b = librosa.feature.mfcc(y=x,
sr=16000,
S=spect,
n_mfcc=20,
dct_type=2,
norm='ortho',
lifter=0)
assert np.mean( (a-b)**2) < 1e-8
"""
if spect is None:
spect = melspectrogram(x, sr=sr, **kwargs)
M = scipy.fftpack.dct(spect, axis=0, type=dct_type, norm=norm)[:n_mfcc]
if lifter > 0:
factor = np.sin(np.pi * np.arange(1, 1 + n_mfcc, dtype=M.dtype) /
lifter)
return M * factor[:, np.newaxis]
elif lifter == 0:
return M
else:
raise ParameterError(
f"MFCC lifter={lifter} must be a non-negative number")
def melspectrogram(x: array,
sr: int=16000,
window_size: int=512,
hop_length: int=320,
n_mels: int=64,
fmin: int=50,
fmax: Optional[float]=None,
window: str='hann',
center: bool=True,
pad_mode: str='reflect',
power: float=2.0,
to_db: bool=True,
ref: float=1.0,
amin: float=1e-10,
top_db: Optional[float]=None) -> array:
"""Compute mel-spectrogram.
Parameters:
x: numpy.ndarray
The input wavform is a numpy array [shape=(n,)]
window_size: int, typically 512, 1024, 2048, etc.
The window size for framing, also used as n_fft for stft
Returns:
The mel-spectrogram in power scale or db scale(default)
Notes:
1. sr is default to 16000, which is commonly used in speech/speaker processing.
2. when fmax is None, it is set to sr//2.
3. this function will convert mel spectgrum to db scale by default. This is different
that of librosa.
"""
_check_audio(x, mono=True)
if len(x) <= 0:
raise ParameterError('The input waveform is empty')
if fmax is None:
fmax = sr // 2
if fmin < 0 or fmin >= fmax:
raise ParameterError('fmin and fmax must statisfy 0<fmin<fmax')
s = stft(
x,
n_fft=window_size,
hop_length=hop_length,
win_length=window_size,
window=window,
center=center,
pad_mode=pad_mode)
spect_power = np.abs(s)**power
fb_matrix = compute_fbank_matrix(
sr=sr, n_fft=window_size, n_mels=n_mels, fmin=fmin, fmax=fmax)
mel_spect = np.matmul(fb_matrix, spect_power)
if to_db:
return power_to_db(mel_spect, ref=ref, amin=amin, top_db=top_db)
else:
return mel_spect
def spectrogram(x: array,
sr: int=16000,
window_size: int=512,
hop_length: int=320,
window: str='hann',
center: bool=True,
pad_mode: str='reflect',
power: float=2.0) -> array:
"""Compute spectrogram from an input waveform.
This function is a wrapper for librosa.feature.stft, with addition step to
compute the magnitude of the complex spectrogram.
"""
s = stft(
x,
n_fft=window_size,
hop_length=hop_length,
win_length=window_size,
window=window,
center=center,
pad_mode=pad_mode)
return np.abs(s)**power
def mu_encode(x: array, mu: int=255, quantized: bool=True) -> array:
"""Mu-law encoding.
Compute the mu-law decoding given an input code.
When quantized is True, the result will be converted to
integer in range [0,mu-1]. Otherwise, the resulting signal
is in range [-1,1]
Reference:
https://en.wikipedia.org/wiki/%CE%9C-law_algorithm
"""
mu = 255
y = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu)
if quantized:
y = np.floor((y + 1) / 2 * mu + 0.5) # convert to [0 , mu-1]
return y
def mu_decode(y: array, mu: int=255, quantized: bool=True) -> array:
"""Mu-law decoding.
Compute the mu-law decoding given an input code.
it assumes that the input y is in
range [0,mu-1] when quantize is True and [-1,1] otherwise
Reference:
https://en.wikipedia.org/wiki/%CE%9C-law_algorithm
"""
if mu < 1:
raise ParameterError('mu is typically set as 2**k-1, k=1, 2, 3,...')
mu = mu - 1
if quantized: # undo the quantization
y = y * 2 / mu - 1
x = np.sign(y) / mu * ((1 + mu)**np.abs(y) - 1)
return x
def randint(high: int) -> int:
"""Generate one random integer in range [0 high)
This is a helper function for random data augmentaiton
"""
return int(np.random.randint(0, high=high))
def rand() -> float:
"""Generate one floating-point number in range [0 1)
This is a helper function for random data augmentaiton
"""
return float(np.random.rand(1))
def depth_augment(y: array,
choices: List=['int8', 'int16'],
probs: List[float]=[0.5, 0.5]) -> array:
""" Audio depth augmentation
Do audio depth augmentation to simulate the distortion brought by quantization.
"""
assert len(probs) == len(
choices
), 'number of choices {} must be equal to size of probs {}'.format(
len(choices), len(probs))
depth = np.random.choice(choices, p=probs)
src_depth = y.dtype
y1 = depth_convert(y, depth)
y2 = depth_convert(y1, src_depth)
return y2
def adaptive_spect_augment(spect: array, tempo_axis: int=0,
level: float=0.1) -> array:
"""Do adpative spectrogram augmentation
The level of the augmentation is gowern by the paramter level,
ranging from 0 to 1, with 0 represents no augmentation
"""
assert spect.ndim == 2., 'only supports 2d tensor or numpy array'
if tempo_axis == 0:
nt, nf = spect.shape
else:
nf, nt = spect.shape
time_mask_width = int(nt * level * 0.5)
freq_mask_width = int(nf * level * 0.5)
num_time_mask = int(10 * level)
num_freq_mask = int(10 * level)
if tempo_axis == 0:
for _ in range(num_time_mask):
start = randint(nt - time_mask_width)
spect[start:start + time_mask_width, :] = 0
for _ in range(num_freq_mask):
start = randint(nf - freq_mask_width)
spect[:, start:start + freq_mask_width] = 0
else:
for _ in range(num_time_mask):
start = randint(nt - time_mask_width)
spect[:, start:start + time_mask_width] = 0
for _ in range(num_freq_mask):
start = randint(nf - freq_mask_width)
spect[start:start + freq_mask_width, :] = 0
return spect
def spect_augment(spect: array,
tempo_axis: int=0,
max_time_mask: int=3,
max_freq_mask: int=3,
max_time_mask_width: int=30,
max_freq_mask_width: int=20) -> array:
"""Do spectrogram augmentation in both time and freq axis
Reference:
"""
assert spect.ndim == 2., 'only supports 2d tensor or numpy array'
if tempo_axis == 0:
nt, nf = spect.shape
else:
nf, nt = spect.shape
num_time_mask = randint(max_time_mask)
num_freq_mask = randint(max_freq_mask)
time_mask_width = randint(max_time_mask_width)
freq_mask_width = randint(max_freq_mask_width)
if tempo_axis == 0:
for _ in range(num_time_mask):
start = randint(nt - time_mask_width)
spect[start:start + time_mask_width, :] = 0
for _ in range(num_freq_mask):
start = randint(nf - freq_mask_width)
spect[:, start:start + freq_mask_width] = 0
else:
for _ in range(num_time_mask):
start = randint(nt - time_mask_width)
spect[:, start:start + time_mask_width] = 0
for _ in range(num_freq_mask):
start = randint(nf - freq_mask_width)
spect[start:start + freq_mask_width, :] = 0
return spect
def random_crop1d(y: array, crop_len: int) -> array:
""" Do random cropping on 1d input signal
The input is a 1d signal, typically a sound waveform
"""
if y.ndim != 1:
'only accept 1d tensor or numpy array'
n = len(y)
idx = randint(n - crop_len)
return y[idx:idx + crop_len]
def random_crop2d(s: array, crop_len: int, tempo_axis: int=0) -> array:
""" Do random cropping for 2D array, typically a spectrogram.
The cropping is done in temporal direction on the time-freq input signal.
"""
if tempo_axis >= s.ndim:
raise ParameterError('axis out of range')
n = s.shape[tempo_axis]
idx = randint(high=n - crop_len)
sli = [slice(None) for i in range(s.ndim)]
sli[tempo_axis] = slice(idx, idx + crop_len)
out = s[tuple(sli)]
return out

@ -15,10 +15,3 @@ from .esc50 import ESC50
from .gtzan import GTZAN
from .tess import TESS
from .urban_sound import UrbanSound8K
__all__ = [
'ESC50',
'UrbanSound8K',
'GTZAN',
'TESS',
]

@ -17,8 +17,8 @@ import numpy as np
import paddle
from ..backends import load as load_audio
from ..features import melspectrogram
from ..features import mfcc
from ..compliance.librosa import melspectrogram
from ..compliance.librosa import mfcc
feat_funcs = {
'raw': None,

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .librosa import LogMelSpectrogram
from .librosa import MelSpectrogram
from .librosa import Spectrogram
from .layers import LogMelSpectrogram
from .layers import MelSpectrogram
from .layers import MFCC
from .layers import Spectrogram

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from functools import partial
from typing import Optional
from typing import Union
@ -19,225 +18,19 @@ from typing import Union
import paddle
import paddle.nn as nn
from ..functional import compute_fbank_matrix
from ..functional import create_dct
from ..functional import power_to_db
from ..functional.window import get_window
__all__ = [
'Spectrogram',
'MelSpectrogram',
'LogMelSpectrogram',
'MFCC',
]
def hz_to_mel(freq: Union[paddle.Tensor, float],
htk: bool=False) -> Union[paddle.Tensor, float]:
"""Convert Hz to Mels.
Parameters:
freq: the input tensor of arbitrary shape, or a single floating point number.
htk: use HTK formula to do the conversion.
The default value is False.
Returns:
The frequencies represented in Mel-scale.
"""
if htk:
if isinstance(freq, paddle.Tensor):
return 2595.0 * paddle.log10(1.0 + freq / 700.0)
else:
return 2595.0 * math.log10(1.0 + freq / 700.0)
# Fill in the linear part
f_min = 0.0
f_sp = 200.0 / 3
mels = (freq - f_min) / f_sp
# Fill in the log-scale part
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = math.log(6.4) / 27.0 # step size for log region
if isinstance(freq, paddle.Tensor):
target = min_log_mel + paddle.log(
freq / min_log_hz + 1e-10) / logstep # prevent nan with 1e-10
mask = (freq > min_log_hz).astype(freq.dtype)
mels = target * mask + mels * (
1 - mask) # will replace by masked_fill OP in future
else:
if freq >= min_log_hz:
mels = min_log_mel + math.log(freq / min_log_hz + 1e-10) / logstep
return mels
def mel_to_hz(mel: Union[float, paddle.Tensor],
htk: bool=False) -> Union[float, paddle.Tensor]:
"""Convert mel bin numbers to frequencies.
Parameters:
mel: the mel frequency represented as a tensor of arbitrary shape, or a floating point number.
htk: use HTK formula to do the conversion.
Returns:
The frequencies represented in hz.
"""
if htk:
return 700.0 * (10.0**(mel / 2595.0) - 1.0)
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mel
# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = math.log(6.4) / 27.0 # step size for log region
if isinstance(mel, paddle.Tensor):
target = min_log_hz * paddle.exp(logstep * (mel - min_log_mel))
mask = (mel > min_log_mel).astype(mel.dtype)
freqs = target * mask + freqs * (
1 - mask) # will replace by masked_fill OP in future
else:
if mel >= min_log_mel:
freqs = min_log_hz * math.exp(logstep * (mel - min_log_mel))
return freqs
def mel_frequencies(n_mels: int=64,
f_min: float=0.0,
f_max: float=11025.0,
htk: bool=False,
dtype: str=paddle.float32):
"""Compute mel frequencies.
Parameters:
n_mels(int): number of Mel bins.
f_min(float): the lower cut-off frequency, below which the filter response is zero.
f_max(float): the upper cut-off frequency, above which the filter response is zero.
htk(bool): whether to use htk formula.
dtype(str): the datatype of the return frequencies.
Returns:
The frequencies represented in Mel-scale
"""
# 'Center freqs' of mel bands - uniformly spaced between limits
min_mel = hz_to_mel(f_min, htk=htk)
max_mel = hz_to_mel(f_max, htk=htk)
mels = paddle.linspace(min_mel, max_mel, n_mels, dtype=dtype)
freqs = mel_to_hz(mels, htk=htk)
return freqs
def fft_frequencies(sr: int, n_fft: int, dtype: str=paddle.float32):
"""Compute fourier frequencies.
Parameters:
sr(int): the audio sample rate.
n_fft(float): the number of fft bins.
dtype(str): the datatype of the return frequencies.
Returns:
The frequencies represented in hz.
"""
return paddle.linspace(0, float(sr) / 2, int(1 + n_fft // 2), dtype=dtype)
def compute_fbank_matrix(sr: int,
n_fft: int,
n_mels: int=64,
f_min: float=0.0,
f_max: Optional[float]=None,
htk: bool=False,
norm: Union[str, float]='slaney',
dtype: str=paddle.float32):
"""Compute fbank matrix.
Parameters:
sr(int): the audio sample rate.
n_fft(int): the number of fft bins.
n_mels(int): the number of Mel bins.
f_min(float): the lower cut-off frequency, below which the filter response is zero.
f_max(float): the upper cut-off frequency, above which the filter response is zero.
htk: whether to use htk formula.
return_complex(bool): whether to return complex matrix. If True, the matrix will
be complex type. Otherwise, the real and image part will be stored in the last
axis of returned tensor.
dtype(str): the datatype of the returned fbank matrix.
Returns:
The fbank matrix of shape (n_mels, int(1+n_fft//2)).
Shape:
output: (n_mels, int(1+n_fft//2))
"""
if f_max is None:
f_max = float(sr) / 2
# Initialize the weights
weights = paddle.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
# Center freqs of each FFT bin
fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft, dtype=dtype)
# 'Center freqs' of mel bands - uniformly spaced between limits
mel_f = mel_frequencies(
n_mels + 2, f_min=f_min, f_max=f_max, htk=htk, dtype=dtype)
fdiff = mel_f[1:] - mel_f[:-1] #np.diff(mel_f)
ramps = mel_f.unsqueeze(1) - fftfreqs.unsqueeze(0)
#ramps = np.subtract.outer(mel_f, fftfreqs)
for i in range(n_mels):
# lower and upper slopes for all bins
lower = -ramps[i] / fdiff[i]
upper = ramps[i + 2] / fdiff[i + 1]
# .. then intersect them with each other and zero
weights[i] = paddle.maximum(
paddle.zeros_like(lower), paddle.minimum(lower, upper))
# Slaney-style mel is scaled to be approx constant energy per channel
if norm == 'slaney':
enorm = 2.0 / (mel_f[2:n_mels + 2] - mel_f[:n_mels])
weights *= enorm.unsqueeze(1)
elif isinstance(norm, int) or isinstance(norm, float):
weights = paddle.nn.functional.normalize(weights, p=norm, axis=-1)
return weights
def power_to_db(magnitude: paddle.Tensor,
ref_value: float=1.0,
amin: float=1e-10,
top_db: Optional[float]=None) -> paddle.Tensor:
"""Convert a power spectrogram (amplitude squared) to decibel (dB) units.
The function computes the scaling ``10 * log10(x / ref)`` in a numerically
stable way.
Parameters:
magnitude(Tensor): the input magnitude tensor of any shape.
ref_value(float): the reference value. If smaller than 1.0, the db level
of the signal will be pulled up accordingly. Otherwise, the db level
is pushed down.
amin(float): the minimum value of input magnitude, below which the input
magnitude is clipped(to amin).
top_db(float): the maximum db value of resulting spectrum, above which the
spectrum is clipped(to top_db).
Returns:
The spectrogram in log-scale.
shape:
input: any shape
output: same as input
"""
if amin <= 0:
raise Exception("amin must be strictly positive")
if ref_value <= 0:
raise Exception("ref_value must be strictly positive")
ones = paddle.ones_like(magnitude)
log_spec = 10.0 * paddle.log10(paddle.maximum(ones * amin, magnitude))
log_spec -= 10.0 * math.log10(max(ref_value, amin))
if top_db is not None:
if top_db < 0:
raise Exception("top_db must be non-negative")
log_spec = paddle.maximum(log_spec, ones * (log_spec.max() - top_db))
return log_spec
class Spectrogram(nn.Layer):
def __init__(self,
n_fft: int=512,
@ -459,3 +252,29 @@ class LogMelSpectrogram(nn.Layer):
amin=self.amin,
top_db=self.top_db)
return log_mel_feature
class MFCC(nn.Layer):
def __init__(self,
sr: int=22050,
n_mfcc: int=40,
norm: str='ortho',
**kwargs):
"""[summary]
Parameters:
sr (int, optional): [description]. Defaults to 22050.
n_mfcc (int, optional): [description]. Defaults to 40.
norm (str, optional): [description]. Defaults to 'ortho'.
"""
super(MFCC, self).__init__()
self._log_melspectrogram = LogMelSpectrogram(sr=sr, **kwargs)
self.dct_matrix = create_dct(
n_mfcc=n_mfcc, n_mels=self._log_melspectrogram.n_mels, norm=norm)
self.register_buffer('dct_matrix', self.dct_matrix)
def forward(self, x):
log_mel_feature = self._log_melspectrogram(x)
mfcc = paddle.matmul(
log_mel_feature.transpose((0, 2, 1)), self.dct_matrix).transpose(
(0, 2, 1)) # (B, n_mels, L)
return mfcc

@ -11,3 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .functional import compute_fbank_matrix
from .functional import create_dct
from .functional import fft_frequencies
from .functional import hz_to_mel
from .functional import mel_frequencies
from .functional import mel_to_hz
from .functional import power_to_db

@ -12,146 +12,39 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from librosa(https://github.com/librosa/librosa)
import warnings
from typing import List
import math
from typing import Optional
from typing import Union
import numpy as np
import scipy
from numpy import ndarray as array
from numpy.lib.stride_tricks import as_strided
from scipy import signal
from ..backends import depth_convert
from ..utils import ParameterError
import paddle
__all__ = [
# dsp
'stft',
'mfcc',
'hz_to_mel',
'mel_to_hz',
'split_frames',
'mel_frequencies',
'power_to_db',
'fft_frequencies',
'compute_fbank_matrix',
'melspectrogram',
'spectrogram',
'mu_encode',
'mu_decode',
# augmentation
'depth_augment',
'spect_augment',
'random_crop1d',
'random_crop2d',
'adaptive_spect_augment',
'power_to_db',
'create_dct',
]
def pad_center(data: array, size: int, axis: int=-1, **kwargs) -> array:
"""Pad an array to a target length along a target axis.
This differs from `np.pad` by centering the data prior to padding,
analogous to `str.center`
"""
kwargs.setdefault("mode", "constant")
n = data.shape[axis]
lpad = int((size - n) // 2)
lengths = [(0, 0)] * data.ndim
lengths[axis] = (lpad, int(size - n - lpad))
if lpad < 0:
raise ParameterError(("Target size ({size:d}) must be "
"at least input size ({n:d})"))
return np.pad(data, lengths, **kwargs)
def split_frames(x: array, frame_length: int, hop_length: int,
axis: int=-1) -> array:
"""Slice a data array into (overlapping) frames.
This function is aligned with librosa.frame
"""
if not isinstance(x, np.ndarray):
raise ParameterError(
f"Input must be of type numpy.ndarray, given type(x)={type(x)}")
if x.shape[axis] < frame_length:
raise ParameterError(f"Input is too short (n={x.shape[axis]:d})"
f" for frame_length={frame_length:d}")
if hop_length < 1:
raise ParameterError(f"Invalid hop_length: {hop_length:d}")
if axis == -1 and not x.flags["F_CONTIGUOUS"]:
warnings.warn(f"librosa.util.frame called with axis={axis} "
"on a non-contiguous input. This will result in a copy.")
x = np.asfortranarray(x)
elif axis == 0 and not x.flags["C_CONTIGUOUS"]:
warnings.warn(f"librosa.util.frame called with axis={axis} "
"on a non-contiguous input. This will result in a copy.")
x = np.ascontiguousarray(x)
n_frames = 1 + (x.shape[axis] - frame_length) // hop_length
strides = np.asarray(x.strides)
new_stride = np.prod(strides[strides > 0] // x.itemsize) * x.itemsize
if axis == -1:
shape = list(x.shape)[:-1] + [frame_length, n_frames]
strides = list(strides) + [hop_length * new_stride]
elif axis == 0:
shape = [n_frames, frame_length] + list(x.shape)[1:]
strides = [hop_length * new_stride] + list(strides)
else:
raise ParameterError(f"Frame axis={axis} must be either 0 or -1")
return as_strided(x, shape=shape, strides=strides)
def _check_audio(y, mono=True) -> bool:
"""Determine whether a variable contains valid audio data.
The audio y must be a np.ndarray, ether 1-channel or two channel
"""
if not isinstance(y, np.ndarray):
raise ParameterError("Audio data must be of type numpy.ndarray")
if y.ndim > 2:
raise ParameterError(
f"Invalid shape for audio ndim={y.ndim:d}, shape={y.shape}")
if mono and y.ndim == 2:
raise ParameterError(
f"Invalid shape for mono audio ndim={y.ndim:d}, shape={y.shape}")
if (mono and len(y) == 0) or (not mono and y.shape[1] < 0):
raise ParameterError(f"Audio is empty ndim={y.ndim:d}, shape={y.shape}")
if not np.issubdtype(y.dtype, np.floating):
raise ParameterError("Audio data must be floating-point")
if not np.isfinite(y).all():
raise ParameterError("Audio buffer is not finite everywhere")
return True
def hz_to_mel(frequencies: Union[float, List[float], array],
htk: bool=False) -> array:
"""Convert Hz to Mels
This function is aligned with librosa.
def hz_to_mel(freq: Union[paddle.Tensor, float],
htk: bool=False) -> Union[paddle.Tensor, float]:
"""Convert Hz to Mels.
Parameters:
freq: the input tensor of arbitrary shape, or a single floating point number.
htk: use HTK formula to do the conversion.
The default value is False.
Returns:
The frequencies represented in Mel-scale.
"""
freq = np.asanyarray(frequencies)
if htk:
return 2595.0 * np.log10(1.0 + freq / 700.0)
if isinstance(freq, paddle.Tensor):
return 2595.0 * paddle.log10(1.0 + freq / 700.0)
else:
return 2595.0 * math.log10(1.0 + freq / 700.0)
# Fill in the linear part
f_min = 0.0
@ -163,107 +56,129 @@ def hz_to_mel(frequencies: Union[float, List[float], array],
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = np.log(6.4) / 27.0 # step size for log region
if freq.ndim:
# If we have array data, vectorize
log_t = freq >= min_log_hz
mels[log_t] = min_log_mel + \
np.log(freq[log_t] / min_log_hz) / logstep
elif freq >= min_log_hz:
# If we have scalar data, heck directly
mels = min_log_mel + np.log(freq / min_log_hz) / logstep
logstep = math.log(6.4) / 27.0 # step size for log region
if isinstance(freq, paddle.Tensor):
target = min_log_mel + paddle.log(
freq / min_log_hz + 1e-10) / logstep # prevent nan with 1e-10
mask = (freq > min_log_hz).astype(freq.dtype)
mels = target * mask + mels * (
1 - mask) # will replace by masked_fill OP in future
else:
if freq >= min_log_hz:
mels = min_log_mel + math.log(freq / min_log_hz + 1e-10) / logstep
return mels
def mel_to_hz(mels: Union[float, List[float], array], htk: int=False) -> array:
def mel_to_hz(mel: Union[float, paddle.Tensor],
htk: bool=False) -> Union[float, paddle.Tensor]:
"""Convert mel bin numbers to frequencies.
This function is aligned with librosa.
Parameters:
mel: the mel frequency represented as a tensor of arbitrary shape, or a floating point number.
htk: use HTK formula to do the conversion.
Returns:
The frequencies represented in hz.
"""
mel_array = np.asanyarray(mels)
if htk:
return 700.0 * (10.0**(mel_array / 2595.0) - 1.0)
return 700.0 * (10.0**(mel / 2595.0) - 1.0)
# Fill in the linear scale
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mel_array
freqs = f_min + f_sp * mel
# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = np.log(6.4) / 27.0 # step size for log region
if mel_array.ndim:
# If we have vector data, vectorize
log_t = mel_array >= min_log_mel
freqs[log_t] = min_log_hz * \
np.exp(logstep * (mel_array[log_t] - min_log_mel))
elif mel_array >= min_log_mel:
# If we have scalar data, check directly
freqs = min_log_hz * np.exp(logstep * (mel_array - min_log_mel))
logstep = math.log(6.4) / 27.0 # step size for log region
if isinstance(mel, paddle.Tensor):
target = min_log_hz * paddle.exp(logstep * (mel - min_log_mel))
mask = (mel > min_log_mel).astype(mel.dtype)
freqs = target * mask + freqs * (
1 - mask) # will replace by masked_fill OP in future
else:
if mel >= min_log_mel:
freqs = min_log_hz * math.exp(logstep * (mel - min_log_mel))
return freqs
def mel_frequencies(n_mels: int=128,
fmin: float=0.0,
fmax: float=11025.0,
htk: bool=False) -> array:
"""Compute mel frequencies
This function is aligned with librosa.
def mel_frequencies(n_mels: int=64,
f_min: float=0.0,
f_max: float=11025.0,
htk: bool=False,
dtype: str=paddle.float32):
"""Compute mel frequencies.
Parameters:
n_mels(int): number of Mel bins.
f_min(float): the lower cut-off frequency, below which the filter response is zero.
f_max(float): the upper cut-off frequency, above which the filter response is zero.
htk(bool): whether to use htk formula.
dtype(str): the datatype of the return frequencies.
Returns:
The frequencies represented in Mel-scale
"""
# 'Center freqs' of mel bands - uniformly spaced between limits
min_mel = hz_to_mel(fmin, htk=htk)
max_mel = hz_to_mel(fmax, htk=htk)
mels = np.linspace(min_mel, max_mel, n_mels)
return mel_to_hz(mels, htk=htk)
min_mel = hz_to_mel(f_min, htk=htk)
max_mel = hz_to_mel(f_max, htk=htk)
mels = paddle.linspace(min_mel, max_mel, n_mels, dtype=dtype)
freqs = mel_to_hz(mels, htk=htk)
return freqs
def fft_frequencies(sr: int, n_fft: int) -> array:
def fft_frequencies(sr: int, n_fft: int, dtype: str=paddle.float32):
"""Compute fourier frequencies.
This function is aligned with librosa.
Parameters:
sr(int): the audio sample rate.
n_fft(float): the number of fft bins.
dtype(str): the datatype of the return frequencies.
Returns:
The frequencies represented in hz.
"""
return np.linspace(0, float(sr) / 2, int(1 + n_fft // 2), endpoint=True)
return paddle.linspace(0, float(sr) / 2, int(1 + n_fft // 2), dtype=dtype)
def compute_fbank_matrix(sr: int,
n_fft: int,
n_mels: int=128,
fmin: float=0.0,
fmax: Optional[float]=None,
n_mels: int=64,
f_min: float=0.0,
f_max: Optional[float]=None,
htk: bool=False,
norm: str="slaney",
dtype: type=np.float32):
norm: Union[str, float]='slaney',
dtype: str=paddle.float32):
"""Compute fbank matrix.
This funciton is aligned with librosa.
Parameters:
sr(int): the audio sample rate.
n_fft(int): the number of fft bins.
n_mels(int): the number of Mel bins.
f_min(float): the lower cut-off frequency, below which the filter response is zero.
f_max(float): the upper cut-off frequency, above which the filter response is zero.
htk: whether to use htk formula.
return_complex(bool): whether to return complex matrix. If True, the matrix will
be complex type. Otherwise, the real and image part will be stored in the last
axis of returned tensor.
dtype(str): the datatype of the returned fbank matrix.
Returns:
The fbank matrix of shape (n_mels, int(1+n_fft//2)).
Shape:
output: (n_mels, int(1+n_fft//2))
"""
if norm != "slaney":
raise ParameterError('norm must set to slaney')
if fmax is None:
fmax = float(sr) / 2
if f_max is None:
f_max = float(sr) / 2
# Initialize the weights
n_mels = int(n_mels)
weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
weights = paddle.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
# Center freqs of each FFT bin
fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft)
fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft, dtype=dtype)
# 'Center freqs' of mel bands - uniformly spaced between limits
mel_f = mel_frequencies(n_mels + 2, fmin=fmin, fmax=fmax, htk=htk)
mel_f = mel_frequencies(
n_mels + 2, f_min=f_min, f_max=f_max, htk=htk, dtype=dtype)
fdiff = np.diff(mel_f)
ramps = np.subtract.outer(mel_f, fftfreqs)
fdiff = mel_f[1:] - mel_f[:-1] #np.diff(mel_f)
ramps = mel_f.unsqueeze(1) - fftfreqs.unsqueeze(0)
#ramps = np.subtract.outer(mel_f, fftfreqs)
for i in range(n_mels):
# lower and upper slopes for all bins
@ -271,458 +186,79 @@ def compute_fbank_matrix(sr: int,
upper = ramps[i + 2] / fdiff[i + 1]
# .. then intersect them with each other and zero
weights[i] = np.maximum(0, np.minimum(lower, upper))
weights[i] = paddle.maximum(
paddle.zeros_like(lower), paddle.minimum(lower, upper))
if norm == "slaney":
# Slaney-style mel is scaled to be approx constant energy per channel
# Slaney-style mel is scaled to be approx constant energy per channel
if norm == 'slaney':
enorm = 2.0 / (mel_f[2:n_mels + 2] - mel_f[:n_mels])
weights *= enorm[:, np.newaxis]
# Only check weights if f_mel[0] is positive
if not np.all((mel_f[:-2] == 0) | (weights.max(axis=1) > 0)):
# This means we have an empty channel somewhere
warnings.warn("Empty filters detected in mel frequency basis. "
"Some channels will produce empty responses. "
"Try increasing your sampling rate (and fmax) or "
"reducing n_mels.")
weights *= enorm.unsqueeze(1)
elif isinstance(norm, int) or isinstance(norm, float):
weights = paddle.nn.functional.normalize(weights, p=norm, axis=-1)
return weights
def stft(x: array,
n_fft: int=2048,
hop_length: Optional[int]=None,
win_length: Optional[int]=None,
window: str="hann",
center: bool=True,
dtype: type=np.complex64,
pad_mode: str="reflect") -> array:
"""Short-time Fourier transform (STFT).
This function is aligned with librosa.
"""
_check_audio(x)
# By default, use the entire frame
if win_length is None:
win_length = n_fft
# Set the default hop, if it's not already specified
if hop_length is None:
hop_length = int(win_length // 4)
fft_window = signal.get_window(window, win_length, fftbins=True)
# Pad the window out to n_fft size
fft_window = pad_center(fft_window, n_fft)
# Reshape so that the window can be broadcast
fft_window = fft_window.reshape((-1, 1))
# Pad the time series so that frames are centered
if center:
if n_fft > x.shape[-1]:
warnings.warn(
f"n_fft={n_fft} is too small for input signal of length={x.shape[-1]}"
)
x = np.pad(x, int(n_fft // 2), mode=pad_mode)
elif n_fft > x.shape[-1]:
raise ParameterError(
f"n_fft={n_fft} is too small for input signal of length={x.shape[-1]}"
)
# Window the time series.
x_frames = split_frames(x, frame_length=n_fft, hop_length=hop_length)
# Pre-allocate the STFT matrix
stft_matrix = np.empty(
(int(1 + n_fft // 2), x_frames.shape[1]), dtype=dtype, order="F")
fft = np.fft # use numpy fft as default
# Constrain STFT block sizes to 256 KB
MAX_MEM_BLOCK = 2**8 * 2**10
# how many columns can we fit within MAX_MEM_BLOCK?
n_columns = MAX_MEM_BLOCK // (stft_matrix.shape[0] * stft_matrix.itemsize)
n_columns = max(n_columns, 1)
for bl_s in range(0, stft_matrix.shape[1], n_columns):
bl_t = min(bl_s + n_columns, stft_matrix.shape[1])
stft_matrix[:, bl_s:bl_t] = fft.rfft(
fft_window * x_frames[:, bl_s:bl_t], axis=0)
return stft_matrix
def power_to_db(spect: array,
ref: float=1.0,
def power_to_db(magnitude: paddle.Tensor,
ref_value: float=1.0,
amin: float=1e-10,
top_db: Optional[float]=80.0) -> array:
"""Convert a power spectrogram (amplitude squared) to decibel (dB) units
This computes the scaling ``10 * log10(spect / ref)`` in a numerically
top_db: Optional[float]=None) -> paddle.Tensor:
"""Convert a power spectrogram (amplitude squared) to decibel (dB) units.
The function computes the scaling ``10 * log10(x / ref)`` in a numerically
stable way.
This function is aligned with librosa.
Parameters:
magnitude(Tensor): the input magnitude tensor of any shape.
ref_value(float): the reference value. If smaller than 1.0, the db level
of the signal will be pulled up accordingly. Otherwise, the db level
is pushed down.
amin(float): the minimum value of input magnitude, below which the input
magnitude is clipped(to amin).
top_db(float): the maximum db value of resulting spectrum, above which the
spectrum is clipped(to top_db).
Returns:
The spectrogram in log-scale.
shape:
input: any shape
output: same as input
"""
spect = np.asarray(spect)
if amin <= 0:
raise ParameterError("amin must be strictly positive")
if np.issubdtype(spect.dtype, np.complexfloating):
warnings.warn(
"power_to_db was called on complex input so phase "
"information will be discarded. To suppress this warning, "
"call power_to_db(np.abs(D)**2) instead.")
magnitude = np.abs(spect)
else:
magnitude = spect
raise Exception("amin must be strictly positive")
if callable(ref):
# User supplied a function to calculate reference power
ref_value = ref(magnitude)
else:
ref_value = np.abs(ref)
if ref_value <= 0:
raise Exception("ref_value must be strictly positive")
log_spec = 10.0 * np.log10(np.maximum(amin, magnitude))
log_spec -= 10.0 * np.log10(np.maximum(amin, ref_value))
ones = paddle.ones_like(magnitude)
log_spec = 10.0 * paddle.log10(paddle.maximum(ones * amin, magnitude))
log_spec -= 10.0 * math.log10(max(ref_value, amin))
if top_db is not None:
if top_db < 0:
raise ParameterError("top_db must be non-negative")
log_spec = np.maximum(log_spec, log_spec.max() - top_db)
raise Exception("top_db must be non-negative")
log_spec = paddle.maximum(log_spec, ones * (log_spec.max() - top_db))
return log_spec
def mfcc(x,
sr: int=16000,
spect: Optional[array]=None,
n_mfcc: int=20,
dct_type: int=2,
norm: str="ortho",
lifter: int=0,
**kwargs) -> array:
"""Mel-frequency cepstral coefficients (MFCCs)
This function is NOT strictly aligned with librosa. The following example shows how to get the
same result with librosa:
# mfcc:
kwargs = {
'window_size':512,
'hop_length':320,
'mel_bins':64,
'fmin':50,
'to_db':False}
a = mfcc(x,
spect=None,
n_mfcc=20,
dct_type=2,
norm='ortho',
lifter=0,
**kwargs)
# librosa mfcc:
spect = librosa.feature.melspectrogram(y=x,sr=16000,n_fft=512,
win_length=512,
hop_length=320,
n_mels=64, fmin=50)
b = librosa.feature.mfcc(y=x,
sr=16000,
S=spect,
n_mfcc=20,
dct_type=2,
norm='ortho',
lifter=0)
assert np.mean( (a-b)**2) < 1e-8
"""
if spect is None:
spect = melspectrogram(x, sr=sr, **kwargs)
M = scipy.fftpack.dct(spect, axis=0, type=dct_type, norm=norm)[:n_mfcc]
if lifter > 0:
factor = np.sin(np.pi * np.arange(1, 1 + n_mfcc, dtype=M.dtype) /
lifter)
return M * factor[:, np.newaxis]
elif lifter == 0:
return M
else:
raise ParameterError(
f"MFCC lifter={lifter} must be a non-negative number")
def melspectrogram(x: array,
sr: int=16000,
window_size: int=512,
hop_length: int=320,
n_mels: int=64,
fmin: int=50,
fmax: Optional[float]=None,
window: str='hann',
center: bool=True,
pad_mode: str='reflect',
power: float=2.0,
to_db: bool=True,
ref: float=1.0,
amin: float=1e-10,
top_db: Optional[float]=None) -> array:
"""Compute mel-spectrogram.
def create_dct(n_mfcc: int,
n_mels: int,
norm: Optional[str]='ortho',
dtype: Optional[str]=paddle.float32):
"""[summary]
Parameters:
x: numpy.ndarray
The input wavform is a numpy array [shape=(n,)]
window_size: int, typically 512, 1024, 2048, etc.
The window size for framing, also used as n_fft for stft
n_mfcc (int): [description]
n_mels (int): [description]
norm (str, optional): [description]. Defaults to 'ortho'.
Returns:
The mel-spectrogram in power scale or db scale(default)
Notes:
1. sr is default to 16000, which is commonly used in speech/speaker processing.
2. when fmax is None, it is set to sr//2.
3. this function will convert mel spectgrum to db scale by default. This is different
that of librosa.
"""
_check_audio(x, mono=True)
if len(x) <= 0:
raise ParameterError('The input waveform is empty')
if fmax is None:
fmax = sr // 2
if fmin < 0 or fmin >= fmax:
raise ParameterError('fmin and fmax must statisfy 0<fmin<fmax')
s = stft(
x,
n_fft=window_size,
hop_length=hop_length,
win_length=window_size,
window=window,
center=center,
pad_mode=pad_mode)
spect_power = np.abs(s)**power
fb_matrix = compute_fbank_matrix(
sr=sr, n_fft=window_size, n_mels=n_mels, fmin=fmin, fmax=fmax)
mel_spect = np.matmul(fb_matrix, spect_power)
if to_db:
return power_to_db(mel_spect, ref=ref, amin=amin, top_db=top_db)
else:
return mel_spect
def spectrogram(x: array,
sr: int=16000,
window_size: int=512,
hop_length: int=320,
window: str='hann',
center: bool=True,
pad_mode: str='reflect',
power: float=2.0) -> array:
"""Compute spectrogram from an input waveform.
This function is a wrapper for librosa.feature.stft, with addition step to
compute the magnitude of the complex spectrogram.
"""
s = stft(
x,
n_fft=window_size,
hop_length=hop_length,
win_length=window_size,
window=window,
center=center,
pad_mode=pad_mode)
return np.abs(s)**power
def mu_encode(x: array, mu: int=255, quantized: bool=True) -> array:
"""Mu-law encoding.
Compute the mu-law decoding given an input code.
When quantized is True, the result will be converted to
integer in range [0,mu-1]. Otherwise, the resulting signal
is in range [-1,1]
Reference:
https://en.wikipedia.org/wiki/%CE%9C-law_algorithm
"""
mu = 255
y = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu)
if quantized:
y = np.floor((y + 1) / 2 * mu + 0.5) # convert to [0 , mu-1]
return y
def mu_decode(y: array, mu: int=255, quantized: bool=True) -> array:
"""Mu-law decoding.
Compute the mu-law decoding given an input code.
it assumes that the input y is in
range [0,mu-1] when quantize is True and [-1,1] otherwise
Reference:
https://en.wikipedia.org/wiki/%CE%9C-law_algorithm
"""
if mu < 1:
raise ParameterError('mu is typically set as 2**k-1, k=1, 2, 3,...')
mu = mu - 1
if quantized: # undo the quantization
y = y * 2 / mu - 1
x = np.sign(y) / mu * ((1 + mu)**np.abs(y) - 1)
return x
def randint(high: int) -> int:
"""Generate one random integer in range [0 high)
This is a helper function for random data augmentaiton
"""
return int(np.random.randint(0, high=high))
def rand() -> float:
"""Generate one floating-point number in range [0 1)
This is a helper function for random data augmentaiton
"""
return float(np.random.rand(1))
def depth_augment(y: array,
choices: List=['int8', 'int16'],
probs: List[float]=[0.5, 0.5]) -> array:
""" Audio depth augmentation
Do audio depth augmentation to simulate the distortion brought by quantization.
"""
assert len(probs) == len(
choices
), 'number of choices {} must be equal to size of probs {}'.format(
len(choices), len(probs))
depth = np.random.choice(choices, p=probs)
src_depth = y.dtype
y1 = depth_convert(y, depth)
y2 = depth_convert(y1, src_depth)
return y2
def adaptive_spect_augment(spect: array, tempo_axis: int=0,
level: float=0.1) -> array:
"""Do adpative spectrogram augmentation
The level of the augmentation is gowern by the paramter level,
ranging from 0 to 1, with 0 represents no augmentation
"""
assert spect.ndim == 2., 'only supports 2d tensor or numpy array'
if tempo_axis == 0:
nt, nf = spect.shape
else:
nf, nt = spect.shape
time_mask_width = int(nt * level * 0.5)
freq_mask_width = int(nf * level * 0.5)
num_time_mask = int(10 * level)
num_freq_mask = int(10 * level)
if tempo_axis == 0:
for _ in range(num_time_mask):
start = randint(nt - time_mask_width)
spect[start:start + time_mask_width, :] = 0
for _ in range(num_freq_mask):
start = randint(nf - freq_mask_width)
spect[:, start:start + freq_mask_width] = 0
else:
for _ in range(num_time_mask):
start = randint(nt - time_mask_width)
spect[:, start:start + time_mask_width] = 0
for _ in range(num_freq_mask):
start = randint(nf - freq_mask_width)
spect[start:start + freq_mask_width, :] = 0
return spect
def spect_augment(spect: array,
tempo_axis: int=0,
max_time_mask: int=3,
max_freq_mask: int=3,
max_time_mask_width: int=30,
max_freq_mask_width: int=20) -> array:
"""Do spectrogram augmentation in both time and freq axis
Reference:
[type]: [description]
"""
assert spect.ndim == 2., 'only supports 2d tensor or numpy array'
if tempo_axis == 0:
nt, nf = spect.shape
else:
nf, nt = spect.shape
num_time_mask = randint(max_time_mask)
num_freq_mask = randint(max_freq_mask)
time_mask_width = randint(max_time_mask_width)
freq_mask_width = randint(max_freq_mask_width)
if tempo_axis == 0:
for _ in range(num_time_mask):
start = randint(nt - time_mask_width)
spect[start:start + time_mask_width, :] = 0
for _ in range(num_freq_mask):
start = randint(nf - freq_mask_width)
spect[:, start:start + freq_mask_width] = 0
n = paddle.arange(n_mels, dtype=dtype)
k = paddle.arange(n_mfcc, dtype=dtype).unsqueeze(1)
dct = paddle.cos(math.pi / float(n_mels) * (n + 0.5) *
k) # size (n_mfcc, n_mels)
if norm is None:
dct *= 2.0
else:
for _ in range(num_time_mask):
start = randint(nt - time_mask_width)
spect[:, start:start + time_mask_width] = 0
for _ in range(num_freq_mask):
start = randint(nf - freq_mask_width)
spect[start:start + freq_mask_width, :] = 0
return spect
def random_crop1d(y: array, crop_len: int) -> array:
""" Do random cropping on 1d input signal
The input is a 1d signal, typically a sound waveform
"""
if y.ndim != 1:
'only accept 1d tensor or numpy array'
n = len(y)
idx = randint(n - crop_len)
return y[idx:idx + crop_len]
def random_crop2d(s: array, crop_len: int, tempo_axis: int=0) -> array:
""" Do random cropping for 2D array, typically a spectrogram.
The cropping is done in temporal direction on the time-freq input signal.
"""
if tempo_axis >= s.ndim:
raise ParameterError('axis out of range')
n = s.shape[tempo_axis]
idx = randint(high=n - crop_len)
sli = [slice(None) for i in range(s.ndim)]
sli[tempo_axis] = slice(idx, idx + crop_len)
out = s[tuple(sli)]
return out
assert norm == "ortho"
dct[0] *= 1.0 / math.sqrt(2.0)
dct *= math.sqrt(2.0 / float(n_mels))
return dct.T

@ -1,4 +1,4 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -11,9 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .audio import depth_convert
from .audio import load
from .audio import normalize
from .audio import resample
from .audio import save_wav
from .audio import to_mono

@ -1,303 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Optional
from typing import Tuple
from typing import Union
import numpy as np
import resampy
import soundfile as sf
from numpy import ndarray as array
from scipy.io import wavfile
from ..utils import ParameterError
__all__ = [
'resample',
'to_mono',
'depth_convert',
'normalize',
'save_wav',
'load',
]
NORMALMIZE_TYPES = ['linear', 'gaussian']
MERGE_TYPES = ['ch0', 'ch1', 'random', 'average']
RESAMPLE_MODES = ['kaiser_best', 'kaiser_fast']
EPS = 1e-8
def resample(y: array, src_sr: int, target_sr: int,
mode: str='kaiser_fast') -> array:
""" Audio resampling
This function is the same as using resampy.resample().
Notes:
The default mode is kaiser_fast. For better audio quality, use mode = 'kaiser_fast'
"""
if mode == 'kaiser_best':
warnings.warn(
f'Using resampy in kaiser_best to {src_sr}=>{target_sr}. This function is pretty slow, \
we recommend the mode kaiser_fast in large scale audio trainning')
if not isinstance(y, np.ndarray):
raise ParameterError(
'Only support numpy array, but received y in {type(y)}')
if mode not in RESAMPLE_MODES:
raise ParameterError(f'resample mode must in {RESAMPLE_MODES}')
return resampy.resample(y, src_sr, target_sr, filter=mode)
def to_mono(y: array, merge_type: str='average') -> array:
""" convert sterior audio to mono
"""
if merge_type not in MERGE_TYPES:
raise ParameterError(
f'Unsupported merge type {merge_type}, available types are {MERGE_TYPES}'
)
if y.ndim > 2:
raise ParameterError(
f'Unsupported audio array, y.ndim > 2, the shape is {y.shape}')
if y.ndim == 1: # nothing to merge
return y
if merge_type == 'ch0':
return y[0]
if merge_type == 'ch1':
return y[1]
if merge_type == 'random':
return y[np.random.randint(0, 2)]
# need to do averaging according to dtype
if y.dtype == 'float32':
y_out = (y[0] + y[1]) * 0.5
elif y.dtype == 'int16':
y_out = y.astype('int32')
y_out = (y_out[0] + y_out[1]) // 2
y_out = np.clip(y_out, np.iinfo(y.dtype).min,
np.iinfo(y.dtype).max).astype(y.dtype)
elif y.dtype == 'int8':
y_out = y.astype('int16')
y_out = (y_out[0] + y_out[1]) // 2
y_out = np.clip(y_out, np.iinfo(y.dtype).min,
np.iinfo(y.dtype).max).astype(y.dtype)
else:
raise ParameterError(f'Unsupported dtype: {y.dtype}')
return y_out
def _safe_cast(y: array, dtype: Union[type, str]) -> array:
""" data type casting in a safe way, i.e., prevent overflow or underflow
This function is used internally.
"""
return np.clip(y, np.iinfo(dtype).min, np.iinfo(dtype).max).astype(dtype)
def depth_convert(y: array, dtype: Union[type, str],
dithering: bool=True) -> array:
"""Convert audio array to target dtype safely
This function convert audio waveform to a target dtype, with addition steps of
preventing overflow/underflow and preserving audio range.
"""
SUPPORT_DTYPE = ['int16', 'int8', 'float32', 'float64']
if y.dtype not in SUPPORT_DTYPE:
raise ParameterError(
'Unsupported audio dtype, '
f'y.dtype is {y.dtype}, supported dtypes are {SUPPORT_DTYPE}')
if dtype not in SUPPORT_DTYPE:
raise ParameterError(
'Unsupported audio dtype, '
f'target dtype is {dtype}, supported dtypes are {SUPPORT_DTYPE}')
if dtype == y.dtype:
return y
if dtype == 'float64' and y.dtype == 'float32':
return _safe_cast(y, dtype)
if dtype == 'float32' and y.dtype == 'float64':
return _safe_cast(y, dtype)
if dtype == 'int16' or dtype == 'int8':
if y.dtype in ['float64', 'float32']:
factor = np.iinfo(dtype).max
y = np.clip(y * factor, np.iinfo(dtype).min,
np.iinfo(dtype).max).astype(dtype)
y = y.astype(dtype)
else:
if dtype == 'int16' and y.dtype == 'int8':
factor = np.iinfo('int16').max / np.iinfo('int8').max - EPS
y = y.astype('float32') * factor
y = y.astype('int16')
else: # dtype == 'int8' and y.dtype=='int16':
y = y.astype('int32') * np.iinfo('int8').max / \
np.iinfo('int16').max
y = y.astype('int8')
if dtype in ['float32', 'float64']:
org_dtype = y.dtype
y = y.astype(dtype) / np.iinfo(org_dtype).max
return y
def sound_file_load(file: str,
offset: Optional[float]=None,
dtype: str='int16',
duration: Optional[int]=None) -> Tuple[array, int]:
"""Load audio using soundfile library
This function load audio file using libsndfile.
Reference:
http://www.mega-nerd.com/libsndfile/#Features
"""
with sf.SoundFile(file) as sf_desc:
sr_native = sf_desc.samplerate
if offset:
sf_desc.seek(int(offset * sr_native))
if duration is not None:
frame_duration = int(duration * sr_native)
else:
frame_duration = -1
y = sf_desc.read(frames=frame_duration, dtype=dtype, always_2d=False).T
return y, sf_desc.samplerate
def audio_file_load():
"""Load audio using audiofile library
This function load audio file using audiofile.
Reference:
https://audiofile.68k.org/
"""
raise NotImplementedError()
def sox_file_load():
"""Load audio using sox library
This function load audio file using sox.
Reference:
http://sox.sourceforge.net/
"""
raise NotImplementedError()
def normalize(y: array, norm_type: str='linear',
mul_factor: float=1.0) -> array:
""" normalize an input audio with additional multiplier.
"""
if norm_type == 'linear':
amax = np.max(np.abs(y))
factor = 1.0 / (amax + EPS)
y = y * factor * mul_factor
elif norm_type == 'gaussian':
amean = np.mean(y)
astd = np.std(y)
astd = max(astd, EPS)
y = mul_factor * (y - amean) / astd
else:
raise NotImplementedError(f'norm_type should be in {NORMALMIZE_TYPES}')
return y
def save_wav(y: array, sr: int, file: str) -> None:
"""Save audio file to disk.
This function saves audio to disk using scipy.io.wavfile, with additional step
to convert input waveform to int16 unless it already is int16
Notes:
It only support raw wav format.
"""
if not file.endswith('.wav'):
raise ParameterError(
f'only .wav file supported, but dst file name is: {file}')
if sr <= 0:
raise ParameterError(
f'Sample rate should be larger than 0, recieved sr = {sr}')
if y.dtype not in ['int16', 'int8']:
warnings.warn(
f'input data type is {y.dtype}, will convert data to int16 format before saving'
)
y_out = depth_convert(y, 'int16')
else:
y_out = y
wavfile.write(file, sr, y_out)
def load(
file: str,
sr: Optional[int]=None,
mono: bool=True,
merge_type: str='average', # ch0,ch1,random,average
normal: bool=True,
norm_type: str='linear',
norm_mul_factor: float=1.0,
offset: float=0.0,
duration: Optional[int]=None,
dtype: str='float32',
resample_mode: str='kaiser_fast') -> Tuple[array, int]:
"""Load audio file from disk.
This function loads audio from disk using using audio beackend.
Parameters:
Notes:
"""
y, r = sound_file_load(file, offset=offset, dtype=dtype, duration=duration)
if not ((y.ndim == 1 and len(y) > 0) or (y.ndim == 2 and len(y[0]) > 0)):
raise ParameterError(f'audio file {file} looks empty')
if mono:
y = to_mono(y, merge_type)
if sr is not None and sr != r:
y = resample(y, r, sr, mode=resample_mode)
r = sr
if normal:
y = normalize(y, norm_type, norm_mul_factor)
elif dtype in ['int8', 'int16']:
# still need to do normalization, before depth convertion
y = normalize(y, 'linear', 1.0)
y = depth_convert(y, dtype)
return y, r

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import mcd.metrics_fast as mt
import numpy as np
from mcd import dtw
__all__ = [

Loading…
Cancel
Save