test.py update the CSVDataset, test=doc

pull/1630/head
xiongxinlei 3 years ago
parent acebfad7b7
commit ebfe3e6b13

@ -60,7 +60,7 @@ def pad_right_2d(x, target_length, axis=-1, mode='constant', **kwargs):
def batch_feature_normalize(batch, mean_norm: bool=True, std_norm: bool=True):
ids = [item['id'] for item in batch]
ids = [item['utt_id'] for item in batch]
lengths = np.asarray([item['feat'].shape[1] for item in batch])
feats = list(
map(lambda x: pad_right_2d(x, lengths.max()),

@ -16,6 +16,7 @@ from dataclasses import fields
from paddle.io import Dataset
from paddleaudio import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
@ -48,19 +49,39 @@ class meta_info:
label: str
# csv dataset support feature type
# raw: return the pcm data sample point
# melspectrogram: fbank feature
feat_funcs = {
'raw': None,
'melspectrogram': melspectrogram,
}
class CSVDataset(Dataset):
def __init__(self, csv_path, label2id_path=None, config=None):
def __init__(self,
csv_path,
label2id_path=None,
config=None,
random_chunk=True,
feat_type: str="raw",
**kwargs):
"""Implement the CSV Dataset
Args:
csv_path (str): csv dataset file path
label2id_path (str): the utterance label to integer id map file path
config (CfgNode): yaml config
feat_type (str): dataset feature type. if it is raw, it return pcm data.
kwargs : feature type args
"""
super().__init__()
self.csv_path = csv_path
self.label2id_path = label2id_path
self.config = config
self.random_chunk = random_chunk
self.feat_type = feat_type
self.feat_config = kwargs
self.id2label = {}
self.label2id = {}
self.data = self.load_data_csv()
@ -128,7 +149,15 @@ class CSVDataset(Dataset):
# we only return the waveform as feat
waveform = waveform[start:stop]
record.update({'feat': waveform})
# all availabel feature type is in feat_funcs
assert self.feat_type in feat_funcs.keys(), \
f"Unknown feat_type: {self.feat_type}, it must be one in {list(feat_funcs.keys())}"
feat_func = feat_funcs[self.feat_type]
feat = feat_func(
waveform, sr=sr, **self.feat_config) if feat_func else waveform
record.update({'feat': feat})
if self.label2id:
record.update({'label': self.label2id[record['label']]})

Loading…
Cancel
Save