194 lines
6.6 KiB
194 lines
6.6 KiB
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from dataclasses import fields
from paddle.io import Dataset
from paddlespeech.audio import load as load_audio
from paddlespeech.audio.compliance.librosa import melspectrogram
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
# the audio meta info in the vector CSVDataset
# utt_id: the utterance segment name
# duration: utterance segment time
# wav: utterance file path
# start: start point in the original wav file
# stop: stop point in the original wav file
# label: the utterance segment's label id
class meta_info:
"""the audio meta info in the vector CSVDataset
utt_id (str): the utterance segment name
duration (float): utterance segment time
wav (str): utterance file path
start (int): start point in the original wav file
stop (int): stop point in the original wav file
lab_id (str): the utterance segment's label id
utt_id: str
duration: float
wav: str
start: int
stop: int
label: str
# csv dataset support feature type
# raw: return the pcm data sample point
# melspectrogram: fbank feature
feat_funcs = {
'raw': None,
'melspectrogram': melspectrogram,
class CSVDataset(Dataset):
def __init__(self,
feat_type: str="raw",
n_train_snts: int=-1,
"""Implement the CSV Dataset
csv_path (str): csv dataset file path
label2id_path (str): the utterance label to integer id map file path
config (CfgNode): yaml config
feat_type (str): dataset feature type. if it is raw, it return pcm data.
n_train_snts (int): select the n_train_snts sample from the dataset.
if n_train_snts = -1, dataset will load all the sample.
Default value is -1.
kwargs : feature type args
self.csv_path = csv_path
self.label2id_path = label2id_path
self.config = config
self.random_chunk = random_chunk
self.feat_type = feat_type
self.n_train_snts = n_train_snts
self.feat_config = kwargs
self.id2label = {}
self.label2id = {}
self.data = self.load_data_csv()
def load_data_csv(self):
"""Load the csv dataset content and store them in the data property
the csv dataset's format has six fields,
that is audio_id or utt_id, audio duration, segment start point, segment stop point
and utterance label.
Note in training period, the utterance label must has a map to integer id in label2id_path
list: the csv data with meta_info type
data = []
with open(self.csv_path, 'r') as rf:
for line in rf.readlines()[1:]:
audio_id, duration, wav, start, stop, spk_id = line.strip(
float(duration), wav,
int(start), int(stop), spk_id))
if self.n_train_snts > 0:
sample_num = min(self.n_train_snts, len(data))
data = data[0:sample_num]
return data
def load_speaker_to_label(self):
"""Load the utterance label map content.
In vector domain, we call the utterance label as speaker label.
The speaker label is real speaker label in speaker verification domain,
and in language identification is language label.
if not self.label2id_path:
logger.warning("No speaker id to label file")
with open(self.label2id_path, 'r') as f:
for line in f.readlines():
label_name, label_id = line.strip().split(' ')
self.label2id[label_name] = int(label_id)
self.id2label[int(label_id)] = label_name
def convert_to_record(self, idx: int):
"""convert the dataset sample to training record the CSV Dataset
idx (int) : the request index in all the dataset
sample = self.data[idx]
record = {}
# To show all fields in a namedtuple: `type(sample)._fields`
for field in fields(sample):
record[field.name] = getattr(sample, field.name)
waveform, sr = load_audio(record['wav'])
# random select a chunk audio samples from the audio
if self.config and self.config.random_chunk:
num_wav_samples = waveform.shape[0]
num_chunk_samples = int(self.config.chunk_duration * sr)
start = random.randint(0, num_wav_samples - num_chunk_samples - 1)
stop = start + num_chunk_samples
start = record['start']
stop = record['stop']
# we only return the waveform as feat
waveform = waveform[start:stop]
# all availabel feature type is in feat_funcs
assert self.feat_type in feat_funcs.keys(), \
f"Unknown feat_type: {self.feat_type}, it must be one in {list(feat_funcs.keys())}"
feat_func = feat_funcs[self.feat_type]
feat = feat_func(
waveform, sr=sr, **self.feat_config) if feat_func else waveform
record.update({'feat': feat})
if self.label2id:
record.update({'label': self.label2id[record['label']]})
return record
def __getitem__(self, idx):
"""Return the specific index sample
idx (int) : the request index in all the dataset
return self.convert_to_record(idx)
def __len__(self):
"""Return the dataset length
int: the length num of the dataset
return len(self.data)