change the vector csv.spk_id to csv.label, test=doc

pull/1630/head
xiongxinlei 2 years ago
parent 57c11dcab0
commit acebfad7b7

@ -25,7 +25,7 @@ from yacs.config import CfgNode
from paddleaudio import load as load_audio
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.utils.utils import get_chunks
from paddlespeech.vector.utils.vector_utils import get_chunks
logger = Log(__name__).getlog()
@ -57,7 +57,9 @@ def get_chunks_list(wav_file: str,
end_sample = int(float(e) * sr)
# currently, all vector csv data format use one representation
# id, duration, wav, start, stop, spk_id
# id, duration, wav, start, stop, label
# in rirs noise, all the label name is 'noise'
# the label is string type and we will convert it to integer type in training
ret.append([
chunk, audio_duration, wav_file, start_sample, end_sample,
"noise"
@ -81,7 +83,7 @@ def generate_csv(wav_files,
split_chunks (bool): audio split flag
"""
logger.info(f'Generating csv: {output_file}')
header = ["utt_id", "duration", "wav", "start", "stop", "lab_id"]
header = ["utt_id", "duration", "wav", "start", "stop", "label"]
csv_lines = []
for item in tqdm.tqdm(wav_files):
csv_lines.extend(

@ -26,7 +26,7 @@ from yacs.config import CfgNode
from paddleaudio import load as load_audio
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.utils.utils import get_chunks
from paddlespeech.vector.utils.vector_utils import get_chunks
logger = Log(__name__).getlog()
@ -38,28 +38,31 @@ def prepare_csv(wav_files, output_file, config, split_chunks=True):
wav_files (list): all the audio list to prepare the csv file
output_file (str): the output csv file
config (CfgNode): yaml configuration content
split_chunks (bool): audio split flag
split_chunks (bool, optional): audio split flag. Defaults to True.
"""
if not os.path.exists(os.path.dirname(output_file)):
os.makedirs(os.path.dirname(output_file))
csv_lines = []
header = ["utt_id", "duration", "wav", "start", "stop", "lab_id"]
header = ["utt_id", "duration", "wav", "start", "stop", "label"]
# voxceleb meta info for each training utterance segment
# we extract a segment from a utterance to train
# and the segment' period is between start and stop time point in the original wav file
# each field in the meta means as follows:
# utt_id: the utterance segment name
# duration: utterance segment time
# wav: utterance file path
# start: start point in the original wav file
# stop: stop point in the original wav file
# lab_id: the utterance segment's speaker name
# each field in the meta info means as follows:
# utt_id: the utterance segment name, which is uniq in training dataset
# duration: the total utterance time
# wav: utterance file path, which should be absoulute path
# start: start point in the original wav file sample point range
# stop: stop point in the original wav file sample point range
# label: the utterance segment's label name,
# which is speaker name in speaker verification domain
for item in tqdm.tqdm(wav_files, total=len(wav_files)):
item = json.loads(item.strip())
audio_id = item['utt'].replace(".wav", "")
audio_id = item['utt'].replace(".wav",
"") # we remove the wav suffix name
audio_duration = item['feat_shape'][0]
wav_file = item['feat']
spk_id = audio_id.split('-')[0]
label = audio_id.split('-')[
0] # speaker name in speaker verification domain
waveform, sr = load_audio(wav_file)
if split_chunks:
uniq_chunks_list = get_chunks(config.chunk_duration, audio_id,
@ -68,14 +71,15 @@ def prepare_csv(wav_files, output_file, config, split_chunks=True):
s, e = chunk.split("_")[-2:] # Timestamps of start and end
start_sample = int(float(s) * sr)
end_sample = int(float(e) * sr)
# id, duration, wav, start, stop, spk_id
# id, duration, wav, start, stop, label
# in vector, the label in speaker id
csv_lines.append([
chunk, audio_duration, wav_file, start_sample, end_sample,
spk_id
label
])
else:
csv_lines.append([
audio_id, audio_duration, wav_file, 0, waveform.shape[0], spk_id
audio_id, audio_duration, wav_file, 0, waveform.shape[0], label
])
with open(output_file, mode="w") as csv_f:
@ -113,6 +117,9 @@ def get_enroll_test_list(dataset_list, verification_file):
for dataset in dataset_list:
with open(dataset, 'r') as f:
for line in f:
# audio_id may be in enroll and test at the same time
# eg: 1 a.wav a.wav
# the audio a.wav is enroll and test file at the same time
audio_id = json.loads(line.strip())['utt']
if audio_id in enroll_audios:
enroll_files.append(line)
@ -145,17 +152,18 @@ def get_train_dev_list(dataset_list, target_dir, split_ratio):
for dataset in dataset_list:
with open(dataset, 'r') as f:
for line in f:
spk_id = json.loads(line.strip())['utt2spk']
speakers.add(spk_id)
# the label is speaker name
label_name = json.loads(line.strip())['utt2spk']
speakers.add(label_name)
audio_files.append(line.strip())
speakers = sorted(speakers)
logger.info(f"we get {len(speakers)} speakers from all the train dataset")
with open(os.path.join(target_dir, "meta", "spk_id2label.txt"), 'w') as f:
for label, spk_id in enumerate(speakers):
f.write(f'{spk_id} {label}\n')
with open(os.path.join(target_dir, "meta", "label2id.txt"), 'w') as f:
for label_id, label_name in enumerate(speakers):
f.write(f'{label_name} {label_id}\n')
logger.info(
f'we store the speakers to {os.path.join(target_dir, "meta", "spk_id2label.txt")}'
f'we store the speakers to {os.path.join(target_dir, "meta", "label2id.txt")}'
)
# the split_ratio is for train dataset
@ -185,7 +193,7 @@ def prepare_data(args, config):
return
# stage 1: prepare the enroll and test csv file
# And we generate the speaker to label file spk_id2label.txt
# And we generate the speaker to label file label2id.txt
logger.info("start to prepare the data csv file")
enroll_files, test_files = get_enroll_test_list(
[args.test], verification_file=config.verification_file)

@ -21,10 +21,10 @@ from paddle.io import DataLoader
from tqdm import tqdm
from yacs.config import CfgNode
from paddleaudio.datasets import VoxCeleb
from paddleaudio.metric import compute_eer
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.batch import batch_feature_normalize
from paddlespeech.vector.io.dataset import CSVDataset
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
from paddlespeech.vector.training.seeding import seed_everything
@ -58,9 +58,8 @@ def main(args, config):
# stage4: construct the enroll and test dataloader
enroll_dataset = VoxCeleb(
subset='enroll',
target_dir=args.data_dir,
enroll_dataset = CSVDataset(
os.path.join(args.data_dir, "vox/csv/enroll.csv"),
feat_type='melspectrogram',
random_chunk=False,
n_mels=config.n_mels,
@ -69,15 +68,14 @@ def main(args, config):
enroll_sampler = BatchSampler(
enroll_dataset, batch_size=config.batch_size,
shuffle=True) # Shuffle to make embedding normalization more robust.
enrol_loader = DataLoader(enroll_dataset,
enroll_loader = DataLoader(enroll_dataset,
batch_sampler=enroll_sampler,
collate_fn=lambda x: batch_feature_normalize(
x, mean_norm=True, std_norm=False),
x, mean_norm=True, std_norm=False),
num_workers=config.num_workers,
return_list=True,)
test_dataset = VoxCeleb(
subset='test',
target_dir=args.data_dir,
test_dataset = CSVDataset(
os.path.join(args.data_dir, "vox/csv/test.csv"),
feat_type='melspectrogram',
random_chunk=False,
n_mels=config.n_mels,
@ -108,9 +106,9 @@ def main(args, config):
id2embedding = {}
# Run multi times to make embedding normalization more stable.
for i in range(2):
for dl in [enrol_loader, test_loader]:
for dl in [enroll_loader, test_loader]:
logger.info(
f'Loop {[i+1]}: Computing embeddings on {dl.dataset.subset} dataset'
f'Loop {[i+1]}: Computing embeddings on {dl.dataset.csv_path} dataset'
)
with paddle.no_grad():
for batch_idx, batch in enumerate(tqdm(dl)):
@ -152,8 +150,8 @@ def main(args, config):
labels = []
enroll_ids = []
test_ids = []
logger.info(f"read the trial from {VoxCeleb.veri_test_file}")
with open(VoxCeleb.veri_test_file, 'r') as f:
logger.info(f"read the trial from {config.verification_file}")
with open(config.verification_file, 'r') as f:
for line in f.readlines():
label, enroll_id, test_id = line.strip().split(' ')
labels.append(int(label))

@ -57,12 +57,10 @@ def main(args, config):
# note: some cmd must do in rank==0, so wo will refactor the data prepare code
train_dataset = CSVDataset(
csv_path=os.path.join(args.data_dir, "vox/csv/train.csv"),
spk_id2label_path=os.path.join(args.data_dir,
"vox/meta/spk_id2label.txt"))
label2id_path=os.path.join(args.data_dir, "vox/meta/label2id.txt"))
dev_dataset = CSVDataset(
csv_path=os.path.join(args.data_dir, "vox/csv/dev.csv"),
spk_id2label_path=os.path.join(args.data_dir,
"vox/meta/spk_id2label.txt"))
label2id_path=os.path.join(args.data_dir, "vox/meta/label2id.txt"))
if config.augment:
augment_pipeline = build_augment_pipeline(target_dir=args.data_dir)
@ -148,7 +146,6 @@ def main(args, config):
train_reader_cost = 0.0
train_feat_cost = 0.0
train_run_cost = 0.0
train_misce_cost = 0.0
reader_start = time.time()
for batch_idx, batch in enumerate(train_loader):

@ -25,7 +25,7 @@ logger = Log(__name__).getlog()
# wav: utterance file path
# start: start point in the original wav file
# stop: stop point in the original wav file
# lab_id: the utterance segment's label id
# label: the utterance segment's label id
@dataclass
@ -45,24 +45,24 @@ class meta_info:
wav: str
start: int
stop: int
lab_id: str
label: str
class CSVDataset(Dataset):
def __init__(self, csv_path, spk_id2label_path=None, config=None):
def __init__(self, csv_path, label2id_path=None, config=None):
"""Implement the CSV Dataset
Args:
csv_path (str): csv dataset file path
spk_id2label_path (str): the utterance label to integer id map file path
label2id_path (str): the utterance label to integer id map file path
config (CfgNode): yaml config
"""
super().__init__()
self.csv_path = csv_path
self.spk_id2label_path = spk_id2label_path
self.label2id_path = label2id_path
self.config = config
self.spk_id2label = {}
self.label2spk_id = {}
self.id2label = {}
self.label2id = {}
self.data = self.load_data_csv()
self.load_speaker_to_label()
@ -71,7 +71,7 @@ class CSVDataset(Dataset):
the csv dataset's format has six fields,
that is audio_id or utt_id, audio duration, segment start point, segment stop point
and utterance label.
Note in training period, the utterance label must has a map to integer id in spk_id2label_path
Note in training period, the utterance label must has a map to integer id in label2id_path
"""
data = []
@ -91,16 +91,15 @@ class CSVDataset(Dataset):
The speaker label is real speaker label in speaker verification domain,
and in language identification is language label.
"""
if not self.spk_id2label_path:
if not self.label2id_path:
logger.warning("No speaker id to label file")
return
self.spk_id2label = {}
self.label2spk_id = {}
with open(self.spk_id2label_path, 'r') as f:
with open(self.label2id_path, 'r') as f:
for line in f.readlines():
spk_id, label = line.strip().split(' ')
self.spk_id2label[spk_id] = int(label)
self.label2spk_id[int(label)] = spk_id
label_name, label_id = line.strip().split(' ')
self.label2id[label_name] = int(label_id)
self.id2label[int(label_id)] = label_name
def convert_to_record(self, idx: int):
"""convert the dataset sample to training record the CSV Dataset
@ -130,8 +129,8 @@ class CSVDataset(Dataset):
# we only return the waveform as feat
waveform = waveform[start:stop]
record.update({'feat': waveform})
if self.spk_id2label:
record.update({'label': self.spk_id2label[record['lab_id']]})
if self.label2id:
record.update({'label': self.label2id[record['label']]})
return record

@ -17,14 +17,14 @@ def get_chunks(seg_dur, audio_id, audio_duration):
"""Get all chunk segments from a utterance
Args:
seg_dur (float): segment chunk duration
audio_id (str): utterance name
audio_duration (float): utterance duration
seg_dur (float): segment chunk duration, seconds
audio_id (str): utterance name,
audio_duration (float): utterance duration, seconds
Returns:
List: all the chunk segments
"""
num_chunks = int(audio_duration / seg_dur) # all in milliseconds
num_chunks = int(audio_duration / seg_dur) # all in seconds
chunk_lst = [
audio_id + "_" + str(i * seg_dur) + "_" + str(i * seg_dur + seg_dur)
for i in range(num_chunks)
Loading…
Cancel
Save