From 395e8d1bbaa62d91d789beaf2a2a6429f63a6177 Mon Sep 17 00:00:00 2001 From: KP <109694228@qq.com> Date: Thu, 10 Mar 2022 14:05:51 +0800 Subject: [PATCH] Add fbank and mfcc unittest. --- paddleaudio/tests/features/test_kaldi.py | 81 ++++++++++++++++++++++ paddleaudio/tests/features/test_librosa.py | 54 +++++++++++++++ 2 files changed, 135 insertions(+) create mode 100644 paddleaudio/tests/features/test_kaldi.py diff --git a/paddleaudio/tests/features/test_kaldi.py b/paddleaudio/tests/features/test_kaldi.py new file mode 100644 index 00000000..6e826aaa --- /dev/null +++ b/paddleaudio/tests/features/test_kaldi.py @@ -0,0 +1,81 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +import paddle +import torch +import torchaudio + +import paddleaudio +from .base import FeatTest + + +class TestKaldi(FeatTest): + def initParmas(self): + self.window_size = 1024 + self.dtype = 'float32' + + def test_window(self): + t_hann_window = torch.hann_window( + self.window_size, periodic=False, dtype=eval(f'torch.{self.dtype}')) + t_hamm_window = torch.hamming_window( + self.window_size, + periodic=False, + alpha=0.54, + beta=0.46, + dtype=eval(f'torch.{self.dtype}')) + t_povey_window = torch.hann_window( + self.window_size, periodic=False, + dtype=eval(f'torch.{self.dtype}')).pow(0.85) + + p_hann_window = paddleaudio.functional.window.get_window( + 'hann', + self.window_size, + fftbins=False, + dtype=eval(f'paddle.{self.dtype}')) + p_hamm_window = paddleaudio.functional.window.get_window( + 'hamming', + self.window_size, + fftbins=False, + dtype=eval(f'paddle.{self.dtype}')) + p_povey_window = paddleaudio.functional.window.get_window( + 'hann', + self.window_size, + fftbins=False, + dtype=eval(f'paddle.{self.dtype}')).pow(0.85) + + np.testing.assert_array_almost_equal(t_hann_window, p_hann_window) + np.testing.assert_array_almost_equal(t_hamm_window, p_hamm_window) + np.testing.assert_array_almost_equal(t_povey_window, p_povey_window) + + def test_fbank(self): + ta_features = torchaudio.compliance.kaldi.fbank( + torch.from_numpy(self.waveform.astype(self.dtype))) + pa_features = paddleaudio.compliance.kaldi.fbank( + paddle.to_tensor(self.waveform.astype(self.dtype))) + np.testing.assert_array_almost_equal( + ta_features, pa_features, decimal=4) + + def test_mfcc(self): + ta_features = torchaudio.compliance.kaldi.mfcc( + torch.from_numpy(self.waveform.astype(self.dtype))) + pa_features = paddleaudio.compliance.kaldi.mfcc( + paddle.to_tensor(self.waveform.astype(self.dtype))) + np.testing.assert_array_almost_equal( + ta_features, pa_features, decimal=4) + + +if __name__ == '__main__': + unittest.main() diff --git a/paddleaudio/tests/features/test_librosa.py b/paddleaudio/tests/features/test_librosa.py index 6500eb2b..cf0c98c7 100644 --- a/paddleaudio/tests/features/test_librosa.py +++ b/paddleaudio/tests/features/test_librosa.py @@ -27,9 +27,11 @@ class TestLibrosa(FeatTest): self.n_fft = 512 self.hop_length = 128 self.n_mels = 40 + self.n_mfcc = 20 self.fmin = 0.0 self.window_str = 'hann' self.pad_mode = 'reflect' + self.top_db = 80.0 def test_stft(self): if len(self.waveform.shape) == 2: # (C, T) @@ -222,6 +224,58 @@ class TestLibrosa(FeatTest): np.testing.assert_array_almost_equal( feature_librosa, feature_layer, decimal=4) + def test_mfcc(self): + if len(self.waveform.shape) == 2: # (C, T) + self.waveform = self.waveform.squeeze( + 0) # 1D input for librosa.feature.melspectrogram + + # librosa: + feature_librosa = librosa.feature.mfcc( + y=self.waveform, + sr=self.sr, + S=None, + n_mfcc=self.n_mfcc, + dct_type=2, + norm='ortho', + lifter=0, + n_fft=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + fmin=self.fmin) + + # paddleaudio.compliance.librosa: + feature_compliance = paddleaudio.compliance.librosa.mfcc( + x=self.waveform, + sr=self.sr, + n_mfcc=self.n_mfcc, + dct_type=2, + norm='ortho', + lifter=0, + window_size=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + fmin=self.fmin, + top_db=self.top_db) + + # paddleaudio.features.layer + x = paddle.to_tensor( + self.waveform, dtype=paddle.float64).unsqueeze(0) # Add batch dim. + feature_extractor = paddleaudio.features.MFCC( + sr=self.sr, + n_mfcc=self.n_mfcc, + n_fft=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + f_min=self.fmin, + top_db=self.top_db, + dtype=x.dtype) + feature_layer = feature_extractor(x).squeeze(0).numpy() + + np.testing.assert_array_almost_equal( + feature_librosa, feature_compliance, decimal=4) + np.testing.assert_array_almost_equal( + feature_librosa, feature_layer, decimal=4) + if __name__ == '__main__': unittest.main()