Merge pull request #1528 from KPatr1ck/audio

[audio][feature]Refactor and add doc string.
pull/1533/head
Hui Zhang 3 years ago committed by GitHub
commit 5201c59f3a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -11,5 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import compliance
from . import datasets
from . import features
from . import functional
from . import io
from . import metric
from . import sox_effects
from .backends import load
from .backends import save

@ -105,7 +105,7 @@ def _get_log_energy(strided_input: Tensor, epsilon: Tensor,
def _get_waveform_and_window_properties(
waveform: Tensor,
channel: int,
sample_frequency: float,
sr: int,
frame_shift: float,
frame_length: float,
round_to_power_of_two: bool,
@ -115,9 +115,9 @@ def _get_waveform_and_window_properties(
'Invalid channel {} for size {}'.format(channel, waveform.shape[0]))
waveform = waveform[channel, :] # size (n)
window_shift = int(
sample_frequency * frame_shift *
sr * frame_shift *
0.001) # pass frame_shift and frame_length in milliseconds
window_size = int(sample_frequency * frame_length * 0.001)
window_size = int(sr * frame_length * 0.001)
padded_window_size = _next_power_of_2(
window_size) if round_to_power_of_two else window_size
@ -128,7 +128,7 @@ def _get_waveform_and_window_properties(
assert padded_window_size % 2 == 0, 'the padded `window_size` must be divisible by two.' \
' use `round_to_power_of_two` or change `frame_length`'
assert 0. <= preemphasis_coefficient <= 1.0, '`preemphasis_coefficient` must be between [0,1]'
assert sample_frequency > 0, '`sample_frequency` must be greater than zero'
assert sr > 0, '`sr` must be greater than zero'
return waveform, window_shift, window_size, padded_window_size
@ -147,45 +147,38 @@ def _get_window(waveform: Tensor,
dtype = waveform.dtype
epsilon = _get_epsilon(dtype)
# size (m, window_size)
# (m, window_size)
strided_input = _get_strided(waveform, window_size, window_shift,
snip_edges)
if dither != 0.0:
# Returns a random number strictly between 0 and 1
x = paddle.maximum(epsilon,
paddle.rand(strided_input.shape, dtype=dtype))
rand_gauss = paddle.sqrt(-2 * x.log()) * paddle.cos(2 * math.pi * x)
strided_input = strided_input + rand_gauss * dither
if remove_dc_offset:
# Subtract each row/frame by its mean
row_means = paddle.mean(
strided_input, axis=1).unsqueeze(1) # size (m, 1)
row_means = paddle.mean(strided_input, axis=1).unsqueeze(1) # (m, 1)
strided_input = strided_input - row_means
if raw_energy:
# Compute the log energy of each row/frame before applying preemphasis and
# window function
signal_log_energy = _get_log_energy(strided_input, epsilon,
energy_floor) # size (m)
energy_floor) # (m)
if preemphasis_coefficient != 0.0:
# strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
offset_strided_input = paddle.nn.functional.pad(
strided_input.unsqueeze(0), (1, 0),
data_format='NCL',
mode='replicate').squeeze(0) # size (m, window_size + 1)
mode='replicate').squeeze(0) # (m, window_size + 1)
strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :
-1]
# Apply window_function to each row/frame
window_function = _feature_window_function(
window_type, window_size, blackman_coeff,
dtype).unsqueeze(0) # size (1, window_size)
strided_input = strided_input * window_function # size (m, window_size)
dtype).unsqueeze(0) # (1, window_size)
strided_input = strided_input * window_function # (m, window_size)
# Pad columns with zero until we reach size (m, padded_window_size)
# (m, padded_window_size)
if padded_window_size != window_size:
padding_right = padded_window_size - window_size
strided_input = paddle.nn.functional.pad(
@ -194,7 +187,6 @@ def _get_window(waveform: Tensor,
mode='constant',
value=0).squeeze(0)
# Compute energy after window function (not the raw one)
if not raw_energy:
signal_log_energy = _get_log_energy(strided_input, epsilon,
energy_floor) # size (m)
@ -203,8 +195,6 @@ def _get_window(waveform: Tensor,
def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
# subtracts the column mean of the tensor size (m, n) if subtract_mean=True
# it returns size (m, n)
if subtract_mean:
col_means = paddle.mean(tensor, axis=0).unsqueeze(0)
tensor = tensor - col_means
@ -218,61 +208,56 @@ def spectrogram(waveform: Tensor,
energy_floor: float=1.0,
frame_length: float=25.0,
frame_shift: float=10.0,
min_duration: float=0.0,
preemphasis_coefficient: float=0.97,
raw_energy: bool=True,
remove_dc_offset: bool=True,
round_to_power_of_two: bool=True,
sample_frequency: float=16000.0,
sr: int=16000,
snip_edges: bool=True,
subtract_mean: bool=False,
window_type: str=POVEY) -> Tensor:
"""[summary]
"""Compute and return a spectrogram from a waveform. The output is identical to Kaldi's.
Args:
waveform (Tensor): [description]
blackman_coeff (float, optional): [description]. Defaults to 0.42.
channel (int, optional): [description]. Defaults to -1.
dither (float, optional): [description]. Defaults to 0.0.
energy_floor (float, optional): [description]. Defaults to 1.0.
frame_length (float, optional): [description]. Defaults to 25.0.
frame_shift (float, optional): [description]. Defaults to 10.0.
min_duration (float, optional): [description]. Defaults to 0.0.
preemphasis_coefficient (float, optional): [description]. Defaults to 0.97.
raw_energy (bool, optional): [description]. Defaults to True.
remove_dc_offset (bool, optional): [description]. Defaults to True.
round_to_power_of_two (bool, optional): [description]. Defaults to True.
sample_frequency (float, optional): [description]. Defaults to 16000.0.
snip_edges (bool, optional): [description]. Defaults to True.
subtract_mean (bool, optional): [description]. Defaults to False.
window_type (str, optional): [description]. Defaults to POVEY.
waveform (Tensor): A waveform tensor with shape [C, T].
blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42.
channel (int, optional): Select the channel of waveform. Defaults to -1.
dither (float, optional): Dithering constant . Defaults to 0.0.
energy_floor (float, optional): Floor on energy of the output Spectrogram. Defaults to 1.0.
frame_length (float, optional): Frame length in milliseconds. Defaults to 25.0.
frame_shift (float, optional): Shift between adjacent frames in milliseconds. Defaults to 10.0.
preemphasis_coefficient (float, optional): Preemphasis coefficient for input waveform. Defaults to 0.97.
raw_energy (bool, optional): Whether to compute before preemphasis and windowing. Defaults to True.
remove_dc_offset (bool, optional): Whether to subtract mean from waveform on frames. Defaults to True.
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
to FFT. Defaults to True.
sr (int, optional): Sample rate of input waveform. Defaults to 16000.
snip_edges (bool, optional): Drop samples in the end of waveform that cann't fit a singal frame when it
is set True. Otherwise performs reflect padding to the end of waveform. Defaults to True.
subtract_mean (bool, optional): Whether to subtract mean of feature files. Defaults to False.
window_type (str, optional): Choose type of window for FFT computation. Defaults to POVEY.
Returns:
Tensor: [description]
Tensor: A spectrogram tensor with shape (m, padded_window_size // 2 + 1) where m is the number of frames
depends on frame_length and frame_shift.
"""
dtype = waveform.dtype
epsilon = _get_epsilon(dtype)
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
waveform, channel, sample_frequency, frame_shift, frame_length,
round_to_power_of_two, preemphasis_coefficient)
if len(waveform) < min_duration * sample_frequency:
# signal is too short
return paddle.empty([0])
waveform, channel, sr, frame_shift, frame_length, round_to_power_of_two,
preemphasis_coefficient)
strided_input, signal_log_energy = _get_window(
waveform, padded_window_size, window_size, window_shift, window_type,
blackman_coeff, snip_edges, raw_energy, energy_floor, dither,
remove_dc_offset, preemphasis_coefficient)
# size (m, padded_window_size // 2 + 1, 2)
# (m, padded_window_size // 2 + 1, 2)
fft = paddle.fft.rfft(strided_input)
# Convert the FFT into a power spectrum
power_spectrum = paddle.maximum(
fft.abs().pow(2.),
epsilon).log() # size (m, padded_window_size // 2 + 1)
fft.abs().pow(2.), epsilon).log() # (m, padded_window_size // 2 + 1)
power_spectrum[:, 0] = signal_log_energy
power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
@ -306,25 +291,19 @@ def _vtln_warp_freq(vtln_low_cutoff: float,
l = vtln_low_cutoff * max(1.0, vtln_warp_factor)
h = vtln_high_cutoff * min(1.0, vtln_warp_factor)
scale = 1.0 / vtln_warp_factor
Fl = scale * l # F(l)
Fh = scale * h # F(h)
Fl = scale * l
Fh = scale * h
assert l > low_freq and h < high_freq
# slope of left part of the 3-piece linear function
scale_left = (Fl - low_freq) / (l - low_freq)
# [slope of center part is just "scale"]
# slope of right part of the 3-piece linear function
scale_right = (high_freq - Fh) / (high_freq - h)
res = paddle.empty_like(freq)
outside_low_high_freq = paddle.less_than(freq, paddle.to_tensor(low_freq)) \
| paddle.greater_than(freq, paddle.to_tensor(high_freq)) # freq < low_freq || freq > high_freq
before_l = paddle.less_than(freq, paddle.to_tensor(l)) # freq < l
before_h = paddle.less_than(freq, paddle.to_tensor(h)) # freq < h
after_h = paddle.greater_equal(freq, paddle.to_tensor(h)) # freq >= h
| paddle.greater_than(freq, paddle.to_tensor(high_freq))
before_l = paddle.less_than(freq, paddle.to_tensor(l))
before_h = paddle.less_than(freq, paddle.to_tensor(h))
after_h = paddle.greater_equal(freq, paddle.to_tensor(h))
# order of operations matter here (since there is overlapping frequency regions)
res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq)
res[before_h] = scale * freq[before_h]
res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq)
@ -363,13 +342,10 @@ def _get_mel_banks(num_bins: int,
assert (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq), \
('Bad values in options: low-freq {} and high-freq {} vs. nyquist {}'.format(low_freq, high_freq, nyquist))
# fft-bin width [think of it as Nyquist-freq / half-window-length]
fft_bin_width = sample_freq / window_length_padded
mel_low_freq = _mel_scale_scalar(low_freq)
mel_high_freq = _mel_scale_scalar(high_freq)
# divide by num_bins+1 in next line because of end-effects where the bins
# spread out to the sides.
mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
if vtln_high < 0.0:
@ -381,10 +357,9 @@ def _get_mel_banks(num_bins: int,
'low-freq {} and high-freq {}'.format(vtln_low, vtln_high, low_freq, high_freq))
bin = paddle.arange(num_bins).unsqueeze(1)
left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1)
center_mel = mel_low_freq + (bin + 1.0
) * mel_freq_delta # size(num_bins, 1)
right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1)
left_mel = mel_low_freq + bin * mel_freq_delta # (num_bins, 1)
center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # (num_bins, 1)
right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # (num_bins, 1)
if vtln_warp_factor != 1.0:
left_mel = _vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq,
@ -395,25 +370,23 @@ def _get_mel_banks(num_bins: int,
right_mel = _vtln_warp_mel_freq(vtln_low, vtln_high, low_freq,
high_freq, vtln_warp_factor, right_mel)
center_freqs = _inverse_mel_scale(center_mel) # size (num_bins)
# size(1, num_fft_bins)
center_freqs = _inverse_mel_scale(center_mel) # (num_bins)
# (1, num_fft_bins)
mel = _mel_scale(fft_bin_width * paddle.arange(num_fft_bins)).unsqueeze(0)
# size (num_bins, num_fft_bins)
# (num_bins, num_fft_bins)
up_slope = (mel - left_mel) / (center_mel - left_mel)
down_slope = (right_mel - mel) / (right_mel - center_mel)
if vtln_warp_factor == 1.0:
# left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values
bins = paddle.maximum(
paddle.zeros([1]), paddle.minimum(up_slope, down_slope))
else:
# warping can move the order of left_mel, center_mel, right_mel anywhere
bins = paddle.zeros_like(up_slope)
up_idx = paddle.greater_than(mel, left_mel) & paddle.less_than(
mel, center_mel) # left_mel < mel <= center_mel
mel, center_mel)
down_idx = paddle.greater_than(mel, center_mel) & paddle.less_than(
mel, right_mel) # center_mel < mel < right_mel
mel, right_mel)
bins[up_idx] = up_slope[up_idx]
bins[down_idx] = down_slope[down_idx]
@ -430,13 +403,12 @@ def fbank(waveform: Tensor,
high_freq: float=0.0,
htk_compat: bool=False,
low_freq: float=20.0,
min_duration: float=0.0,
num_mel_bins: int=23,
n_mels: int=23,
preemphasis_coefficient: float=0.97,
raw_energy: bool=True,
remove_dc_offset: bool=True,
round_to_power_of_two: bool=True,
sample_frequency: float=16000.0,
sr: int=16000,
snip_edges: bool=True,
subtract_mean: bool=False,
use_energy: bool=False,
@ -446,83 +418,75 @@ def fbank(waveform: Tensor,
vtln_low: float=100.0,
vtln_warp: float=1.0,
window_type: str=POVEY) -> Tensor:
"""[summary]
"""Compute and return filter banks from a waveform. The output is identical to Kaldi's.
Args:
waveform (Tensor): [description]
blackman_coeff (float, optional): [description]. Defaults to 0.42.
channel (int, optional): [description]. Defaults to -1.
dither (float, optional): [description]. Defaults to 0.0.
energy_floor (float, optional): [description]. Defaults to 1.0.
frame_length (float, optional): [description]. Defaults to 25.0.
frame_shift (float, optional): [description]. Defaults to 10.0.
high_freq (float, optional): [description]. Defaults to 0.0.
htk_compat (bool, optional): [description]. Defaults to False.
low_freq (float, optional): [description]. Defaults to 20.0.
min_duration (float, optional): [description]. Defaults to 0.0.
num_mel_bins (int, optional): [description]. Defaults to 23.
preemphasis_coefficient (float, optional): [description]. Defaults to 0.97.
raw_energy (bool, optional): [description]. Defaults to True.
remove_dc_offset (bool, optional): [description]. Defaults to True.
round_to_power_of_two (bool, optional): [description]. Defaults to True.
sample_frequency (float, optional): [description]. Defaults to 16000.0.
snip_edges (bool, optional): [description]. Defaults to True.
subtract_mean (bool, optional): [description]. Defaults to False.
use_energy (bool, optional): [description]. Defaults to False.
use_log_fbank (bool, optional): [description]. Defaults to True.
use_power (bool, optional): [description]. Defaults to True.
vtln_high (float, optional): [description]. Defaults to -500.0.
vtln_low (float, optional): [description]. Defaults to 100.0.
vtln_warp (float, optional): [description]. Defaults to 1.0.
window_type (str, optional): [description]. Defaults to POVEY.
waveform (Tensor): A waveform tensor with shape [C, T].
blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42.
channel (int, optional): Select the channel of waveform. Defaults to -1.
dither (float, optional): Dithering constant . Defaults to 0.0.
energy_floor (float, optional): Floor on energy of the output Spectrogram. Defaults to 1.0.
frame_length (float, optional): Frame length in milliseconds. Defaults to 25.0.
frame_shift (float, optional): Shift between adjacent frames in milliseconds. Defaults to 10.0.
high_freq (float, optional): The upper cut-off frequency. Defaults to 0.0.
htk_compat (bool, optional): Put energy to the last when it is set True. Defaults to False.
low_freq (float, optional): The lower cut-off frequency. Defaults to 20.0.
n_mels (int, optional): Number of output mel bins. Defaults to 23.
preemphasis_coefficient (float, optional): Preemphasis coefficient for input waveform. Defaults to 0.97.
raw_energy (bool, optional): Whether to compute before preemphasis and windowing. Defaults to True.
remove_dc_offset (bool, optional): Whether to subtract mean from waveform on frames. Defaults to True.
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
to FFT. Defaults to True.
sr (int, optional): Sample rate of input waveform. Defaults to 16000.
snip_edges (bool, optional): Drop samples in the end of waveform that cann't fit a singal frame when it
is set True. Otherwise performs reflect padding to the end of waveform. Defaults to True.
subtract_mean (bool, optional): Whether to subtract mean of feature files. Defaults to False.
use_energy (bool, optional): Add an dimension with energy of spectrogram to the output. Defaults to False.
use_log_fbank (bool, optional): Return log fbank when it is set True. Defaults to True.
use_power (bool, optional): Whether to use power instead of magnitude. Defaults to True.
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function. Defaults to -500.0.
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function. Defaults to 100.0.
vtln_warp (float, optional): Vtln warp factor. Defaults to 1.0.
window_type (str, optional): Choose type of window for FFT computation. Defaults to POVEY.
Returns:
Tensor: [description]
Tensor: A filter banks tensor with shape (m, n_mels).
"""
dtype = waveform.dtype
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
waveform, channel, sample_frequency, frame_shift, frame_length,
round_to_power_of_two, preemphasis_coefficient)
if len(waveform) < min_duration * sample_frequency:
# signal is too short
return paddle.empty([0], dtype=dtype)
waveform, channel, sr, frame_shift, frame_length, round_to_power_of_two,
preemphasis_coefficient)
# strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
strided_input, signal_log_energy = _get_window(
waveform, padded_window_size, window_size, window_shift, window_type,
blackman_coeff, snip_edges, raw_energy, energy_floor, dither,
remove_dc_offset, preemphasis_coefficient)
# size (m, padded_window_size // 2 + 1)
# (m, padded_window_size // 2 + 1)
spectrum = paddle.fft.rfft(strided_input).abs()
if use_power:
spectrum = spectrum.pow(2.)
# size (num_mel_bins, padded_window_size // 2)
mel_energies, _ = _get_mel_banks(num_mel_bins, padded_window_size,
sample_frequency, low_freq, high_freq,
vtln_low, vtln_high, vtln_warp)
# (n_mels, padded_window_size // 2)
mel_energies, _ = _get_mel_banks(n_mels, padded_window_size, sr, low_freq,
high_freq, vtln_low, vtln_high, vtln_warp)
mel_energies = mel_energies.astype(dtype)
# pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
# (n_mels, padded_window_size // 2 + 1)
mel_energies = paddle.nn.functional.pad(
mel_energies.unsqueeze(0), (0, 1),
data_format='NCL',
mode='constant',
value=0).squeeze(0)
# sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
# (m, n_mels)
mel_energies = paddle.mm(spectrum, mel_energies.T)
if use_log_fbank:
# avoid log of zero (which should be prevented anyway by dithering)
mel_energies = paddle.maximum(mel_energies, _get_epsilon(dtype)).log()
# if use_energy then add it as the last column for htk_compat == true else first column
if use_energy:
signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1)
# returns size (m, num_mel_bins + 1)
signal_log_energy = signal_log_energy.unsqueeze(1)
if htk_compat:
mel_energies = paddle.concat(
(mel_energies, signal_log_energy), axis=1)
@ -530,28 +494,20 @@ def fbank(waveform: Tensor,
mel_energies = paddle.concat(
(signal_log_energy, mel_energies), axis=1)
# (m, n_mels + 1)
mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
return mel_energies
def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor:
# returns a dct matrix of size (num_mel_bins, num_ceps)
# size (num_mel_bins, num_mel_bins)
dct_matrix = create_dct(num_mel_bins, num_mel_bins, 'ortho')
# kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins)
# this would be the first column in the dct_matrix for torchaudio as it expects a
# right multiply (which would be the first column of the kaldi's dct_matrix as kaldi
# expects a left multiply e.g. dct_matrix * vector).
dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins))
dct_matrix = dct_matrix[:, :num_ceps]
def _get_dct_matrix(n_mfcc: int, n_mels: int) -> Tensor:
dct_matrix = create_dct(n_mels, n_mels, 'ortho')
dct_matrix[:, 0] = math.sqrt(1 / float(n_mels))
dct_matrix = dct_matrix[:, :n_mfcc] # (n_mels, n_mfcc)
return dct_matrix
def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor:
# returns size (num_ceps)
# Compute liftering coefficients (scaling on cepstral coeffs)
# coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected.
i = paddle.arange(num_ceps)
def _get_lifter_coeffs(n_mfcc: int, cepstral_lifter: float) -> Tensor:
i = paddle.arange(n_mfcc)
return 1.0 + 0.5 * cepstral_lifter * paddle.sin(math.pi * i /
cepstral_lifter)
@ -567,14 +523,13 @@ def mfcc(waveform: Tensor,
high_freq: float=0.0,
htk_compat: bool=False,
low_freq: float=20.0,
num_ceps: int=13,
min_duration: float=0.0,
num_mel_bins: int=23,
n_mfcc: int=13,
n_mels: int=23,
preemphasis_coefficient: float=0.97,
raw_energy: bool=True,
remove_dc_offset: bool=True,
round_to_power_of_two: bool=True,
sample_frequency: float=16000.0,
sr: int=16000,
snip_edges: bool=True,
subtract_mean: bool=False,
use_energy: bool=False,
@ -582,47 +537,47 @@ def mfcc(waveform: Tensor,
vtln_low: float=100.0,
vtln_warp: float=1.0,
window_type: str=POVEY) -> Tensor:
"""[summary]
"""Compute and return mel frequency cepstral coefficients from a waveform. The output is
identical to Kaldi's.
Args:
waveform (Tensor): [description]
blackman_coeff (float, optional): [description]. Defaults to 0.42.
cepstral_lifter (float, optional): [description]. Defaults to 22.0.
channel (int, optional): [description]. Defaults to -1.
dither (float, optional): [description]. Defaults to 0.0.
energy_floor (float, optional): [description]. Defaults to 1.0.
frame_length (float, optional): [description]. Defaults to 25.0.
frame_shift (float, optional): [description]. Defaults to 10.0.
high_freq (float, optional): [description]. Defaults to 0.0.
htk_compat (bool, optional): [description]. Defaults to False.
low_freq (float, optional): [description]. Defaults to 20.0.
num_ceps (int, optional): [description]. Defaults to 13.
min_duration (float, optional): [description]. Defaults to 0.0.
num_mel_bins (int, optional): [description]. Defaults to 23.
preemphasis_coefficient (float, optional): [description]. Defaults to 0.97.
raw_energy (bool, optional): [description]. Defaults to True.
remove_dc_offset (bool, optional): [description]. Defaults to True.
round_to_power_of_two (bool, optional): [description]. Defaults to True.
sample_frequency (float, optional): [description]. Defaults to 16000.0.
snip_edges (bool, optional): [description]. Defaults to True.
subtract_mean (bool, optional): [description]. Defaults to False.
use_energy (bool, optional): [description]. Defaults to False.
vtln_high (float, optional): [description]. Defaults to -500.0.
vtln_low (float, optional): [description]. Defaults to 100.0.
vtln_warp (float, optional): [description]. Defaults to 1.0.
window_type (str, optional): [description]. Defaults to POVEY.
waveform (Tensor): A waveform tensor with shape [C, T].
blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42.
cepstral_lifter (float, optional): Scaling of output mfccs. Defaults to 22.0.
channel (int, optional): Select the channel of waveform. Defaults to -1.
dither (float, optional): Dithering constant . Defaults to 0.0.
energy_floor (float, optional): Floor on energy of the output Spectrogram. Defaults to 1.0.
frame_length (float, optional): Frame length in milliseconds. Defaults to 25.0.
frame_shift (float, optional): Shift between adjacent frames in milliseconds. Defaults to 10.0.
high_freq (float, optional): The upper cut-off frequency. Defaults to 0.0.
htk_compat (bool, optional): Put energy to the last when it is set True. Defaults to False.
low_freq (float, optional): The lower cut-off frequency. Defaults to 20.0.
n_mfcc (int, optional): Number of cepstra in MFCC. Defaults to 13.
n_mels (int, optional): Number of output mel bins. Defaults to 23.
preemphasis_coefficient (float, optional): Preemphasis coefficient for input waveform. Defaults to 0.97.
raw_energy (bool, optional): Whether to compute before preemphasis and windowing. Defaults to True.
remove_dc_offset (bool, optional): Whether to subtract mean from waveform on frames. Defaults to True.
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
to FFT. Defaults to True.
sr (int, optional): Sample rate of input waveform. Defaults to 16000.
snip_edges (bool, optional): Drop samples in the end of waveform that cann't fit a singal frame when it
is set True. Otherwise performs reflect padding to the end of waveform. Defaults to True.
subtract_mean (bool, optional): Whether to subtract mean of feature files. Defaults to False.
use_energy (bool, optional): Add an dimension with energy of spectrogram to the output. Defaults to False.
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function. Defaults to -500.0.
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function. Defaults to 100.0.
vtln_warp (float, optional): Vtln warp factor. Defaults to 1.0.
window_type (str, optional): Choose type of window for FFT computation. Defaults to POVEY.
Returns:
Tensor: [description]
Tensor: A mel frequency cepstral coefficients tensor with shape (m, n_mfcc).
"""
assert num_ceps <= num_mel_bins, 'num_ceps cannot be larger than num_mel_bins: %d vs %d' % (
num_ceps, num_mel_bins)
assert n_mfcc <= n_mels, 'n_mfcc cannot be larger than n_mels: %d vs %d' % (
n_mfcc, n_mels)
dtype = waveform.dtype
# The mel_energies should not be squared (use_power=True), not have mean subtracted
# (subtract_mean=False), and use log (use_log_fbank=True).
# size (m, num_mel_bins + use_energy)
# (m, n_mels + use_energy)
feature = fbank(
waveform=waveform,
blackman_coeff=blackman_coeff,
@ -634,13 +589,12 @@ def mfcc(waveform: Tensor,
high_freq=high_freq,
htk_compat=htk_compat,
low_freq=low_freq,
min_duration=min_duration,
num_mel_bins=num_mel_bins,
n_mels=n_mels,
preemphasis_coefficient=preemphasis_coefficient,
raw_energy=raw_energy,
remove_dc_offset=remove_dc_offset,
round_to_power_of_two=round_to_power_of_two,
sample_frequency=sample_frequency,
sr=sr,
snip_edges=snip_edges,
subtract_mean=False,
use_energy=use_energy,
@ -652,34 +606,29 @@ def mfcc(waveform: Tensor,
window_type=window_type)
if use_energy:
# size (m)
signal_log_energy = feature[:, num_mel_bins if htk_compat else 0]
# offset is 0 if htk_compat==True else 1
# (m)
signal_log_energy = feature[:, n_mels if htk_compat else 0]
mel_offset = int(not htk_compat)
feature = feature[:, mel_offset:(num_mel_bins + mel_offset)]
feature = feature[:, mel_offset:(n_mels + mel_offset)]
# size (num_mel_bins, num_ceps)
dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).astype(dtype=dtype)
# (n_mels, n_mfcc)
dct_matrix = _get_dct_matrix(n_mfcc, n_mels).astype(dtype=dtype)
# size (m, num_ceps)
# (m, n_mfcc)
feature = feature.matmul(dct_matrix)
if cepstral_lifter != 0.0:
# size (1, num_ceps)
lifter_coeffs = _get_lifter_coeffs(num_ceps,
cepstral_lifter).unsqueeze(0)
# (1, n_mfcc)
lifter_coeffs = _get_lifter_coeffs(n_mfcc, cepstral_lifter).unsqueeze(0)
feature *= lifter_coeffs.astype(dtype=dtype)
# if use_energy then replace the last column for htk_compat == true else first column
if use_energy:
feature[:, 0] = signal_log_energy
if htk_compat:
energy = feature[:, 0].unsqueeze(1) # size (m, 1)
feature = feature[:, 1:] # size (m, num_ceps - 1)
energy = feature[:, 0].unsqueeze(1) # (m, 1)
feature = feature[:, 1:] # (m, n_mfcc - 1)
if not use_energy:
# scale on C0 (actually removing a scale we previously added that's
# part of one common definition of the cosine transform.)
energy *= math.sqrt(2)
feature = paddle.concat((feature, energy), axis=1)

@ -71,15 +71,17 @@ class Spectrogram(nn.Layer):
if win_length is None:
win_length = n_fft
fft_window = get_window(window, win_length, fftbins=True, dtype=dtype)
self.fft_window = get_window(
window, win_length, fftbins=True, dtype=dtype)
self._stft = partial(
paddle.signal.stft,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=fft_window,
window=self.fft_window,
center=center,
pad_mode=pad_mode)
self.register_buffer('fft_window', self.fft_window)
def forward(self, x):
stft = self._stft(x)
@ -259,12 +261,18 @@ class MFCC(nn.Layer):
sr: int=22050,
n_mfcc: int=40,
norm: str='ortho',
dtype: str=paddle.float32,
**kwargs):
"""[summary]
"""Compute mel frequency cepstral coefficients(MFCCs) feature of given waveforms.
Parameters:
sr (int, optional): [description]. Defaults to 22050.
n_mfcc (int, optional): [description]. Defaults to 40.
norm (str, optional): [description]. Defaults to 'ortho'.
sr(int): the audio sample rate.
The default value is 22050.
n_mfcc (int, optional): Number of cepstra in MFCC. Defaults to 40.
norm(str|float): the normalization type in computing fbank matrix. Slaney-style is used by default.
You can specify norm=1.0/2.0 to use customized p-norm normalization.
dtype(str): the datatype of fbank matrix used in the transform. Use float64 to increase numerical
accuracy. Note that the final transform will be conducted in float32 regardless of dtype of fbank matrix.
"""
super(MFCC, self).__init__()
self._log_melspectrogram = LogMelSpectrogram(sr=sr, **kwargs)

Loading…
Cancel
Save