simplify feature pipeline graph

pull/2212/head
Hui Zhang 2 years ago
parent 8690a00bd8
commit 0d7d87120b

@ -357,10 +357,13 @@ def _get_mel_banks(num_bins: int,
('Bad values in options: vtln-low {} and vtln-high {}, versus ' ('Bad values in options: vtln-low {} and vtln-high {}, versus '
'low-freq {} and high-freq {}'.format(vtln_low, vtln_high, low_freq, high_freq)) 'low-freq {} and high-freq {}'.format(vtln_low, vtln_high, low_freq, high_freq))
bin = paddle.arange(num_bins).unsqueeze(1) bin = paddle.arange(num_bins, dtype=paddle.float32).unsqueeze(1)
# left_mel = mel_low_freq + bin * mel_freq_delta # (num_bins, 1)
# center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # (num_bins, 1)
# right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # (num_bins, 1)
left_mel = mel_low_freq + bin * mel_freq_delta # (num_bins, 1) left_mel = mel_low_freq + bin * mel_freq_delta # (num_bins, 1)
center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # (num_bins, 1) center_mel = left_mel + mel_freq_delta
right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # (num_bins, 1) right_mel = center_mel + mel_freq_delta
if vtln_warp_factor != 1.0: if vtln_warp_factor != 1.0:
left_mel = _vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, left_mel = _vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq,
@ -373,7 +376,7 @@ def _get_mel_banks(num_bins: int,
center_freqs = _inverse_mel_scale(center_mel) # (num_bins) center_freqs = _inverse_mel_scale(center_mel) # (num_bins)
# (1, num_fft_bins) # (1, num_fft_bins)
mel = _mel_scale(fft_bin_width * paddle.arange(num_fft_bins)).unsqueeze(0) mel = _mel_scale(fft_bin_width * paddle.arange(num_fft_bins, dtype=paddle.float32)).unsqueeze(0)
# (num_bins, num_fft_bins) # (num_bins, num_fft_bins)
up_slope = (mel - left_mel) / (center_mel - left_mel) up_slope = (mel - left_mel) / (center_mel - left_mel)
@ -472,7 +475,8 @@ def fbank(waveform: Tensor,
# (n_mels, padded_window_size // 2) # (n_mels, padded_window_size // 2)
mel_energies, _ = _get_mel_banks(n_mels, padded_window_size, sr, low_freq, mel_energies, _ = _get_mel_banks(n_mels, padded_window_size, sr, low_freq,
high_freq, vtln_low, vtln_high, vtln_warp) high_freq, vtln_low, vtln_high, vtln_warp)
mel_energies = mel_energies.astype(dtype) # mel_energies = mel_energies.astype(dtype)
assert mel_energies.dtype == dtype
# (n_mels, padded_window_size // 2 + 1) # (n_mels, padded_window_size // 2 + 1)
mel_energies = paddle.nn.functional.pad( mel_energies = paddle.nn.functional.pad(

Loading…
Cancel
Save