|
|
|
@ -105,7 +105,7 @@ def _get_log_energy(strided_input: Tensor, epsilon: Tensor,
|
|
|
|
|
def _get_waveform_and_window_properties(
|
|
|
|
|
waveform: Tensor,
|
|
|
|
|
channel: int,
|
|
|
|
|
sample_frequency: float,
|
|
|
|
|
sr: int,
|
|
|
|
|
frame_shift: float,
|
|
|
|
|
frame_length: float,
|
|
|
|
|
round_to_power_of_two: bool,
|
|
|
|
@ -115,9 +115,9 @@ def _get_waveform_and_window_properties(
|
|
|
|
|
'Invalid channel {} for size {}'.format(channel, waveform.shape[0]))
|
|
|
|
|
waveform = waveform[channel, :] # size (n)
|
|
|
|
|
window_shift = int(
|
|
|
|
|
sample_frequency * frame_shift *
|
|
|
|
|
sr * frame_shift *
|
|
|
|
|
0.001) # pass frame_shift and frame_length in milliseconds
|
|
|
|
|
window_size = int(sample_frequency * frame_length * 0.001)
|
|
|
|
|
window_size = int(sr * frame_length * 0.001)
|
|
|
|
|
padded_window_size = _next_power_of_2(
|
|
|
|
|
window_size) if round_to_power_of_two else window_size
|
|
|
|
|
|
|
|
|
@ -128,7 +128,7 @@ def _get_waveform_and_window_properties(
|
|
|
|
|
assert padded_window_size % 2 == 0, 'the padded `window_size` must be divisible by two.' \
|
|
|
|
|
' use `round_to_power_of_two` or change `frame_length`'
|
|
|
|
|
assert 0. <= preemphasis_coefficient <= 1.0, '`preemphasis_coefficient` must be between [0,1]'
|
|
|
|
|
assert sample_frequency > 0, '`sample_frequency` must be greater than zero'
|
|
|
|
|
assert sr > 0, '`sr` must be greater than zero'
|
|
|
|
|
return waveform, window_shift, window_size, padded_window_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -147,45 +147,38 @@ def _get_window(waveform: Tensor,
|
|
|
|
|
dtype = waveform.dtype
|
|
|
|
|
epsilon = _get_epsilon(dtype)
|
|
|
|
|
|
|
|
|
|
# size (m, window_size)
|
|
|
|
|
# (m, window_size)
|
|
|
|
|
strided_input = _get_strided(waveform, window_size, window_shift,
|
|
|
|
|
snip_edges)
|
|
|
|
|
|
|
|
|
|
if dither != 0.0:
|
|
|
|
|
# Returns a random number strictly between 0 and 1
|
|
|
|
|
x = paddle.maximum(epsilon,
|
|
|
|
|
paddle.rand(strided_input.shape, dtype=dtype))
|
|
|
|
|
rand_gauss = paddle.sqrt(-2 * x.log()) * paddle.cos(2 * math.pi * x)
|
|
|
|
|
strided_input = strided_input + rand_gauss * dither
|
|
|
|
|
|
|
|
|
|
if remove_dc_offset:
|
|
|
|
|
# Subtract each row/frame by its mean
|
|
|
|
|
row_means = paddle.mean(
|
|
|
|
|
strided_input, axis=1).unsqueeze(1) # size (m, 1)
|
|
|
|
|
row_means = paddle.mean(strided_input, axis=1).unsqueeze(1) # (m, 1)
|
|
|
|
|
strided_input = strided_input - row_means
|
|
|
|
|
|
|
|
|
|
if raw_energy:
|
|
|
|
|
# Compute the log energy of each row/frame before applying preemphasis and
|
|
|
|
|
# window function
|
|
|
|
|
signal_log_energy = _get_log_energy(strided_input, epsilon,
|
|
|
|
|
energy_floor) # size (m)
|
|
|
|
|
energy_floor) # (m)
|
|
|
|
|
|
|
|
|
|
if preemphasis_coefficient != 0.0:
|
|
|
|
|
# strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
|
|
|
|
|
offset_strided_input = paddle.nn.functional.pad(
|
|
|
|
|
strided_input.unsqueeze(0), (1, 0),
|
|
|
|
|
data_format='NCL',
|
|
|
|
|
mode='replicate').squeeze(0) # size (m, window_size + 1)
|
|
|
|
|
mode='replicate').squeeze(0) # (m, window_size + 1)
|
|
|
|
|
strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :
|
|
|
|
|
-1]
|
|
|
|
|
|
|
|
|
|
# Apply window_function to each row/frame
|
|
|
|
|
window_function = _feature_window_function(
|
|
|
|
|
window_type, window_size, blackman_coeff,
|
|
|
|
|
dtype).unsqueeze(0) # size (1, window_size)
|
|
|
|
|
strided_input = strided_input * window_function # size (m, window_size)
|
|
|
|
|
dtype).unsqueeze(0) # (1, window_size)
|
|
|
|
|
strided_input = strided_input * window_function # (m, window_size)
|
|
|
|
|
|
|
|
|
|
# Pad columns with zero until we reach size (m, padded_window_size)
|
|
|
|
|
# (m, padded_window_size)
|
|
|
|
|
if padded_window_size != window_size:
|
|
|
|
|
padding_right = padded_window_size - window_size
|
|
|
|
|
strided_input = paddle.nn.functional.pad(
|
|
|
|
@ -194,7 +187,6 @@ def _get_window(waveform: Tensor,
|
|
|
|
|
mode='constant',
|
|
|
|
|
value=0).squeeze(0)
|
|
|
|
|
|
|
|
|
|
# Compute energy after window function (not the raw one)
|
|
|
|
|
if not raw_energy:
|
|
|
|
|
signal_log_energy = _get_log_energy(strided_input, epsilon,
|
|
|
|
|
energy_floor) # size (m)
|
|
|
|
@ -203,8 +195,6 @@ def _get_window(waveform: Tensor,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
|
|
|
|
|
# subtracts the column mean of the tensor size (m, n) if subtract_mean=True
|
|
|
|
|
# it returns size (m, n)
|
|
|
|
|
if subtract_mean:
|
|
|
|
|
col_means = paddle.mean(tensor, axis=0).unsqueeze(0)
|
|
|
|
|
tensor = tensor - col_means
|
|
|
|
@ -218,61 +208,56 @@ def spectrogram(waveform: Tensor,
|
|
|
|
|
energy_floor: float=1.0,
|
|
|
|
|
frame_length: float=25.0,
|
|
|
|
|
frame_shift: float=10.0,
|
|
|
|
|
min_duration: float=0.0,
|
|
|
|
|
preemphasis_coefficient: float=0.97,
|
|
|
|
|
raw_energy: bool=True,
|
|
|
|
|
remove_dc_offset: bool=True,
|
|
|
|
|
round_to_power_of_two: bool=True,
|
|
|
|
|
sample_frequency: float=16000.0,
|
|
|
|
|
sr: int=16000,
|
|
|
|
|
snip_edges: bool=True,
|
|
|
|
|
subtract_mean: bool=False,
|
|
|
|
|
window_type: str=POVEY) -> Tensor:
|
|
|
|
|
"""[summary]
|
|
|
|
|
"""Compute and return a spectrogram from a waveform. The output is identical to Kaldi's.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
waveform (Tensor): [description]
|
|
|
|
|
blackman_coeff (float, optional): [description]. Defaults to 0.42.
|
|
|
|
|
channel (int, optional): [description]. Defaults to -1.
|
|
|
|
|
dither (float, optional): [description]. Defaults to 0.0.
|
|
|
|
|
energy_floor (float, optional): [description]. Defaults to 1.0.
|
|
|
|
|
frame_length (float, optional): [description]. Defaults to 25.0.
|
|
|
|
|
frame_shift (float, optional): [description]. Defaults to 10.0.
|
|
|
|
|
min_duration (float, optional): [description]. Defaults to 0.0.
|
|
|
|
|
preemphasis_coefficient (float, optional): [description]. Defaults to 0.97.
|
|
|
|
|
raw_energy (bool, optional): [description]. Defaults to True.
|
|
|
|
|
remove_dc_offset (bool, optional): [description]. Defaults to True.
|
|
|
|
|
round_to_power_of_two (bool, optional): [description]. Defaults to True.
|
|
|
|
|
sample_frequency (float, optional): [description]. Defaults to 16000.0.
|
|
|
|
|
snip_edges (bool, optional): [description]. Defaults to True.
|
|
|
|
|
subtract_mean (bool, optional): [description]. Defaults to False.
|
|
|
|
|
window_type (str, optional): [description]. Defaults to POVEY.
|
|
|
|
|
waveform (Tensor): A waveform tensor with shape [C, T].
|
|
|
|
|
blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42.
|
|
|
|
|
channel (int, optional): Select the channel of waveform. Defaults to -1.
|
|
|
|
|
dither (float, optional): Dithering constant . Defaults to 0.0.
|
|
|
|
|
energy_floor (float, optional): Floor on energy of the output Spectrogram. Defaults to 1.0.
|
|
|
|
|
frame_length (float, optional): Frame length in milliseconds. Defaults to 25.0.
|
|
|
|
|
frame_shift (float, optional): Shift between adjacent frames in milliseconds. Defaults to 10.0.
|
|
|
|
|
preemphasis_coefficient (float, optional): Preemphasis coefficient for input waveform. Defaults to 0.97.
|
|
|
|
|
raw_energy (bool, optional): Whether to compute before preemphasis and windowing. Defaults to True.
|
|
|
|
|
remove_dc_offset (bool, optional): Whether to subtract mean from waveform on frames. Defaults to True.
|
|
|
|
|
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
|
|
|
|
to FFT. Defaults to True.
|
|
|
|
|
sr (int, optional): Sample rate of input waveform. Defaults to 16000.
|
|
|
|
|
snip_edges (bool, optional): Drop samples in the end of waveform that cann't fit a singal frame when it
|
|
|
|
|
is set True. Otherwise performs reflect padding to the end of waveform. Defaults to True.
|
|
|
|
|
subtract_mean (bool, optional): Whether to subtract mean of feature files. Defaults to False.
|
|
|
|
|
window_type (str, optional): Choose type of window for FFT computation. Defaults to POVEY.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor: [description]
|
|
|
|
|
Tensor: A spectrogram tensor with shape (m, padded_window_size // 2 + 1) where m is the number of frames
|
|
|
|
|
depends on frame_length and frame_shift.
|
|
|
|
|
"""
|
|
|
|
|
dtype = waveform.dtype
|
|
|
|
|
epsilon = _get_epsilon(dtype)
|
|
|
|
|
|
|
|
|
|
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
|
|
|
|
|
waveform, channel, sample_frequency, frame_shift, frame_length,
|
|
|
|
|
round_to_power_of_two, preemphasis_coefficient)
|
|
|
|
|
|
|
|
|
|
if len(waveform) < min_duration * sample_frequency:
|
|
|
|
|
# signal is too short
|
|
|
|
|
return paddle.empty([0])
|
|
|
|
|
waveform, channel, sr, frame_shift, frame_length, round_to_power_of_two,
|
|
|
|
|
preemphasis_coefficient)
|
|
|
|
|
|
|
|
|
|
strided_input, signal_log_energy = _get_window(
|
|
|
|
|
waveform, padded_window_size, window_size, window_shift, window_type,
|
|
|
|
|
blackman_coeff, snip_edges, raw_energy, energy_floor, dither,
|
|
|
|
|
remove_dc_offset, preemphasis_coefficient)
|
|
|
|
|
|
|
|
|
|
# size (m, padded_window_size // 2 + 1, 2)
|
|
|
|
|
# (m, padded_window_size // 2 + 1, 2)
|
|
|
|
|
fft = paddle.fft.rfft(strided_input)
|
|
|
|
|
|
|
|
|
|
# Convert the FFT into a power spectrum
|
|
|
|
|
power_spectrum = paddle.maximum(
|
|
|
|
|
fft.abs().pow(2.),
|
|
|
|
|
epsilon).log() # size (m, padded_window_size // 2 + 1)
|
|
|
|
|
fft.abs().pow(2.), epsilon).log() # (m, padded_window_size // 2 + 1)
|
|
|
|
|
power_spectrum[:, 0] = signal_log_energy
|
|
|
|
|
|
|
|
|
|
power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
|
|
|
|
@ -306,25 +291,19 @@ def _vtln_warp_freq(vtln_low_cutoff: float,
|
|
|
|
|
l = vtln_low_cutoff * max(1.0, vtln_warp_factor)
|
|
|
|
|
h = vtln_high_cutoff * min(1.0, vtln_warp_factor)
|
|
|
|
|
scale = 1.0 / vtln_warp_factor
|
|
|
|
|
Fl = scale * l # F(l)
|
|
|
|
|
Fh = scale * h # F(h)
|
|
|
|
|
Fl = scale * l
|
|
|
|
|
Fh = scale * h
|
|
|
|
|
assert l > low_freq and h < high_freq
|
|
|
|
|
# slope of left part of the 3-piece linear function
|
|
|
|
|
scale_left = (Fl - low_freq) / (l - low_freq)
|
|
|
|
|
# [slope of center part is just "scale"]
|
|
|
|
|
|
|
|
|
|
# slope of right part of the 3-piece linear function
|
|
|
|
|
scale_right = (high_freq - Fh) / (high_freq - h)
|
|
|
|
|
|
|
|
|
|
res = paddle.empty_like(freq)
|
|
|
|
|
|
|
|
|
|
outside_low_high_freq = paddle.less_than(freq, paddle.to_tensor(low_freq)) \
|
|
|
|
|
| paddle.greater_than(freq, paddle.to_tensor(high_freq)) # freq < low_freq || freq > high_freq
|
|
|
|
|
before_l = paddle.less_than(freq, paddle.to_tensor(l)) # freq < l
|
|
|
|
|
before_h = paddle.less_than(freq, paddle.to_tensor(h)) # freq < h
|
|
|
|
|
after_h = paddle.greater_equal(freq, paddle.to_tensor(h)) # freq >= h
|
|
|
|
|
| paddle.greater_than(freq, paddle.to_tensor(high_freq))
|
|
|
|
|
before_l = paddle.less_than(freq, paddle.to_tensor(l))
|
|
|
|
|
before_h = paddle.less_than(freq, paddle.to_tensor(h))
|
|
|
|
|
after_h = paddle.greater_equal(freq, paddle.to_tensor(h))
|
|
|
|
|
|
|
|
|
|
# order of operations matter here (since there is overlapping frequency regions)
|
|
|
|
|
res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq)
|
|
|
|
|
res[before_h] = scale * freq[before_h]
|
|
|
|
|
res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq)
|
|
|
|
@ -363,13 +342,10 @@ def _get_mel_banks(num_bins: int,
|
|
|
|
|
assert (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq), \
|
|
|
|
|
('Bad values in options: low-freq {} and high-freq {} vs. nyquist {}'.format(low_freq, high_freq, nyquist))
|
|
|
|
|
|
|
|
|
|
# fft-bin width [think of it as Nyquist-freq / half-window-length]
|
|
|
|
|
fft_bin_width = sample_freq / window_length_padded
|
|
|
|
|
mel_low_freq = _mel_scale_scalar(low_freq)
|
|
|
|
|
mel_high_freq = _mel_scale_scalar(high_freq)
|
|
|
|
|
|
|
|
|
|
# divide by num_bins+1 in next line because of end-effects where the bins
|
|
|
|
|
# spread out to the sides.
|
|
|
|
|
mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
|
|
|
|
|
|
|
|
|
|
if vtln_high < 0.0:
|
|
|
|
@ -381,10 +357,9 @@ def _get_mel_banks(num_bins: int,
|
|
|
|
|
'low-freq {} and high-freq {}'.format(vtln_low, vtln_high, low_freq, high_freq))
|
|
|
|
|
|
|
|
|
|
bin = paddle.arange(num_bins).unsqueeze(1)
|
|
|
|
|
left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1)
|
|
|
|
|
center_mel = mel_low_freq + (bin + 1.0
|
|
|
|
|
) * mel_freq_delta # size(num_bins, 1)
|
|
|
|
|
right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1)
|
|
|
|
|
left_mel = mel_low_freq + bin * mel_freq_delta # (num_bins, 1)
|
|
|
|
|
center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # (num_bins, 1)
|
|
|
|
|
right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # (num_bins, 1)
|
|
|
|
|
|
|
|
|
|
if vtln_warp_factor != 1.0:
|
|
|
|
|
left_mel = _vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq,
|
|
|
|
@ -395,25 +370,23 @@ def _get_mel_banks(num_bins: int,
|
|
|
|
|
right_mel = _vtln_warp_mel_freq(vtln_low, vtln_high, low_freq,
|
|
|
|
|
high_freq, vtln_warp_factor, right_mel)
|
|
|
|
|
|
|
|
|
|
center_freqs = _inverse_mel_scale(center_mel) # size (num_bins)
|
|
|
|
|
# size(1, num_fft_bins)
|
|
|
|
|
center_freqs = _inverse_mel_scale(center_mel) # (num_bins)
|
|
|
|
|
# (1, num_fft_bins)
|
|
|
|
|
mel = _mel_scale(fft_bin_width * paddle.arange(num_fft_bins)).unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
# size (num_bins, num_fft_bins)
|
|
|
|
|
# (num_bins, num_fft_bins)
|
|
|
|
|
up_slope = (mel - left_mel) / (center_mel - left_mel)
|
|
|
|
|
down_slope = (right_mel - mel) / (right_mel - center_mel)
|
|
|
|
|
|
|
|
|
|
if vtln_warp_factor == 1.0:
|
|
|
|
|
# left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values
|
|
|
|
|
bins = paddle.maximum(
|
|
|
|
|
paddle.zeros([1]), paddle.minimum(up_slope, down_slope))
|
|
|
|
|
else:
|
|
|
|
|
# warping can move the order of left_mel, center_mel, right_mel anywhere
|
|
|
|
|
bins = paddle.zeros_like(up_slope)
|
|
|
|
|
up_idx = paddle.greater_than(mel, left_mel) & paddle.less_than(
|
|
|
|
|
mel, center_mel) # left_mel < mel <= center_mel
|
|
|
|
|
mel, center_mel)
|
|
|
|
|
down_idx = paddle.greater_than(mel, center_mel) & paddle.less_than(
|
|
|
|
|
mel, right_mel) # center_mel < mel < right_mel
|
|
|
|
|
mel, right_mel)
|
|
|
|
|
bins[up_idx] = up_slope[up_idx]
|
|
|
|
|
bins[down_idx] = down_slope[down_idx]
|
|
|
|
|
|
|
|
|
@ -430,13 +403,12 @@ def fbank(waveform: Tensor,
|
|
|
|
|
high_freq: float=0.0,
|
|
|
|
|
htk_compat: bool=False,
|
|
|
|
|
low_freq: float=20.0,
|
|
|
|
|
min_duration: float=0.0,
|
|
|
|
|
num_mel_bins: int=23,
|
|
|
|
|
n_mels: int=23,
|
|
|
|
|
preemphasis_coefficient: float=0.97,
|
|
|
|
|
raw_energy: bool=True,
|
|
|
|
|
remove_dc_offset: bool=True,
|
|
|
|
|
round_to_power_of_two: bool=True,
|
|
|
|
|
sample_frequency: float=16000.0,
|
|
|
|
|
sr: int=16000,
|
|
|
|
|
snip_edges: bool=True,
|
|
|
|
|
subtract_mean: bool=False,
|
|
|
|
|
use_energy: bool=False,
|
|
|
|
@ -446,83 +418,75 @@ def fbank(waveform: Tensor,
|
|
|
|
|
vtln_low: float=100.0,
|
|
|
|
|
vtln_warp: float=1.0,
|
|
|
|
|
window_type: str=POVEY) -> Tensor:
|
|
|
|
|
"""[summary]
|
|
|
|
|
"""Compute and return filter banks from a waveform. The output is identical to Kaldi's.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
waveform (Tensor): [description]
|
|
|
|
|
blackman_coeff (float, optional): [description]. Defaults to 0.42.
|
|
|
|
|
channel (int, optional): [description]. Defaults to -1.
|
|
|
|
|
dither (float, optional): [description]. Defaults to 0.0.
|
|
|
|
|
energy_floor (float, optional): [description]. Defaults to 1.0.
|
|
|
|
|
frame_length (float, optional): [description]. Defaults to 25.0.
|
|
|
|
|
frame_shift (float, optional): [description]. Defaults to 10.0.
|
|
|
|
|
high_freq (float, optional): [description]. Defaults to 0.0.
|
|
|
|
|
htk_compat (bool, optional): [description]. Defaults to False.
|
|
|
|
|
low_freq (float, optional): [description]. Defaults to 20.0.
|
|
|
|
|
min_duration (float, optional): [description]. Defaults to 0.0.
|
|
|
|
|
num_mel_bins (int, optional): [description]. Defaults to 23.
|
|
|
|
|
preemphasis_coefficient (float, optional): [description]. Defaults to 0.97.
|
|
|
|
|
raw_energy (bool, optional): [description]. Defaults to True.
|
|
|
|
|
remove_dc_offset (bool, optional): [description]. Defaults to True.
|
|
|
|
|
round_to_power_of_two (bool, optional): [description]. Defaults to True.
|
|
|
|
|
sample_frequency (float, optional): [description]. Defaults to 16000.0.
|
|
|
|
|
snip_edges (bool, optional): [description]. Defaults to True.
|
|
|
|
|
subtract_mean (bool, optional): [description]. Defaults to False.
|
|
|
|
|
use_energy (bool, optional): [description]. Defaults to False.
|
|
|
|
|
use_log_fbank (bool, optional): [description]. Defaults to True.
|
|
|
|
|
use_power (bool, optional): [description]. Defaults to True.
|
|
|
|
|
vtln_high (float, optional): [description]. Defaults to -500.0.
|
|
|
|
|
vtln_low (float, optional): [description]. Defaults to 100.0.
|
|
|
|
|
vtln_warp (float, optional): [description]. Defaults to 1.0.
|
|
|
|
|
window_type (str, optional): [description]. Defaults to POVEY.
|
|
|
|
|
waveform (Tensor): A waveform tensor with shape [C, T].
|
|
|
|
|
blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42.
|
|
|
|
|
channel (int, optional): Select the channel of waveform. Defaults to -1.
|
|
|
|
|
dither (float, optional): Dithering constant . Defaults to 0.0.
|
|
|
|
|
energy_floor (float, optional): Floor on energy of the output Spectrogram. Defaults to 1.0.
|
|
|
|
|
frame_length (float, optional): Frame length in milliseconds. Defaults to 25.0.
|
|
|
|
|
frame_shift (float, optional): Shift between adjacent frames in milliseconds. Defaults to 10.0.
|
|
|
|
|
high_freq (float, optional): The upper cut-off frequency. Defaults to 0.0.
|
|
|
|
|
htk_compat (bool, optional): Put energy to the last when it is set True. Defaults to False.
|
|
|
|
|
low_freq (float, optional): The lower cut-off frequency. Defaults to 20.0.
|
|
|
|
|
n_mels (int, optional): Number of output mel bins. Defaults to 23.
|
|
|
|
|
preemphasis_coefficient (float, optional): Preemphasis coefficient for input waveform. Defaults to 0.97.
|
|
|
|
|
raw_energy (bool, optional): Whether to compute before preemphasis and windowing. Defaults to True.
|
|
|
|
|
remove_dc_offset (bool, optional): Whether to subtract mean from waveform on frames. Defaults to True.
|
|
|
|
|
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
|
|
|
|
to FFT. Defaults to True.
|
|
|
|
|
sr (int, optional): Sample rate of input waveform. Defaults to 16000.
|
|
|
|
|
snip_edges (bool, optional): Drop samples in the end of waveform that cann't fit a singal frame when it
|
|
|
|
|
is set True. Otherwise performs reflect padding to the end of waveform. Defaults to True.
|
|
|
|
|
subtract_mean (bool, optional): Whether to subtract mean of feature files. Defaults to False.
|
|
|
|
|
use_energy (bool, optional): Add an dimension with energy of spectrogram to the output. Defaults to False.
|
|
|
|
|
use_log_fbank (bool, optional): Return log fbank when it is set True. Defaults to True.
|
|
|
|
|
use_power (bool, optional): Whether to use power instead of magnitude. Defaults to True.
|
|
|
|
|
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function. Defaults to -500.0.
|
|
|
|
|
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function. Defaults to 100.0.
|
|
|
|
|
vtln_warp (float, optional): Vtln warp factor. Defaults to 1.0.
|
|
|
|
|
window_type (str, optional): Choose type of window for FFT computation. Defaults to POVEY.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor: [description]
|
|
|
|
|
Tensor: A filter banks tensor with shape (m, n_mels).
|
|
|
|
|
"""
|
|
|
|
|
dtype = waveform.dtype
|
|
|
|
|
|
|
|
|
|
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
|
|
|
|
|
waveform, channel, sample_frequency, frame_shift, frame_length,
|
|
|
|
|
round_to_power_of_two, preemphasis_coefficient)
|
|
|
|
|
|
|
|
|
|
if len(waveform) < min_duration * sample_frequency:
|
|
|
|
|
# signal is too short
|
|
|
|
|
return paddle.empty([0], dtype=dtype)
|
|
|
|
|
waveform, channel, sr, frame_shift, frame_length, round_to_power_of_two,
|
|
|
|
|
preemphasis_coefficient)
|
|
|
|
|
|
|
|
|
|
# strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
|
|
|
|
|
strided_input, signal_log_energy = _get_window(
|
|
|
|
|
waveform, padded_window_size, window_size, window_shift, window_type,
|
|
|
|
|
blackman_coeff, snip_edges, raw_energy, energy_floor, dither,
|
|
|
|
|
remove_dc_offset, preemphasis_coefficient)
|
|
|
|
|
|
|
|
|
|
# size (m, padded_window_size // 2 + 1)
|
|
|
|
|
# (m, padded_window_size // 2 + 1)
|
|
|
|
|
spectrum = paddle.fft.rfft(strided_input).abs()
|
|
|
|
|
if use_power:
|
|
|
|
|
spectrum = spectrum.pow(2.)
|
|
|
|
|
|
|
|
|
|
# size (num_mel_bins, padded_window_size // 2)
|
|
|
|
|
mel_energies, _ = _get_mel_banks(num_mel_bins, padded_window_size,
|
|
|
|
|
sample_frequency, low_freq, high_freq,
|
|
|
|
|
vtln_low, vtln_high, vtln_warp)
|
|
|
|
|
# (n_mels, padded_window_size // 2)
|
|
|
|
|
mel_energies, _ = _get_mel_banks(n_mels, padded_window_size, sr, low_freq,
|
|
|
|
|
high_freq, vtln_low, vtln_high, vtln_warp)
|
|
|
|
|
mel_energies = mel_energies.astype(dtype)
|
|
|
|
|
|
|
|
|
|
# pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
|
|
|
|
|
# (n_mels, padded_window_size // 2 + 1)
|
|
|
|
|
mel_energies = paddle.nn.functional.pad(
|
|
|
|
|
mel_energies.unsqueeze(0), (0, 1),
|
|
|
|
|
data_format='NCL',
|
|
|
|
|
mode='constant',
|
|
|
|
|
value=0).squeeze(0)
|
|
|
|
|
|
|
|
|
|
# sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
|
|
|
|
|
# (m, n_mels)
|
|
|
|
|
mel_energies = paddle.mm(spectrum, mel_energies.T)
|
|
|
|
|
if use_log_fbank:
|
|
|
|
|
# avoid log of zero (which should be prevented anyway by dithering)
|
|
|
|
|
mel_energies = paddle.maximum(mel_energies, _get_epsilon(dtype)).log()
|
|
|
|
|
|
|
|
|
|
# if use_energy then add it as the last column for htk_compat == true else first column
|
|
|
|
|
if use_energy:
|
|
|
|
|
signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1)
|
|
|
|
|
# returns size (m, num_mel_bins + 1)
|
|
|
|
|
signal_log_energy = signal_log_energy.unsqueeze(1)
|
|
|
|
|
if htk_compat:
|
|
|
|
|
mel_energies = paddle.concat(
|
|
|
|
|
(mel_energies, signal_log_energy), axis=1)
|
|
|
|
@ -530,28 +494,20 @@ def fbank(waveform: Tensor,
|
|
|
|
|
mel_energies = paddle.concat(
|
|
|
|
|
(signal_log_energy, mel_energies), axis=1)
|
|
|
|
|
|
|
|
|
|
# (m, n_mels + 1)
|
|
|
|
|
mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
|
|
|
|
|
return mel_energies
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor:
|
|
|
|
|
# returns a dct matrix of size (num_mel_bins, num_ceps)
|
|
|
|
|
# size (num_mel_bins, num_mel_bins)
|
|
|
|
|
dct_matrix = create_dct(num_mel_bins, num_mel_bins, 'ortho')
|
|
|
|
|
# kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins)
|
|
|
|
|
# this would be the first column in the dct_matrix for torchaudio as it expects a
|
|
|
|
|
# right multiply (which would be the first column of the kaldi's dct_matrix as kaldi
|
|
|
|
|
# expects a left multiply e.g. dct_matrix * vector).
|
|
|
|
|
dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins))
|
|
|
|
|
dct_matrix = dct_matrix[:, :num_ceps]
|
|
|
|
|
def _get_dct_matrix(n_mfcc: int, n_mels: int) -> Tensor:
|
|
|
|
|
dct_matrix = create_dct(n_mels, n_mels, 'ortho')
|
|
|
|
|
dct_matrix[:, 0] = math.sqrt(1 / float(n_mels))
|
|
|
|
|
dct_matrix = dct_matrix[:, :n_mfcc] # (n_mels, n_mfcc)
|
|
|
|
|
return dct_matrix
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor:
|
|
|
|
|
# returns size (num_ceps)
|
|
|
|
|
# Compute liftering coefficients (scaling on cepstral coeffs)
|
|
|
|
|
# coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected.
|
|
|
|
|
i = paddle.arange(num_ceps)
|
|
|
|
|
def _get_lifter_coeffs(n_mfcc: int, cepstral_lifter: float) -> Tensor:
|
|
|
|
|
i = paddle.arange(n_mfcc)
|
|
|
|
|
return 1.0 + 0.5 * cepstral_lifter * paddle.sin(math.pi * i /
|
|
|
|
|
cepstral_lifter)
|
|
|
|
|
|
|
|
|
@ -567,14 +523,13 @@ def mfcc(waveform: Tensor,
|
|
|
|
|
high_freq: float=0.0,
|
|
|
|
|
htk_compat: bool=False,
|
|
|
|
|
low_freq: float=20.0,
|
|
|
|
|
num_ceps: int=13,
|
|
|
|
|
min_duration: float=0.0,
|
|
|
|
|
num_mel_bins: int=23,
|
|
|
|
|
n_mfcc: int=13,
|
|
|
|
|
n_mels: int=23,
|
|
|
|
|
preemphasis_coefficient: float=0.97,
|
|
|
|
|
raw_energy: bool=True,
|
|
|
|
|
remove_dc_offset: bool=True,
|
|
|
|
|
round_to_power_of_two: bool=True,
|
|
|
|
|
sample_frequency: float=16000.0,
|
|
|
|
|
sr: int=16000,
|
|
|
|
|
snip_edges: bool=True,
|
|
|
|
|
subtract_mean: bool=False,
|
|
|
|
|
use_energy: bool=False,
|
|
|
|
@ -582,47 +537,47 @@ def mfcc(waveform: Tensor,
|
|
|
|
|
vtln_low: float=100.0,
|
|
|
|
|
vtln_warp: float=1.0,
|
|
|
|
|
window_type: str=POVEY) -> Tensor:
|
|
|
|
|
"""[summary]
|
|
|
|
|
"""Compute and return mel frequency cepstral coefficients from a waveform. The output is
|
|
|
|
|
identical to Kaldi's.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
waveform (Tensor): [description]
|
|
|
|
|
blackman_coeff (float, optional): [description]. Defaults to 0.42.
|
|
|
|
|
cepstral_lifter (float, optional): [description]. Defaults to 22.0.
|
|
|
|
|
channel (int, optional): [description]. Defaults to -1.
|
|
|
|
|
dither (float, optional): [description]. Defaults to 0.0.
|
|
|
|
|
energy_floor (float, optional): [description]. Defaults to 1.0.
|
|
|
|
|
frame_length (float, optional): [description]. Defaults to 25.0.
|
|
|
|
|
frame_shift (float, optional): [description]. Defaults to 10.0.
|
|
|
|
|
high_freq (float, optional): [description]. Defaults to 0.0.
|
|
|
|
|
htk_compat (bool, optional): [description]. Defaults to False.
|
|
|
|
|
low_freq (float, optional): [description]. Defaults to 20.0.
|
|
|
|
|
num_ceps (int, optional): [description]. Defaults to 13.
|
|
|
|
|
min_duration (float, optional): [description]. Defaults to 0.0.
|
|
|
|
|
num_mel_bins (int, optional): [description]. Defaults to 23.
|
|
|
|
|
preemphasis_coefficient (float, optional): [description]. Defaults to 0.97.
|
|
|
|
|
raw_energy (bool, optional): [description]. Defaults to True.
|
|
|
|
|
remove_dc_offset (bool, optional): [description]. Defaults to True.
|
|
|
|
|
round_to_power_of_two (bool, optional): [description]. Defaults to True.
|
|
|
|
|
sample_frequency (float, optional): [description]. Defaults to 16000.0.
|
|
|
|
|
snip_edges (bool, optional): [description]. Defaults to True.
|
|
|
|
|
subtract_mean (bool, optional): [description]. Defaults to False.
|
|
|
|
|
use_energy (bool, optional): [description]. Defaults to False.
|
|
|
|
|
vtln_high (float, optional): [description]. Defaults to -500.0.
|
|
|
|
|
vtln_low (float, optional): [description]. Defaults to 100.0.
|
|
|
|
|
vtln_warp (float, optional): [description]. Defaults to 1.0.
|
|
|
|
|
window_type (str, optional): [description]. Defaults to POVEY.
|
|
|
|
|
waveform (Tensor): A waveform tensor with shape [C, T].
|
|
|
|
|
blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42.
|
|
|
|
|
cepstral_lifter (float, optional): Scaling of output mfccs. Defaults to 22.0.
|
|
|
|
|
channel (int, optional): Select the channel of waveform. Defaults to -1.
|
|
|
|
|
dither (float, optional): Dithering constant . Defaults to 0.0.
|
|
|
|
|
energy_floor (float, optional): Floor on energy of the output Spectrogram. Defaults to 1.0.
|
|
|
|
|
frame_length (float, optional): Frame length in milliseconds. Defaults to 25.0.
|
|
|
|
|
frame_shift (float, optional): Shift between adjacent frames in milliseconds. Defaults to 10.0.
|
|
|
|
|
high_freq (float, optional): The upper cut-off frequency. Defaults to 0.0.
|
|
|
|
|
htk_compat (bool, optional): Put energy to the last when it is set True. Defaults to False.
|
|
|
|
|
low_freq (float, optional): The lower cut-off frequency. Defaults to 20.0.
|
|
|
|
|
n_mfcc (int, optional): Number of cepstra in MFCC. Defaults to 13.
|
|
|
|
|
n_mels (int, optional): Number of output mel bins. Defaults to 23.
|
|
|
|
|
preemphasis_coefficient (float, optional): Preemphasis coefficient for input waveform. Defaults to 0.97.
|
|
|
|
|
raw_energy (bool, optional): Whether to compute before preemphasis and windowing. Defaults to True.
|
|
|
|
|
remove_dc_offset (bool, optional): Whether to subtract mean from waveform on frames. Defaults to True.
|
|
|
|
|
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
|
|
|
|
to FFT. Defaults to True.
|
|
|
|
|
sr (int, optional): Sample rate of input waveform. Defaults to 16000.
|
|
|
|
|
snip_edges (bool, optional): Drop samples in the end of waveform that cann't fit a singal frame when it
|
|
|
|
|
is set True. Otherwise performs reflect padding to the end of waveform. Defaults to True.
|
|
|
|
|
subtract_mean (bool, optional): Whether to subtract mean of feature files. Defaults to False.
|
|
|
|
|
use_energy (bool, optional): Add an dimension with energy of spectrogram to the output. Defaults to False.
|
|
|
|
|
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function. Defaults to -500.0.
|
|
|
|
|
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function. Defaults to 100.0.
|
|
|
|
|
vtln_warp (float, optional): Vtln warp factor. Defaults to 1.0.
|
|
|
|
|
window_type (str, optional): Choose type of window for FFT computation. Defaults to POVEY.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor: [description]
|
|
|
|
|
Tensor: A mel frequency cepstral coefficients tensor with shape (m, n_mfcc).
|
|
|
|
|
"""
|
|
|
|
|
assert num_ceps <= num_mel_bins, 'num_ceps cannot be larger than num_mel_bins: %d vs %d' % (
|
|
|
|
|
num_ceps, num_mel_bins)
|
|
|
|
|
assert n_mfcc <= n_mels, 'n_mfcc cannot be larger than n_mels: %d vs %d' % (
|
|
|
|
|
n_mfcc, n_mels)
|
|
|
|
|
|
|
|
|
|
dtype = waveform.dtype
|
|
|
|
|
|
|
|
|
|
# The mel_energies should not be squared (use_power=True), not have mean subtracted
|
|
|
|
|
# (subtract_mean=False), and use log (use_log_fbank=True).
|
|
|
|
|
# size (m, num_mel_bins + use_energy)
|
|
|
|
|
# (m, n_mels + use_energy)
|
|
|
|
|
feature = fbank(
|
|
|
|
|
waveform=waveform,
|
|
|
|
|
blackman_coeff=blackman_coeff,
|
|
|
|
@ -634,13 +589,12 @@ def mfcc(waveform: Tensor,
|
|
|
|
|
high_freq=high_freq,
|
|
|
|
|
htk_compat=htk_compat,
|
|
|
|
|
low_freq=low_freq,
|
|
|
|
|
min_duration=min_duration,
|
|
|
|
|
num_mel_bins=num_mel_bins,
|
|
|
|
|
n_mels=n_mels,
|
|
|
|
|
preemphasis_coefficient=preemphasis_coefficient,
|
|
|
|
|
raw_energy=raw_energy,
|
|
|
|
|
remove_dc_offset=remove_dc_offset,
|
|
|
|
|
round_to_power_of_two=round_to_power_of_two,
|
|
|
|
|
sample_frequency=sample_frequency,
|
|
|
|
|
sr=sr,
|
|
|
|
|
snip_edges=snip_edges,
|
|
|
|
|
subtract_mean=False,
|
|
|
|
|
use_energy=use_energy,
|
|
|
|
@ -652,34 +606,29 @@ def mfcc(waveform: Tensor,
|
|
|
|
|
window_type=window_type)
|
|
|
|
|
|
|
|
|
|
if use_energy:
|
|
|
|
|
# size (m)
|
|
|
|
|
signal_log_energy = feature[:, num_mel_bins if htk_compat else 0]
|
|
|
|
|
# offset is 0 if htk_compat==True else 1
|
|
|
|
|
# (m)
|
|
|
|
|
signal_log_energy = feature[:, n_mels if htk_compat else 0]
|
|
|
|
|
mel_offset = int(not htk_compat)
|
|
|
|
|
feature = feature[:, mel_offset:(num_mel_bins + mel_offset)]
|
|
|
|
|
feature = feature[:, mel_offset:(n_mels + mel_offset)]
|
|
|
|
|
|
|
|
|
|
# size (num_mel_bins, num_ceps)
|
|
|
|
|
dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).astype(dtype=dtype)
|
|
|
|
|
# (n_mels, n_mfcc)
|
|
|
|
|
dct_matrix = _get_dct_matrix(n_mfcc, n_mels).astype(dtype=dtype)
|
|
|
|
|
|
|
|
|
|
# size (m, num_ceps)
|
|
|
|
|
# (m, n_mfcc)
|
|
|
|
|
feature = feature.matmul(dct_matrix)
|
|
|
|
|
|
|
|
|
|
if cepstral_lifter != 0.0:
|
|
|
|
|
# size (1, num_ceps)
|
|
|
|
|
lifter_coeffs = _get_lifter_coeffs(num_ceps,
|
|
|
|
|
cepstral_lifter).unsqueeze(0)
|
|
|
|
|
# (1, n_mfcc)
|
|
|
|
|
lifter_coeffs = _get_lifter_coeffs(n_mfcc, cepstral_lifter).unsqueeze(0)
|
|
|
|
|
feature *= lifter_coeffs.astype(dtype=dtype)
|
|
|
|
|
|
|
|
|
|
# if use_energy then replace the last column for htk_compat == true else first column
|
|
|
|
|
if use_energy:
|
|
|
|
|
feature[:, 0] = signal_log_energy
|
|
|
|
|
|
|
|
|
|
if htk_compat:
|
|
|
|
|
energy = feature[:, 0].unsqueeze(1) # size (m, 1)
|
|
|
|
|
feature = feature[:, 1:] # size (m, num_ceps - 1)
|
|
|
|
|
energy = feature[:, 0].unsqueeze(1) # (m, 1)
|
|
|
|
|
feature = feature[:, 1:] # (m, n_mfcc - 1)
|
|
|
|
|
if not use_energy:
|
|
|
|
|
# scale on C0 (actually removing a scale we previously added that's
|
|
|
|
|
# part of one common definition of the cosine transform.)
|
|
|
|
|
energy *= math.sqrt(2)
|
|
|
|
|
|
|
|
|
|
feature = paddle.concat((feature, energy), axis=1)
|
|
|
|
|