change the vector csv.spk_id to csv.label, test=doc

pull/1630/head
xiongxinlei 2 years ago
parent 57c11dcab0
commit acebfad7b7

@ -25,7 +25,7 @@ from yacs.config import CfgNode
from paddleaudio import load as load_audio from paddleaudio import load as load_audio
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.utils.utils import get_chunks from paddlespeech.vector.utils.vector_utils import get_chunks
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
@ -57,7 +57,9 @@ def get_chunks_list(wav_file: str,
end_sample = int(float(e) * sr) end_sample = int(float(e) * sr)
# currently, all vector csv data format use one representation # currently, all vector csv data format use one representation
# id, duration, wav, start, stop, spk_id # id, duration, wav, start, stop, label
# in rirs noise, all the label name is 'noise'
# the label is string type and we will convert it to integer type in training
ret.append([ ret.append([
chunk, audio_duration, wav_file, start_sample, end_sample, chunk, audio_duration, wav_file, start_sample, end_sample,
"noise" "noise"
@ -81,7 +83,7 @@ def generate_csv(wav_files,
split_chunks (bool): audio split flag split_chunks (bool): audio split flag
""" """
logger.info(f'Generating csv: {output_file}') logger.info(f'Generating csv: {output_file}')
header = ["utt_id", "duration", "wav", "start", "stop", "lab_id"] header = ["utt_id", "duration", "wav", "start", "stop", "label"]
csv_lines = [] csv_lines = []
for item in tqdm.tqdm(wav_files): for item in tqdm.tqdm(wav_files):
csv_lines.extend( csv_lines.extend(

@ -26,7 +26,7 @@ from yacs.config import CfgNode
from paddleaudio import load as load_audio from paddleaudio import load as load_audio
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.utils.utils import get_chunks from paddlespeech.vector.utils.vector_utils import get_chunks
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
@ -38,28 +38,31 @@ def prepare_csv(wav_files, output_file, config, split_chunks=True):
wav_files (list): all the audio list to prepare the csv file wav_files (list): all the audio list to prepare the csv file
output_file (str): the output csv file output_file (str): the output csv file
config (CfgNode): yaml configuration content config (CfgNode): yaml configuration content
split_chunks (bool): audio split flag split_chunks (bool, optional): audio split flag. Defaults to True.
""" """
if not os.path.exists(os.path.dirname(output_file)): if not os.path.exists(os.path.dirname(output_file)):
os.makedirs(os.path.dirname(output_file)) os.makedirs(os.path.dirname(output_file))
csv_lines = [] csv_lines = []
header = ["utt_id", "duration", "wav", "start", "stop", "lab_id"] header = ["utt_id", "duration", "wav", "start", "stop", "label"]
# voxceleb meta info for each training utterance segment # voxceleb meta info for each training utterance segment
# we extract a segment from a utterance to train # we extract a segment from a utterance to train
# and the segment' period is between start and stop time point in the original wav file # and the segment' period is between start and stop time point in the original wav file
# each field in the meta means as follows: # each field in the meta info means as follows:
# utt_id: the utterance segment name # utt_id: the utterance segment name, which is uniq in training dataset
# duration: utterance segment time # duration: the total utterance time
# wav: utterance file path # wav: utterance file path, which should be absoulute path
# start: start point in the original wav file # start: start point in the original wav file sample point range
# stop: stop point in the original wav file # stop: stop point in the original wav file sample point range
# lab_id: the utterance segment's speaker name # label: the utterance segment's label name,
# which is speaker name in speaker verification domain
for item in tqdm.tqdm(wav_files, total=len(wav_files)): for item in tqdm.tqdm(wav_files, total=len(wav_files)):
item = json.loads(item.strip()) item = json.loads(item.strip())
audio_id = item['utt'].replace(".wav", "") audio_id = item['utt'].replace(".wav",
"") # we remove the wav suffix name
audio_duration = item['feat_shape'][0] audio_duration = item['feat_shape'][0]
wav_file = item['feat'] wav_file = item['feat']
spk_id = audio_id.split('-')[0] label = audio_id.split('-')[
0] # speaker name in speaker verification domain
waveform, sr = load_audio(wav_file) waveform, sr = load_audio(wav_file)
if split_chunks: if split_chunks:
uniq_chunks_list = get_chunks(config.chunk_duration, audio_id, uniq_chunks_list = get_chunks(config.chunk_duration, audio_id,
@ -68,14 +71,15 @@ def prepare_csv(wav_files, output_file, config, split_chunks=True):
s, e = chunk.split("_")[-2:] # Timestamps of start and end s, e = chunk.split("_")[-2:] # Timestamps of start and end
start_sample = int(float(s) * sr) start_sample = int(float(s) * sr)
end_sample = int(float(e) * sr) end_sample = int(float(e) * sr)
# id, duration, wav, start, stop, spk_id # id, duration, wav, start, stop, label
# in vector, the label in speaker id
csv_lines.append([ csv_lines.append([
chunk, audio_duration, wav_file, start_sample, end_sample, chunk, audio_duration, wav_file, start_sample, end_sample,
spk_id label
]) ])
else: else:
csv_lines.append([ csv_lines.append([
audio_id, audio_duration, wav_file, 0, waveform.shape[0], spk_id audio_id, audio_duration, wav_file, 0, waveform.shape[0], label
]) ])
with open(output_file, mode="w") as csv_f: with open(output_file, mode="w") as csv_f:
@ -113,6 +117,9 @@ def get_enroll_test_list(dataset_list, verification_file):
for dataset in dataset_list: for dataset in dataset_list:
with open(dataset, 'r') as f: with open(dataset, 'r') as f:
for line in f: for line in f:
# audio_id may be in enroll and test at the same time
# eg: 1 a.wav a.wav
# the audio a.wav is enroll and test file at the same time
audio_id = json.loads(line.strip())['utt'] audio_id = json.loads(line.strip())['utt']
if audio_id in enroll_audios: if audio_id in enroll_audios:
enroll_files.append(line) enroll_files.append(line)
@ -145,17 +152,18 @@ def get_train_dev_list(dataset_list, target_dir, split_ratio):
for dataset in dataset_list: for dataset in dataset_list:
with open(dataset, 'r') as f: with open(dataset, 'r') as f:
for line in f: for line in f:
spk_id = json.loads(line.strip())['utt2spk'] # the label is speaker name
speakers.add(spk_id) label_name = json.loads(line.strip())['utt2spk']
speakers.add(label_name)
audio_files.append(line.strip()) audio_files.append(line.strip())
speakers = sorted(speakers) speakers = sorted(speakers)
logger.info(f"we get {len(speakers)} speakers from all the train dataset") logger.info(f"we get {len(speakers)} speakers from all the train dataset")
with open(os.path.join(target_dir, "meta", "spk_id2label.txt"), 'w') as f: with open(os.path.join(target_dir, "meta", "label2id.txt"), 'w') as f:
for label, spk_id in enumerate(speakers): for label_id, label_name in enumerate(speakers):
f.write(f'{spk_id} {label}\n') f.write(f'{label_name} {label_id}\n')
logger.info( logger.info(
f'we store the speakers to {os.path.join(target_dir, "meta", "spk_id2label.txt")}' f'we store the speakers to {os.path.join(target_dir, "meta", "label2id.txt")}'
) )
# the split_ratio is for train dataset # the split_ratio is for train dataset
@ -185,7 +193,7 @@ def prepare_data(args, config):
return return
# stage 1: prepare the enroll and test csv file # stage 1: prepare the enroll and test csv file
# And we generate the speaker to label file spk_id2label.txt # And we generate the speaker to label file label2id.txt
logger.info("start to prepare the data csv file") logger.info("start to prepare the data csv file")
enroll_files, test_files = get_enroll_test_list( enroll_files, test_files = get_enroll_test_list(
[args.test], verification_file=config.verification_file) [args.test], verification_file=config.verification_file)

@ -21,10 +21,10 @@ from paddle.io import DataLoader
from tqdm import tqdm from tqdm import tqdm
from yacs.config import CfgNode from yacs.config import CfgNode
from paddleaudio.datasets import VoxCeleb
from paddleaudio.metric import compute_eer from paddleaudio.metric import compute_eer
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.batch import batch_feature_normalize from paddlespeech.vector.io.batch import batch_feature_normalize
from paddlespeech.vector.io.dataset import CSVDataset
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.modules.sid_model import SpeakerIdetification from paddlespeech.vector.modules.sid_model import SpeakerIdetification
from paddlespeech.vector.training.seeding import seed_everything from paddlespeech.vector.training.seeding import seed_everything
@ -58,9 +58,8 @@ def main(args, config):
# stage4: construct the enroll and test dataloader # stage4: construct the enroll and test dataloader
enroll_dataset = VoxCeleb( enroll_dataset = CSVDataset(
subset='enroll', os.path.join(args.data_dir, "vox/csv/enroll.csv"),
target_dir=args.data_dir,
feat_type='melspectrogram', feat_type='melspectrogram',
random_chunk=False, random_chunk=False,
n_mels=config.n_mels, n_mels=config.n_mels,
@ -69,15 +68,14 @@ def main(args, config):
enroll_sampler = BatchSampler( enroll_sampler = BatchSampler(
enroll_dataset, batch_size=config.batch_size, enroll_dataset, batch_size=config.batch_size,
shuffle=True) # Shuffle to make embedding normalization more robust. shuffle=True) # Shuffle to make embedding normalization more robust.
enrol_loader = DataLoader(enroll_dataset, enroll_loader = DataLoader(enroll_dataset,
batch_sampler=enroll_sampler, batch_sampler=enroll_sampler,
collate_fn=lambda x: batch_feature_normalize( collate_fn=lambda x: batch_feature_normalize(
x, mean_norm=True, std_norm=False), x, mean_norm=True, std_norm=False),
num_workers=config.num_workers, num_workers=config.num_workers,
return_list=True,) return_list=True,)
test_dataset = VoxCeleb( test_dataset = CSVDataset(
subset='test', os.path.join(args.data_dir, "vox/csv/test.csv"),
target_dir=args.data_dir,
feat_type='melspectrogram', feat_type='melspectrogram',
random_chunk=False, random_chunk=False,
n_mels=config.n_mels, n_mels=config.n_mels,
@ -108,9 +106,9 @@ def main(args, config):
id2embedding = {} id2embedding = {}
# Run multi times to make embedding normalization more stable. # Run multi times to make embedding normalization more stable.
for i in range(2): for i in range(2):
for dl in [enrol_loader, test_loader]: for dl in [enroll_loader, test_loader]:
logger.info( logger.info(
f'Loop {[i+1]}: Computing embeddings on {dl.dataset.subset} dataset' f'Loop {[i+1]}: Computing embeddings on {dl.dataset.csv_path} dataset'
) )
with paddle.no_grad(): with paddle.no_grad():
for batch_idx, batch in enumerate(tqdm(dl)): for batch_idx, batch in enumerate(tqdm(dl)):
@ -152,8 +150,8 @@ def main(args, config):
labels = [] labels = []
enroll_ids = [] enroll_ids = []
test_ids = [] test_ids = []
logger.info(f"read the trial from {VoxCeleb.veri_test_file}") logger.info(f"read the trial from {config.verification_file}")
with open(VoxCeleb.veri_test_file, 'r') as f: with open(config.verification_file, 'r') as f:
for line in f.readlines(): for line in f.readlines():
label, enroll_id, test_id = line.strip().split(' ') label, enroll_id, test_id = line.strip().split(' ')
labels.append(int(label)) labels.append(int(label))

@ -57,12 +57,10 @@ def main(args, config):
# note: some cmd must do in rank==0, so wo will refactor the data prepare code # note: some cmd must do in rank==0, so wo will refactor the data prepare code
train_dataset = CSVDataset( train_dataset = CSVDataset(
csv_path=os.path.join(args.data_dir, "vox/csv/train.csv"), csv_path=os.path.join(args.data_dir, "vox/csv/train.csv"),
spk_id2label_path=os.path.join(args.data_dir, label2id_path=os.path.join(args.data_dir, "vox/meta/label2id.txt"))
"vox/meta/spk_id2label.txt"))
dev_dataset = CSVDataset( dev_dataset = CSVDataset(
csv_path=os.path.join(args.data_dir, "vox/csv/dev.csv"), csv_path=os.path.join(args.data_dir, "vox/csv/dev.csv"),
spk_id2label_path=os.path.join(args.data_dir, label2id_path=os.path.join(args.data_dir, "vox/meta/label2id.txt"))
"vox/meta/spk_id2label.txt"))
if config.augment: if config.augment:
augment_pipeline = build_augment_pipeline(target_dir=args.data_dir) augment_pipeline = build_augment_pipeline(target_dir=args.data_dir)
@ -148,7 +146,6 @@ def main(args, config):
train_reader_cost = 0.0 train_reader_cost = 0.0
train_feat_cost = 0.0 train_feat_cost = 0.0
train_run_cost = 0.0 train_run_cost = 0.0
train_misce_cost = 0.0
reader_start = time.time() reader_start = time.time()
for batch_idx, batch in enumerate(train_loader): for batch_idx, batch in enumerate(train_loader):

@ -25,7 +25,7 @@ logger = Log(__name__).getlog()
# wav: utterance file path # wav: utterance file path
# start: start point in the original wav file # start: start point in the original wav file
# stop: stop point in the original wav file # stop: stop point in the original wav file
# lab_id: the utterance segment's label id # label: the utterance segment's label id
@dataclass @dataclass
@ -45,24 +45,24 @@ class meta_info:
wav: str wav: str
start: int start: int
stop: int stop: int
lab_id: str label: str
class CSVDataset(Dataset): class CSVDataset(Dataset):
def __init__(self, csv_path, spk_id2label_path=None, config=None): def __init__(self, csv_path, label2id_path=None, config=None):
"""Implement the CSV Dataset """Implement the CSV Dataset
Args: Args:
csv_path (str): csv dataset file path csv_path (str): csv dataset file path
spk_id2label_path (str): the utterance label to integer id map file path label2id_path (str): the utterance label to integer id map file path
config (CfgNode): yaml config config (CfgNode): yaml config
""" """
super().__init__() super().__init__()
self.csv_path = csv_path self.csv_path = csv_path
self.spk_id2label_path = spk_id2label_path self.label2id_path = label2id_path
self.config = config self.config = config
self.spk_id2label = {} self.id2label = {}
self.label2spk_id = {} self.label2id = {}
self.data = self.load_data_csv() self.data = self.load_data_csv()
self.load_speaker_to_label() self.load_speaker_to_label()
@ -71,7 +71,7 @@ class CSVDataset(Dataset):
the csv dataset's format has six fields, the csv dataset's format has six fields,
that is audio_id or utt_id, audio duration, segment start point, segment stop point that is audio_id or utt_id, audio duration, segment start point, segment stop point
and utterance label. and utterance label.
Note in training period, the utterance label must has a map to integer id in spk_id2label_path Note in training period, the utterance label must has a map to integer id in label2id_path
""" """
data = [] data = []
@ -91,16 +91,15 @@ class CSVDataset(Dataset):
The speaker label is real speaker label in speaker verification domain, The speaker label is real speaker label in speaker verification domain,
and in language identification is language label. and in language identification is language label.
""" """
if not self.spk_id2label_path: if not self.label2id_path:
logger.warning("No speaker id to label file") logger.warning("No speaker id to label file")
return return
self.spk_id2label = {}
self.label2spk_id = {} with open(self.label2id_path, 'r') as f:
with open(self.spk_id2label_path, 'r') as f:
for line in f.readlines(): for line in f.readlines():
spk_id, label = line.strip().split(' ') label_name, label_id = line.strip().split(' ')
self.spk_id2label[spk_id] = int(label) self.label2id[label_name] = int(label_id)
self.label2spk_id[int(label)] = spk_id self.id2label[int(label_id)] = label_name
def convert_to_record(self, idx: int): def convert_to_record(self, idx: int):
"""convert the dataset sample to training record the CSV Dataset """convert the dataset sample to training record the CSV Dataset
@ -130,8 +129,8 @@ class CSVDataset(Dataset):
# we only return the waveform as feat # we only return the waveform as feat
waveform = waveform[start:stop] waveform = waveform[start:stop]
record.update({'feat': waveform}) record.update({'feat': waveform})
if self.spk_id2label: if self.label2id:
record.update({'label': self.spk_id2label[record['lab_id']]}) record.update({'label': self.label2id[record['label']]})
return record return record

@ -17,14 +17,14 @@ def get_chunks(seg_dur, audio_id, audio_duration):
"""Get all chunk segments from a utterance """Get all chunk segments from a utterance
Args: Args:
seg_dur (float): segment chunk duration seg_dur (float): segment chunk duration, seconds
audio_id (str): utterance name audio_id (str): utterance name,
audio_duration (float): utterance duration audio_duration (float): utterance duration, seconds
Returns: Returns:
List: all the chunk segments List: all the chunk segments
""" """
num_chunks = int(audio_duration / seg_dur) # all in milliseconds num_chunks = int(audio_duration / seg_dur) # all in seconds
chunk_lst = [ chunk_lst = [
audio_id + "_" + str(i * seg_dur) + "_" + str(i * seg_dur + seg_dur) audio_id + "_" + str(i * seg_dur) + "_" + str(i * seg_dur + seg_dur)
for i in range(num_chunks) for i in range(num_chunks)
Loading…
Cancel
Save