add some annotations, test=doc

pull/1630/head
xiongxinlei 2 years ago
parent 30b5b3cb9e
commit 57c11dcab0

@ -53,7 +53,7 @@ def prepare_csv(wav_files, output_file, config, split_chunks=True):
# wav: utterance file path
# start: start point in the original wav file
# stop: stop point in the original wav file
# spk_id: the utterance segment's speaker name
# lab_id: the utterance segment's speaker name
for item in tqdm.tqdm(wav_files, total=len(wav_files)):
item = json.loads(item.strip())
audio_id = item['utt'].replace(".wav", "")

@ -30,6 +30,16 @@ logger = Log(__name__).getlog()
@dataclass
class meta_info:
"""the audio meta info in the vector CSVDataset
Args:
utt_id (str): the utterance segment name
duration (float): utterance segment time
wav (str): utterance file path
start (int): start point in the original wav file
stop (int): stop point in the original wav file
lab_id (str): the utterance segment's label id
"""
utt_id: str
duration: float
wav: str
@ -39,18 +49,30 @@ class meta_info:
class CSVDataset(Dataset):
# meta_info = collections.namedtuple(
# 'META_INFO', ('id', 'duration', 'wav', 'start', 'stop', 'spk_id'))
def __init__(self, csv_path, spk_id2label_path=None, config=None):
"""Implement the CSV Dataset
Args:
csv_path (str): csv dataset file path
spk_id2label_path (str): the utterance label to integer id map file path
config (CfgNode): yaml config
"""
super().__init__()
self.csv_path = csv_path
self.spk_id2label_path = spk_id2label_path
self.config = config
self.spk_id2label = {}
self.label2spk_id = {}
self.data = self.load_data_csv()
self.spk_id2label = self.load_speaker_to_label()
self.load_speaker_to_label()
def load_data_csv(self):
"""Load the csv dataset content and store them in the data property
the csv dataset's format has six fields,
that is audio_id or utt_id, audio duration, segment start point, segment stop point
and utterance label.
Note in training period, the utterance label must has a map to integer id in spk_id2label_path
"""
data = []
with open(self.csv_path, 'r') as rf:
@ -64,18 +86,28 @@ class CSVDataset(Dataset):
return data
def load_speaker_to_label(self):
"""Load the utterance label map content.
In vector domain, we call the utterance label as speaker label.
The speaker label is real speaker label in speaker verification domain,
and in language identification is language label.
"""
if not self.spk_id2label_path:
logger.warning("No speaker id to label file")
return
spk_id2label = {}
self.spk_id2label = {}
self.label2spk_id = {}
with open(self.spk_id2label_path, 'r') as f:
for line in f.readlines():
spk_id, label = line.strip().split(' ')
spk_id2label[spk_id] = int(label)
return spk_id2label
self.spk_id2label[spk_id] = int(label)
self.label2spk_id[int(label)] = spk_id
def convert_to_record(self, idx: int):
"""convert the dataset sample to training record the CSV Dataset
Args:
idx (int) : the request index in all the dataset
"""
sample = self.data[idx]
record = {}
@ -104,7 +136,14 @@ class CSVDataset(Dataset):
return record
def __getitem__(self, idx):
"""Return the specific index sample
Args:
idx (int) : the request index in all the dataset
"""
return self.convert_to_record(idx)
def __len__(self):
"""Return the dataset length
"""
return len(self.data)

Loading…
Cancel
Save