diff --git a/examples/voxceleb/sv0/conf/ecapa_tdnn.yaml b/examples/voxceleb/sv0/conf/ecapa_tdnn.yaml index 4715c5a3..3e3a1307 100644 --- a/examples/voxceleb/sv0/conf/ecapa_tdnn.yaml +++ b/examples/voxceleb/sv0/conf/ecapa_tdnn.yaml @@ -4,7 +4,7 @@ augment: True batch_size: 32 num_workers: 2 -num_speakers: 1211 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41 +num_speakers: 7205 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41 shuffle: True skip_prep: False split_ratio: 0.9 @@ -42,8 +42,16 @@ epochs: 10 save_interval: 10 log_interval: 10 learning_rate: 1e-8 +max_lr: 1e-3 +step_size: 140000 +########################################### +# loss # +########################################### +margin: 0.2 +scale: 30 + ########################################### # Testing # ########################################### diff --git a/examples/voxceleb/sv0/conf/ecapa_tdnn_small.yaml b/examples/voxceleb/sv0/conf/ecapa_tdnn_small.yaml index 5ad5ea28..5925e573 100644 --- a/examples/voxceleb/sv0/conf/ecapa_tdnn_small.yaml +++ b/examples/voxceleb/sv0/conf/ecapa_tdnn_small.yaml @@ -2,7 +2,7 @@ # Data # ########################################### augment: True -batch_size: 16 +batch_size: 32 num_workers: 2 num_speakers: 1211 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41 shuffle: True @@ -42,7 +42,14 @@ epochs: 100 save_interval: 10 log_interval: 10 learning_rate: 1e-8 +max_lr: 1e-3 +step_size: 140000 +########################################### +# loss # +########################################### +margin: 0.2 +scale: 30 ########################################### # Testing # diff --git a/paddlespeech/vector/exps/ecapa_tdnn/test.py b/paddlespeech/vector/exps/ecapa_tdnn/test.py index 70b1521e..1b38075d 100644 --- a/paddlespeech/vector/exps/ecapa_tdnn/test.py +++ b/paddlespeech/vector/exps/ecapa_tdnn/test.py @@ -38,10 +38,10 @@ def compute_dataset_embedding(data_loader, model, mean_var_norm_emb, config, """compute the dataset embeddings Args: - data_loader (_type_): _description_ - model (_type_): _description_ - mean_var_norm_emb (_type_): _description_ - config (_type_): _description_ + data_loader (paddle.io.Dataloader): the dataset loader to be compute the embedding + model (paddle.nn.Layer): the speaker verification model + mean_var_norm_emb : compute the embedding mean and std norm + config (yacs.config.CfgNode): the yaml config """ logger.info( f'Computing embeddings on {data_loader.dataset.csv_path} dataset') @@ -65,6 +65,17 @@ def compute_dataset_embedding(data_loader, model, mean_var_norm_emb, config, def compute_verification_scores(id2embedding, train_cohort, config): + """Compute the verification trial scores + + Args: + id2embedding (dict): the utterance embedding + train_cohort (paddle.tensor): the cohort dataset embedding + config (yacs.config.CfgNode): the yaml config + + Returns: + the scores and the trial labels, + 1 refers the target and 0 refers the nontarget in labels + """ labels = [] enroll_ids = [] test_ids = [] @@ -119,20 +130,32 @@ def compute_verification_scores(id2embedding, train_cohort, config): def main(args, config): + """The main process for test the speaker verification model + + Args: + args (argparse.Namespace): the command line args namespace + config (yacs.config.CfgNode): the yaml config + """ + # stage0: set the training device, cpu or gpu + # if set the gpu, paddlespeech will select a gpu according the env CUDA_VISIBLE_DEVICES paddle.set_device(args.device) - # set the random seed, it is a must for multiprocess training + # set the random seed, it is the necessary measures for multiprocess training seed_everything(config.seed) # stage1: build the dnn backbone model network + # we will extract the audio embedding from the backbone model ecapa_tdnn = EcapaTdnn(**config.model) # stage2: build the speaker verification eval instance with backbone model + # because the checkpoint dict name has the SpeakerIdetification prefix + # so we need to create the SpeakerIdetification instance + # but we acutally use the backbone model to extact the audio embedding model = SpeakerIdetification( backbone=ecapa_tdnn, num_class=config.num_speakers) # stage3: load the pre-trained model - # we get the last model from the epoch and save_interval + # generally, we get the last model from the epoch args.load_checkpoint = os.path.abspath( os.path.expanduser(args.load_checkpoint)) @@ -143,7 +166,8 @@ def main(args, config): logger.info(f'Checkpoint loaded from {args.load_checkpoint}') # stage4: construct the enroll and test dataloader - + # Now, wo think the enroll dataset is in the {args.data_dir}/vox/csv/enroll.csv, + # and the test dataset is in the {args.data_dir}/vox/csv/test.csv enroll_dataset = CSVDataset( os.path.join(args.data_dir, "vox/csv/enroll.csv"), feat_type='melspectrogram', @@ -152,14 +176,14 @@ def main(args, config): window_size=config.window_size, hop_length=config.hop_size) enroll_sampler = BatchSampler( - enroll_dataset, batch_size=config.batch_size, - shuffle=False) # Shuffle to make embedding normalization more robust. + enroll_dataset, batch_size=config.batch_size, shuffle=False) enroll_loader = DataLoader(enroll_dataset, batch_sampler=enroll_sampler, collate_fn=lambda x: batch_feature_normalize( x, mean_norm=True, std_norm=False), num_workers=config.num_workers, return_list=True,) + test_dataset = CSVDataset( os.path.join(args.data_dir, "vox/csv/test.csv"), feat_type='melspectrogram', @@ -167,7 +191,6 @@ def main(args, config): n_mels=config.n_mels, window_size=config.window_size, hop_length=config.hop_size) - test_sampler = BatchSampler( test_dataset, batch_size=config.batch_size, shuffle=False) test_loader = DataLoader(test_dataset, @@ -180,16 +203,17 @@ def main(args, config): model.eval() # stage6: global embedding norm to imporve the performance + # and we create the InputNormalization instance to process the embedding mean and std norm logger.info(f"global embedding norm: {config.global_embedding_norm}") - - # stage7: Compute embeddings of audios in enrol and test dataset from model. - if config.global_embedding_norm: mean_var_norm_emb = InputNormalization( norm_type="global", mean_norm=config.embedding_mean_norm, std_norm=config.embedding_std_norm) + # stage 7: score norm need the imposters dataset + # we select the train dataset as the idea imposters dataset + # and we select the config.n_train_snts utterance to as the final imposters dataset if "score_norm" in config: logger.info(f"we will do score norm: {config.score_norm}") train_dataset = CSVDataset( @@ -209,6 +233,7 @@ def main(args, config): num_workers=config.num_workers, return_list=True,) + # stage 8: Compute embeddings of audios in enrol and test dataset from model. id2embedding = {} # Run multi times to make embedding normalization more stable. logger.info("First loop for enroll and test dataset") @@ -225,7 +250,7 @@ def main(args, config): mean_var_norm_emb.save( os.path.join(args.load_checkpoint, "mean_var_norm_emb")) - # stage 8: Compute cosine scores. + # stage 9: Compute cosine scores. train_cohort = None if "score_norm" in config: train_embeddings = {} @@ -234,11 +259,11 @@ def main(args, config): train_embeddings) train_cohort = paddle.stack(list(train_embeddings.values())) - # compute the scores + # stage 10: compute the scores scores, labels = compute_verification_scores(id2embedding, train_cohort, config) - # compute the EER and threshold + # stage 11: compute the EER and threshold scores = paddle.to_tensor(scores) EER, threshold = compute_eer(np.asarray(labels), scores.numpy()) logger.info( diff --git a/paddlespeech/vector/exps/ecapa_tdnn/train.py b/paddlespeech/vector/exps/ecapa_tdnn/train.py index b777dae8..8855689d 100644 --- a/paddlespeech/vector/exps/ecapa_tdnn/train.py +++ b/paddlespeech/vector/exps/ecapa_tdnn/train.py @@ -42,6 +42,12 @@ logger = Log(__name__).getlog() def main(args, config): + """The main process for test the speaker verification model + + Args: + args (argparse.Namespace): the command line args namespace + config (yacs.config.CfgNode): the yaml config + """ # stage0: set the training device, cpu or gpu paddle.set_device(args.device) @@ -49,11 +55,11 @@ def main(args, config): paddle.distributed.init_parallel_env() nranks = paddle.distributed.get_world_size() local_rank = paddle.distributed.get_rank() - # set the random seed, it is a must for multiprocess training + # set the random seed, it is the necessary measures for multiprocess training seed_everything(config.seed) # stage2: data prepare, such vox1 and vox2 data, and augment noise data and pipline - # note: some cmd must do in rank==0, so wo will refactor the data prepare code + # note: some operations must be done in rank==0 train_dataset = CSVDataset( csv_path=os.path.join(args.data_dir, "vox/csv/train.csv"), label2id_path=os.path.join(args.data_dir, "vox/meta/label2id.txt")) @@ -61,12 +67,14 @@ def main(args, config): csv_path=os.path.join(args.data_dir, "vox/csv/dev.csv"), label2id_path=os.path.join(args.data_dir, "vox/meta/label2id.txt")) + # we will build the augment pipeline process list if config.augment: augment_pipeline = build_augment_pipeline(target_dir=args.data_dir) else: augment_pipeline = [] # stage3: build the dnn backbone model network + # in speaker verification period, we use the backbone mode to extract the audio embedding ecapa_tdnn = EcapaTdnn(**config.model) # stage4: build the speaker verification train instance with backbone model @@ -77,13 +85,15 @@ def main(args, config): # 140000 is single gpu steps # so, in multi-gpu mode, wo reduce the step_size to 140000//nranks to enable CyclicLRScheduler lr_schedule = CyclicLRScheduler( - base_lr=config.learning_rate, max_lr=1e-3, step_size=140000 // nranks) + base_lr=config.learning_rate, + max_lr=config.max_lr, + step_size=config.step_size // nranks) optimizer = paddle.optimizer.AdamW( learning_rate=lr_schedule, parameters=model.parameters()) # stage6: build the loss function, we now only support LogSoftmaxWrapper criterion = LogSoftmaxWrapper( - loss_fn=AdditiveAngularMargin(margin=0.2, scale=30)) + loss_fn=AdditiveAngularMargin(margin=config.margin, scale=config.scale)) # stage7: confirm training start epoch # if pre-trained model exists, start epoch confirmed by the pre-trained model @@ -225,7 +235,7 @@ def main(args, config): print_msg += ' avg_train_cost: {:.5f} sec,'.format( train_run_cost / config.log_interval) - print_msg += ' lr={:.4E} step/sec={:.2f} ips:{:.5f}| ETA {}'.format( + print_msg += ' lr={:.4E} step/sec={:.2f} ips={:.5f}| ETA {}'.format( lr, timer.timing, timer.ips, timer.eta) logger.info(print_msg) diff --git a/paddlespeech/vector/io/embedding_norm.py b/paddlespeech/vector/io/embedding_norm.py index 619f3710..429bb1af 100644 --- a/paddlespeech/vector/io/embedding_norm.py +++ b/paddlespeech/vector/io/embedding_norm.py @@ -57,14 +57,14 @@ class InputNormalization: lengths (paddle.Tensor): A batch of tensors containing the relative length of each sentence (e.g, [0.7, 0.9, 1.0]). It is used to avoid computing stats on zero-padded steps. - spk_ids (_type_, optional): tensor containing the ids of each speaker (e.g, [0 10 6]). + spk_ids (paddle.Tensor, optional): tensor containing the ids of each speaker (e.g, [0 10 6]). It is used to perform per-speaker normalization when norm_type='speaker'. Defaults to paddle.to_tensor([], dtype="float32"). Returns: paddle.Tensor: The normalized feature or embedding """ N_batches = x.shape[0] - # print(f"x shape: {x.shape[1]}") + current_means = [] current_stds = [] @@ -75,6 +75,9 @@ class InputNormalization: actual_size = paddle.round(lengths[snt_id] * x.shape[1]).astype("int32") # computing actual time data statistics + # we extract the snt_id embedding from the x + # and the target paddle.Tensor will reduce an 0-axis + # so we need unsqueeze operation to recover the all axis current_mean, current_std = self._compute_current_stats( x[snt_id, 0:actual_size, ...].unsqueeze(0)) current_means.append(current_mean)