diff --git a/paddlespeech/vector/io/batch.py b/paddlespeech/vector/io/batch.py index 92ca990cf..b85563e7a 100644 --- a/paddlespeech/vector/io/batch.py +++ b/paddlespeech/vector/io/batch.py @@ -60,7 +60,7 @@ def pad_right_2d(x, target_length, axis=-1, mode='constant', **kwargs): def batch_feature_normalize(batch, mean_norm: bool=True, std_norm: bool=True): - ids = [item['id'] for item in batch] + ids = [item['utt_id'] for item in batch] lengths = np.asarray([item['feat'].shape[1] for item in batch]) feats = list( map(lambda x: pad_right_2d(x, lengths.max()), diff --git a/paddlespeech/vector/io/dataset.py b/paddlespeech/vector/io/dataset.py index 0f87b88ec..e7a8445be 100644 --- a/paddlespeech/vector/io/dataset.py +++ b/paddlespeech/vector/io/dataset.py @@ -16,6 +16,7 @@ from dataclasses import fields from paddle.io import Dataset from paddleaudio import load as load_audio +from paddleaudio.compliance.librosa import melspectrogram from paddlespeech.s2t.utils.log import Log logger = Log(__name__).getlog() @@ -48,19 +49,39 @@ class meta_info: label: str +# csv dataset support feature type +# raw: return the pcm data sample point +# melspectrogram: fbank feature +feat_funcs = { + 'raw': None, + 'melspectrogram': melspectrogram, +} + + class CSVDataset(Dataset): - def __init__(self, csv_path, label2id_path=None, config=None): + def __init__(self, + csv_path, + label2id_path=None, + config=None, + random_chunk=True, + feat_type: str="raw", + **kwargs): """Implement the CSV Dataset Args: csv_path (str): csv dataset file path label2id_path (str): the utterance label to integer id map file path config (CfgNode): yaml config + feat_type (str): dataset feature type. if it is raw, it return pcm data. + kwargs : feature type args """ super().__init__() self.csv_path = csv_path self.label2id_path = label2id_path self.config = config + self.random_chunk = random_chunk + self.feat_type = feat_type + self.feat_config = kwargs self.id2label = {} self.label2id = {} self.data = self.load_data_csv() @@ -128,7 +149,15 @@ class CSVDataset(Dataset): # we only return the waveform as feat waveform = waveform[start:stop] - record.update({'feat': waveform}) + + # all availabel feature type is in feat_funcs + assert self.feat_type in feat_funcs.keys(), \ + f"Unknown feat_type: {self.feat_type}, it must be one in {list(feat_funcs.keys())}" + feat_func = feat_funcs[self.feat_type] + feat = feat_func( + waveform, sr=sr, **self.feat_config) if feat_func else waveform + + record.update({'feat': feat}) if self.label2id: record.update({'label': self.label2id[record['label']]})