diff --git a/paddlespeech/audio/datasets/dataset.py b/paddlespeech/audio/datasets/dataset.py new file mode 100644 index 000000000..170e91669 --- /dev/null +++ b/paddlespeech/audio/datasets/dataset.py @@ -0,0 +1,100 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List + +import numpy as np +import paddle + +from ..backends.soundfile_backend import soundfile_load as load_audio +from ..compliance.kaldi import fbank as kaldi_fbank +from ..compliance.kaldi import mfcc as kaldi_mfcc +from ..compliance.librosa import melspectrogram +from ..compliance.librosa import mfcc + +feat_funcs = { + 'raw': None, + 'melspectrogram': melspectrogram, + 'mfcc': mfcc, + 'kaldi_fbank': kaldi_fbank, + 'kaldi_mfcc': kaldi_mfcc, +} + + +class AudioClassificationDataset(paddle.io.Dataset): + """ + Base class of audio classification dataset. + """ + + def __init__(self, + files: List[str], + labels: List[int], + feat_type: str='raw', + sample_rate: int=None, + **kwargs): + """ + Ags: + files (:obj:`List[str]`): A list of absolute path of audio files. + labels (:obj:`List[int]`): Labels of audio files. + feat_type (:obj:`str`, `optional`, defaults to `raw`): + It identifies the feature type that user wants to extract of an audio file. + """ + super(AudioClassificationDataset, self).__init__() + + if feat_type not in feat_funcs.keys(): + raise RuntimeError( + f"Unknown feat_type: {feat_type}, it must be one in {list(feat_funcs.keys())}" + ) + + self.files = files + self.labels = labels + + self.feat_type = feat_type + self.sample_rate = sample_rate + self.feat_config = kwargs # Pass keyword arguments to customize feature config + + def _get_data(self, input_file: str): + raise NotImplementedError + + def _convert_to_record(self, idx): + file, label = self.files[idx], self.labels[idx] + + if self.sample_rate is None: + waveform, sample_rate = load_audio(file) + else: + waveform, sample_rate = load_audio(file, sr=self.sample_rate) + + feat_func = feat_funcs[self.feat_type] + + record = {} + if self.feat_type in ['kaldi_fbank', 'kaldi_mfcc']: + waveform = paddle.to_tensor(waveform).unsqueeze(0) # (C, T) + record['feat'] = feat_func( + waveform=waveform, sr=self.sample_rate, **self.feat_config) + else: + record['feat'] = feat_func( + waveform, sample_rate, + **self.feat_config) if feat_func else waveform + record['label'] = label + return record + + def __getitem__(self, idx): + record = self._convert_to_record(idx) + if self.feat_type in ['kaldi_fbank', 'kaldi_mfcc']: + return self.keys[idx], record['feat'], record['label'] + else: + return np.array(record['feat']).transpose(), np.array( + record['label'], dtype=np.int64) + + def __len__(self): + return len(self.files) diff --git a/setup.py b/setup.py index 49d9188a6..25db54064 100644 --- a/setup.py +++ b/setup.py @@ -121,7 +121,7 @@ base = [ "webrtcvad", "yacs>=0.1.8", "zhon", - "sklearn", + "scikit-learn", ] server = ["pattern_singleton", "websockets"]