|
|
|
@ -16,6 +16,7 @@ from dataclasses import fields
|
|
|
|
|
from paddle.io import Dataset
|
|
|
|
|
|
|
|
|
|
from paddleaudio import load as load_audio
|
|
|
|
|
from paddleaudio.compliance.librosa import melspectrogram
|
|
|
|
|
from paddlespeech.s2t.utils.log import Log
|
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
|
|
|
|
|
@ -48,19 +49,39 @@ class meta_info:
|
|
|
|
|
label: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# csv dataset support feature type
|
|
|
|
|
# raw: return the pcm data sample point
|
|
|
|
|
# melspectrogram: fbank feature
|
|
|
|
|
feat_funcs = {
|
|
|
|
|
'raw': None,
|
|
|
|
|
'melspectrogram': melspectrogram,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CSVDataset(Dataset):
|
|
|
|
|
def __init__(self, csv_path, label2id_path=None, config=None):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
csv_path,
|
|
|
|
|
label2id_path=None,
|
|
|
|
|
config=None,
|
|
|
|
|
random_chunk=True,
|
|
|
|
|
feat_type: str="raw",
|
|
|
|
|
**kwargs):
|
|
|
|
|
"""Implement the CSV Dataset
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
csv_path (str): csv dataset file path
|
|
|
|
|
label2id_path (str): the utterance label to integer id map file path
|
|
|
|
|
config (CfgNode): yaml config
|
|
|
|
|
feat_type (str): dataset feature type. if it is raw, it return pcm data.
|
|
|
|
|
kwargs : feature type args
|
|
|
|
|
"""
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.csv_path = csv_path
|
|
|
|
|
self.label2id_path = label2id_path
|
|
|
|
|
self.config = config
|
|
|
|
|
self.random_chunk = random_chunk
|
|
|
|
|
self.feat_type = feat_type
|
|
|
|
|
self.feat_config = kwargs
|
|
|
|
|
self.id2label = {}
|
|
|
|
|
self.label2id = {}
|
|
|
|
|
self.data = self.load_data_csv()
|
|
|
|
@ -128,7 +149,15 @@ class CSVDataset(Dataset):
|
|
|
|
|
|
|
|
|
|
# we only return the waveform as feat
|
|
|
|
|
waveform = waveform[start:stop]
|
|
|
|
|
record.update({'feat': waveform})
|
|
|
|
|
|
|
|
|
|
# all availabel feature type is in feat_funcs
|
|
|
|
|
assert self.feat_type in feat_funcs.keys(), \
|
|
|
|
|
f"Unknown feat_type: {self.feat_type}, it must be one in {list(feat_funcs.keys())}"
|
|
|
|
|
feat_func = feat_funcs[self.feat_type]
|
|
|
|
|
feat = feat_func(
|
|
|
|
|
waveform, sr=sr, **self.feat_config) if feat_func else waveform
|
|
|
|
|
|
|
|
|
|
record.update({'feat': feat})
|
|
|
|
|
if self.label2id:
|
|
|
|
|
record.update({'label': self.label2id[record['label']]})
|
|
|
|
|
|
|
|
|
|