From acebfad7b7fcd007c8e6e27e31d372490e226fea Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Sun, 3 Apr 2022 15:30:55 +0800 Subject: [PATCH] change the vector csv.spk_id to csv.label, test=doc --- .../make_rirs_noise_csv_dataset_from_json.py | 8 +-- .../local/make_vox_csv_dataset_from_json.py | 52 +++++++++++-------- paddlespeech/vector/exps/ecapa_tdnn/test.py | 24 ++++----- paddlespeech/vector/exps/ecapa_tdnn/train.py | 7 +-- paddlespeech/vector/io/dataset.py | 33 ++++++------ .../utils/{utils.py => vector_utils.py} | 8 +-- 6 files changed, 68 insertions(+), 64 deletions(-) rename paddlespeech/vector/utils/{utils.py => vector_utils.py} (80%) diff --git a/examples/voxceleb/sv0/local/make_rirs_noise_csv_dataset_from_json.py b/examples/voxceleb/sv0/local/make_rirs_noise_csv_dataset_from_json.py index 26015aed..b25a9d49 100644 --- a/examples/voxceleb/sv0/local/make_rirs_noise_csv_dataset_from_json.py +++ b/examples/voxceleb/sv0/local/make_rirs_noise_csv_dataset_from_json.py @@ -25,7 +25,7 @@ from yacs.config import CfgNode from paddleaudio import load as load_audio from paddlespeech.s2t.utils.log import Log -from paddlespeech.vector.utils.utils import get_chunks +from paddlespeech.vector.utils.vector_utils import get_chunks logger = Log(__name__).getlog() @@ -57,7 +57,9 @@ def get_chunks_list(wav_file: str, end_sample = int(float(e) * sr) # currently, all vector csv data format use one representation - # id, duration, wav, start, stop, spk_id + # id, duration, wav, start, stop, label + # in rirs noise, all the label name is 'noise' + # the label is string type and we will convert it to integer type in training ret.append([ chunk, audio_duration, wav_file, start_sample, end_sample, "noise" @@ -81,7 +83,7 @@ def generate_csv(wav_files, split_chunks (bool): audio split flag """ logger.info(f'Generating csv: {output_file}') - header = ["utt_id", "duration", "wav", "start", "stop", "lab_id"] + header = ["utt_id", "duration", "wav", "start", "stop", "label"] csv_lines = [] for item in tqdm.tqdm(wav_files): csv_lines.extend( diff --git a/examples/voxceleb/sv0/local/make_vox_csv_dataset_from_json.py b/examples/voxceleb/sv0/local/make_vox_csv_dataset_from_json.py index 6c33aba5..4e64c306 100644 --- a/examples/voxceleb/sv0/local/make_vox_csv_dataset_from_json.py +++ b/examples/voxceleb/sv0/local/make_vox_csv_dataset_from_json.py @@ -26,7 +26,7 @@ from yacs.config import CfgNode from paddleaudio import load as load_audio from paddlespeech.s2t.utils.log import Log -from paddlespeech.vector.utils.utils import get_chunks +from paddlespeech.vector.utils.vector_utils import get_chunks logger = Log(__name__).getlog() @@ -38,28 +38,31 @@ def prepare_csv(wav_files, output_file, config, split_chunks=True): wav_files (list): all the audio list to prepare the csv file output_file (str): the output csv file config (CfgNode): yaml configuration content - split_chunks (bool): audio split flag + split_chunks (bool, optional): audio split flag. Defaults to True. """ if not os.path.exists(os.path.dirname(output_file)): os.makedirs(os.path.dirname(output_file)) csv_lines = [] - header = ["utt_id", "duration", "wav", "start", "stop", "lab_id"] + header = ["utt_id", "duration", "wav", "start", "stop", "label"] # voxceleb meta info for each training utterance segment # we extract a segment from a utterance to train # and the segment' period is between start and stop time point in the original wav file - # each field in the meta means as follows: - # utt_id: the utterance segment name - # duration: utterance segment time - # wav: utterance file path - # start: start point in the original wav file - # stop: stop point in the original wav file - # lab_id: the utterance segment's speaker name + # each field in the meta info means as follows: + # utt_id: the utterance segment name, which is uniq in training dataset + # duration: the total utterance time + # wav: utterance file path, which should be absoulute path + # start: start point in the original wav file sample point range + # stop: stop point in the original wav file sample point range + # label: the utterance segment's label name, + # which is speaker name in speaker verification domain for item in tqdm.tqdm(wav_files, total=len(wav_files)): item = json.loads(item.strip()) - audio_id = item['utt'].replace(".wav", "") + audio_id = item['utt'].replace(".wav", + "") # we remove the wav suffix name audio_duration = item['feat_shape'][0] wav_file = item['feat'] - spk_id = audio_id.split('-')[0] + label = audio_id.split('-')[ + 0] # speaker name in speaker verification domain waveform, sr = load_audio(wav_file) if split_chunks: uniq_chunks_list = get_chunks(config.chunk_duration, audio_id, @@ -68,14 +71,15 @@ def prepare_csv(wav_files, output_file, config, split_chunks=True): s, e = chunk.split("_")[-2:] # Timestamps of start and end start_sample = int(float(s) * sr) end_sample = int(float(e) * sr) - # id, duration, wav, start, stop, spk_id + # id, duration, wav, start, stop, label + # in vector, the label in speaker id csv_lines.append([ chunk, audio_duration, wav_file, start_sample, end_sample, - spk_id + label ]) else: csv_lines.append([ - audio_id, audio_duration, wav_file, 0, waveform.shape[0], spk_id + audio_id, audio_duration, wav_file, 0, waveform.shape[0], label ]) with open(output_file, mode="w") as csv_f: @@ -113,6 +117,9 @@ def get_enroll_test_list(dataset_list, verification_file): for dataset in dataset_list: with open(dataset, 'r') as f: for line in f: + # audio_id may be in enroll and test at the same time + # eg: 1 a.wav a.wav + # the audio a.wav is enroll and test file at the same time audio_id = json.loads(line.strip())['utt'] if audio_id in enroll_audios: enroll_files.append(line) @@ -145,17 +152,18 @@ def get_train_dev_list(dataset_list, target_dir, split_ratio): for dataset in dataset_list: with open(dataset, 'r') as f: for line in f: - spk_id = json.loads(line.strip())['utt2spk'] - speakers.add(spk_id) + # the label is speaker name + label_name = json.loads(line.strip())['utt2spk'] + speakers.add(label_name) audio_files.append(line.strip()) speakers = sorted(speakers) logger.info(f"we get {len(speakers)} speakers from all the train dataset") - with open(os.path.join(target_dir, "meta", "spk_id2label.txt"), 'w') as f: - for label, spk_id in enumerate(speakers): - f.write(f'{spk_id} {label}\n') + with open(os.path.join(target_dir, "meta", "label2id.txt"), 'w') as f: + for label_id, label_name in enumerate(speakers): + f.write(f'{label_name} {label_id}\n') logger.info( - f'we store the speakers to {os.path.join(target_dir, "meta", "spk_id2label.txt")}' + f'we store the speakers to {os.path.join(target_dir, "meta", "label2id.txt")}' ) # the split_ratio is for train dataset @@ -185,7 +193,7 @@ def prepare_data(args, config): return # stage 1: prepare the enroll and test csv file - # And we generate the speaker to label file spk_id2label.txt + # And we generate the speaker to label file label2id.txt logger.info("start to prepare the data csv file") enroll_files, test_files = get_enroll_test_list( [args.test], verification_file=config.verification_file) diff --git a/paddlespeech/vector/exps/ecapa_tdnn/test.py b/paddlespeech/vector/exps/ecapa_tdnn/test.py index d0de6dc5..4d78cfd3 100644 --- a/paddlespeech/vector/exps/ecapa_tdnn/test.py +++ b/paddlespeech/vector/exps/ecapa_tdnn/test.py @@ -21,10 +21,10 @@ from paddle.io import DataLoader from tqdm import tqdm from yacs.config import CfgNode -from paddleaudio.datasets import VoxCeleb from paddleaudio.metric import compute_eer from paddlespeech.s2t.utils.log import Log from paddlespeech.vector.io.batch import batch_feature_normalize +from paddlespeech.vector.io.dataset import CSVDataset from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn from paddlespeech.vector.modules.sid_model import SpeakerIdetification from paddlespeech.vector.training.seeding import seed_everything @@ -58,9 +58,8 @@ def main(args, config): # stage4: construct the enroll and test dataloader - enroll_dataset = VoxCeleb( - subset='enroll', - target_dir=args.data_dir, + enroll_dataset = CSVDataset( + os.path.join(args.data_dir, "vox/csv/enroll.csv"), feat_type='melspectrogram', random_chunk=False, n_mels=config.n_mels, @@ -69,15 +68,14 @@ def main(args, config): enroll_sampler = BatchSampler( enroll_dataset, batch_size=config.batch_size, shuffle=True) # Shuffle to make embedding normalization more robust. - enrol_loader = DataLoader(enroll_dataset, + enroll_loader = DataLoader(enroll_dataset, batch_sampler=enroll_sampler, collate_fn=lambda x: batch_feature_normalize( - x, mean_norm=True, std_norm=False), + x, mean_norm=True, std_norm=False), num_workers=config.num_workers, return_list=True,) - test_dataset = VoxCeleb( - subset='test', - target_dir=args.data_dir, + test_dataset = CSVDataset( + os.path.join(args.data_dir, "vox/csv/test.csv"), feat_type='melspectrogram', random_chunk=False, n_mels=config.n_mels, @@ -108,9 +106,9 @@ def main(args, config): id2embedding = {} # Run multi times to make embedding normalization more stable. for i in range(2): - for dl in [enrol_loader, test_loader]: + for dl in [enroll_loader, test_loader]: logger.info( - f'Loop {[i+1]}: Computing embeddings on {dl.dataset.subset} dataset' + f'Loop {[i+1]}: Computing embeddings on {dl.dataset.csv_path} dataset' ) with paddle.no_grad(): for batch_idx, batch in enumerate(tqdm(dl)): @@ -152,8 +150,8 @@ def main(args, config): labels = [] enroll_ids = [] test_ids = [] - logger.info(f"read the trial from {VoxCeleb.veri_test_file}") - with open(VoxCeleb.veri_test_file, 'r') as f: + logger.info(f"read the trial from {config.verification_file}") + with open(config.verification_file, 'r') as f: for line in f.readlines(): label, enroll_id, test_id = line.strip().split(' ') labels.append(int(label)) diff --git a/paddlespeech/vector/exps/ecapa_tdnn/train.py b/paddlespeech/vector/exps/ecapa_tdnn/train.py index d6919d23..7ff6cb69 100644 --- a/paddlespeech/vector/exps/ecapa_tdnn/train.py +++ b/paddlespeech/vector/exps/ecapa_tdnn/train.py @@ -57,12 +57,10 @@ def main(args, config): # note: some cmd must do in rank==0, so wo will refactor the data prepare code train_dataset = CSVDataset( csv_path=os.path.join(args.data_dir, "vox/csv/train.csv"), - spk_id2label_path=os.path.join(args.data_dir, - "vox/meta/spk_id2label.txt")) + label2id_path=os.path.join(args.data_dir, "vox/meta/label2id.txt")) dev_dataset = CSVDataset( csv_path=os.path.join(args.data_dir, "vox/csv/dev.csv"), - spk_id2label_path=os.path.join(args.data_dir, - "vox/meta/spk_id2label.txt")) + label2id_path=os.path.join(args.data_dir, "vox/meta/label2id.txt")) if config.augment: augment_pipeline = build_augment_pipeline(target_dir=args.data_dir) @@ -148,7 +146,6 @@ def main(args, config): train_reader_cost = 0.0 train_feat_cost = 0.0 train_run_cost = 0.0 - train_misce_cost = 0.0 reader_start = time.time() for batch_idx, batch in enumerate(train_loader): diff --git a/paddlespeech/vector/io/dataset.py b/paddlespeech/vector/io/dataset.py index e70c8d3c..0f87b88e 100644 --- a/paddlespeech/vector/io/dataset.py +++ b/paddlespeech/vector/io/dataset.py @@ -25,7 +25,7 @@ logger = Log(__name__).getlog() # wav: utterance file path # start: start point in the original wav file # stop: stop point in the original wav file -# lab_id: the utterance segment's label id +# label: the utterance segment's label id @dataclass @@ -45,24 +45,24 @@ class meta_info: wav: str start: int stop: int - lab_id: str + label: str class CSVDataset(Dataset): - def __init__(self, csv_path, spk_id2label_path=None, config=None): + def __init__(self, csv_path, label2id_path=None, config=None): """Implement the CSV Dataset Args: csv_path (str): csv dataset file path - spk_id2label_path (str): the utterance label to integer id map file path + label2id_path (str): the utterance label to integer id map file path config (CfgNode): yaml config """ super().__init__() self.csv_path = csv_path - self.spk_id2label_path = spk_id2label_path + self.label2id_path = label2id_path self.config = config - self.spk_id2label = {} - self.label2spk_id = {} + self.id2label = {} + self.label2id = {} self.data = self.load_data_csv() self.load_speaker_to_label() @@ -71,7 +71,7 @@ class CSVDataset(Dataset): the csv dataset's format has six fields, that is audio_id or utt_id, audio duration, segment start point, segment stop point and utterance label. - Note in training period, the utterance label must has a map to integer id in spk_id2label_path + Note in training period, the utterance label must has a map to integer id in label2id_path """ data = [] @@ -91,16 +91,15 @@ class CSVDataset(Dataset): The speaker label is real speaker label in speaker verification domain, and in language identification is language label. """ - if not self.spk_id2label_path: + if not self.label2id_path: logger.warning("No speaker id to label file") return - self.spk_id2label = {} - self.label2spk_id = {} - with open(self.spk_id2label_path, 'r') as f: + + with open(self.label2id_path, 'r') as f: for line in f.readlines(): - spk_id, label = line.strip().split(' ') - self.spk_id2label[spk_id] = int(label) - self.label2spk_id[int(label)] = spk_id + label_name, label_id = line.strip().split(' ') + self.label2id[label_name] = int(label_id) + self.id2label[int(label_id)] = label_name def convert_to_record(self, idx: int): """convert the dataset sample to training record the CSV Dataset @@ -130,8 +129,8 @@ class CSVDataset(Dataset): # we only return the waveform as feat waveform = waveform[start:stop] record.update({'feat': waveform}) - if self.spk_id2label: - record.update({'label': self.spk_id2label[record['lab_id']]}) + if self.label2id: + record.update({'label': self.label2id[record['label']]}) return record diff --git a/paddlespeech/vector/utils/utils.py b/paddlespeech/vector/utils/vector_utils.py similarity index 80% rename from paddlespeech/vector/utils/utils.py rename to paddlespeech/vector/utils/vector_utils.py index 892b19c7..46de7ffa 100644 --- a/paddlespeech/vector/utils/utils.py +++ b/paddlespeech/vector/utils/vector_utils.py @@ -17,14 +17,14 @@ def get_chunks(seg_dur, audio_id, audio_duration): """Get all chunk segments from a utterance Args: - seg_dur (float): segment chunk duration - audio_id (str): utterance name - audio_duration (float): utterance duration + seg_dur (float): segment chunk duration, seconds + audio_id (str): utterance name, + audio_duration (float): utterance duration, seconds Returns: List: all the chunk segments """ - num_chunks = int(audio_duration / seg_dur) # all in milliseconds + num_chunks = int(audio_duration / seg_dur) # all in seconds chunk_lst = [ audio_id + "_" + str(i * seg_dur) + "_" + str(i * seg_dur + seg_dur) for i in range(num_chunks)