From 0d7d87120b79b71259a2d42c8a33f0e93adf67ee Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 14 Sep 2022 16:44:12 +0000 Subject: [PATCH] simplify feature pipeline graph --- paddlespeech/audio/compliance/kaldi.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/paddlespeech/audio/compliance/kaldi.py b/paddlespeech/audio/compliance/kaldi.py index beb2d86b9..24415058c 100644 --- a/paddlespeech/audio/compliance/kaldi.py +++ b/paddlespeech/audio/compliance/kaldi.py @@ -357,10 +357,13 @@ def _get_mel_banks(num_bins: int, ('Bad values in options: vtln-low {} and vtln-high {}, versus ' 'low-freq {} and high-freq {}'.format(vtln_low, vtln_high, low_freq, high_freq)) - bin = paddle.arange(num_bins).unsqueeze(1) + bin = paddle.arange(num_bins, dtype=paddle.float32).unsqueeze(1) + # left_mel = mel_low_freq + bin * mel_freq_delta # (num_bins, 1) + # center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # (num_bins, 1) + # right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # (num_bins, 1) left_mel = mel_low_freq + bin * mel_freq_delta # (num_bins, 1) - center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # (num_bins, 1) - right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # (num_bins, 1) + center_mel = left_mel + mel_freq_delta + right_mel = center_mel + mel_freq_delta if vtln_warp_factor != 1.0: left_mel = _vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, @@ -373,7 +376,7 @@ def _get_mel_banks(num_bins: int, center_freqs = _inverse_mel_scale(center_mel) # (num_bins) # (1, num_fft_bins) - mel = _mel_scale(fft_bin_width * paddle.arange(num_fft_bins)).unsqueeze(0) + mel = _mel_scale(fft_bin_width * paddle.arange(num_fft_bins, dtype=paddle.float32)).unsqueeze(0) # (num_bins, num_fft_bins) up_slope = (mel - left_mel) / (center_mel - left_mel) @@ -472,7 +475,8 @@ def fbank(waveform: Tensor, # (n_mels, padded_window_size // 2) mel_energies, _ = _get_mel_banks(n_mels, padded_window_size, sr, low_freq, high_freq, vtln_low, vtln_high, vtln_warp) - mel_energies = mel_energies.astype(dtype) + # mel_energies = mel_energies.astype(dtype) + assert mel_energies.dtype == dtype # (n_mels, padded_window_size // 2 + 1) mel_energies = paddle.nn.functional.pad(