diff --git a/examples/voxceleb/sv0/conf/ecapa_tdnn.yaml b/examples/voxceleb/sv0/conf/ecapa_tdnn.yaml index 333040548..720326f8d 100644 --- a/examples/voxceleb/sv0/conf/ecapa_tdnn.yaml +++ b/examples/voxceleb/sv0/conf/ecapa_tdnn.yaml @@ -1,3 +1,12 @@ +########################################### +# Data # +########################################### +batch_size: 32 +num_workers: 2 +num_speakers: 7205 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41 +shuffle: True +random_chunk: True + ########################################################### # FEATURE EXTRACTION SETTING # ########################################################### @@ -7,7 +16,6 @@ feature: window_size: 400 #25ms, sample rate 16000, 25 * 16000 / 1000 = 400 hop_length: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160 - ########################################################### # MODEL SETTING # ########################################################### @@ -15,9 +23,8 @@ feature: # if we want use another model, please choose another configuration yaml file model: input_size: 80 - ##"channels": [1024, 1024, 1024, 1024, 3072], # "channels": [512, 512, 512, 512, 1536], - channels: [512, 512, 512, 512, 1536] + channels: [1024, 1024, 1024, 1024, 3072] kernel_sizes: [5, 3, 3, 3, 1] dilations: [1, 2, 3, 4, 1] attention_channels: 128 @@ -26,10 +33,9 @@ model: ########################################### # Training # ########################################### -seed: 0 +seed: 1986 # according from speechbrain configuration epochs: 10 -batch_size: 32 -num_workers: 2 save_freq: 10 -log_freq: 10 +log_interval: 10 learning_rate: 1e-8 + diff --git a/examples/voxceleb/sv0/run.sh b/examples/voxceleb/sv0/run.sh index 769332eb7..c5dc3dd29 100755 --- a/examples/voxceleb/sv0/run.sh +++ b/examples/voxceleb/sv0/run.sh @@ -47,7 +47,8 @@ mkdir -p ${exp_dir} if [ $stage -le 0 ]; then # stage 0: data prepare for vox1 and vox2, vox2 must be converted from m4a to wav python3 local/data_prepare.py \ - --data-dir ${dir} --augment --vox2-base-path ${vox2_base_path} + --data-dir ${dir} --augment --vox2-base-path ${vox2_base_path} \ + --config conf/ecapa_tdnn.yaml fi if [ $stage -le 1 ]; then diff --git a/paddleaudio/paddleaudio/metric/__init__.py b/paddleaudio/paddleaudio/metric/__init__.py index b435571d4..8e5ca9f75 100644 --- a/paddleaudio/paddleaudio/metric/__init__.py +++ b/paddleaudio/paddleaudio/metric/__init__.py @@ -12,5 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from .dtw import dtw_distance -from .mcd import mcd_distance from .eer import compute_eer +from .eer import compute_minDCF +from .mcd import mcd_distance diff --git a/paddleaudio/paddleaudio/metric/eer.py b/paddleaudio/paddleaudio/metric/eer.py index 65dc7a3c4..7738987eb 100644 --- a/paddleaudio/paddleaudio/metric/eer.py +++ b/paddleaudio/paddleaudio/metric/eer.py @@ -14,6 +14,7 @@ from typing import List import numpy as np +import paddle from sklearn.metrics import roc_curve @@ -26,3 +27,68 @@ def compute_eer(labels: np.ndarray, scores: np.ndarray) -> List[float]: eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))] eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] return eer, eer_threshold + + +def compute_minDCF(positive_scores, + negative_scores, + c_miss=1.0, + c_fa=1.0, + p_target=0.01): + """ + This is modified from SpeechBrain + https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/utils/metric_stats.py#L509 + Computes the minDCF metric normally used to evaluate speaker verification + systems. The min_DCF is the minimum of the following C_det function computed + within the defined threshold range: + + C_det = c_miss * p_miss * p_target + c_fa * p_fa * (1 -p_target) + + where p_miss is the missing probability and p_fa is the probability of having + a false alarm. + + Args: + positive_scores (Paddle.Tensor): The scores from entries of the same class. + negative_scores (Paddle.Tensor): The scores from entries of different classes. + c_miss (float, optional): Cost assigned to a missing error (default 1.0). + c_fa (float, optional): Cost assigned to a false alarm (default 1.0). + p_target (float, optional): Prior probability of having a target (default 0.01). + + Returns: + _type_: min dcf + """ + # Computing candidate thresholds + if len(positive_scores.shape) > 1: + positive_scores = positive_scores.squeeze() + + if len(negative_scores.shape) > 1: + negative_scores = negative_scores.squeeze() + + thresholds = paddle.sort(paddle.concat([positive_scores, negative_scores])) + thresholds = paddle.unique(thresholds) + + # Adding intermediate thresholds + interm_thresholds = (thresholds[0:-1] + thresholds[1:]) / 2 + thresholds = paddle.sort(paddle.concat([thresholds, interm_thresholds])) + + # Computing False Rejection Rate (miss detection) + positive_scores = paddle.concat( + len(thresholds) * [positive_scores.unsqueeze(0)]) + pos_scores_threshold = positive_scores.transpose(perm=[1, 0]) <= thresholds + p_miss = (pos_scores_threshold.sum(0) + ).astype("float32") / positive_scores.shape[1] + del positive_scores + del pos_scores_threshold + + # Computing False Acceptance Rate (false alarm) + negative_scores = paddle.concat( + len(thresholds) * [negative_scores.unsqueeze(0)]) + neg_scores_threshold = negative_scores.transpose(perm=[1, 0]) > thresholds + p_fa = (neg_scores_threshold.sum(0) + ).astype("float32") / negative_scores.shape[1] + del negative_scores + del neg_scores_threshold + + c_det = c_miss * p_miss * p_target + c_fa * p_fa * (1 - p_target) + c_min = paddle.min(c_det, axis=0) + min_index = paddle.argmin(c_det, axis=0) + return float(c_min), float(thresholds[min_index]) diff --git a/paddlespeech/vector/exps/ecapa_tdnn/extract_speaker_embedding.py b/paddlespeech/vector/exps/ecapa_tdnn/extract_emb.py similarity index 100% rename from paddlespeech/vector/exps/ecapa_tdnn/extract_speaker_embedding.py rename to paddlespeech/vector/exps/ecapa_tdnn/extract_emb.py diff --git a/paddlespeech/vector/exps/ecapa_tdnn/speaker_verification_cosine.py b/paddlespeech/vector/exps/ecapa_tdnn/speaker_verification_cosine.py index 01a3506a2..781bf2a5e 100644 --- a/paddlespeech/vector/exps/ecapa_tdnn/speaker_verification_cosine.py +++ b/paddlespeech/vector/exps/ecapa_tdnn/speaker_verification_cosine.py @@ -45,7 +45,7 @@ def main(args, config): # stage2: build the speaker verification eval instance with backbone model model = SpeakerIdetification( - backbone=ecapa_tdnn, num_class=VoxCeleb.num_speakers) + backbone=ecapa_tdnn, num_class=config.num_speakers) # stage3: load the pre-trained model args.load_checkpoint = os.path.abspath( @@ -93,6 +93,7 @@ def main(args, config): model.eval() # stage7: global embedding norm to imporve the performance + print("global embedding norm: {}".format(args.global_embedding_norm)) if args.global_embedding_norm: global_embedding_mean = None global_embedding_std = None @@ -118,6 +119,8 @@ def main(args, config): -1).numpy() # (N, emb_size, 1) -> (N, emb_size) # Global embedding normalization. + # if we use the global embedding norm + # eer can reduece about relative 10% if args.global_embedding_norm: batch_count += 1 current_mean = embeddings.mean( @@ -150,8 +153,8 @@ def main(args, config): for line in f.readlines(): label, enrol_id, test_id = line.strip().split(' ') labels.append(int(label)) - enrol_ids.append(enrol_id.split('.')[0].replace('/', '-')) - test_ids.append(test_id.split('.')[0].replace('/', '-')) + enrol_ids.append(enrol_id.split('.')[0].replace('/', '--')) + test_ids.append(test_id.split('.')[0].replace('/', '--')) cos_sim_func = paddle.nn.CosineSimilarity(axis=1) enrol_embeddings, test_embeddings = map(lambda ids: paddle.to_tensor( @@ -185,11 +188,10 @@ if __name__ == "__main__": default='', help="Directory to load model checkpoint to contiune trainning.") parser.add_argument("--global-embedding-norm", - type=bool, - default=True, + default=False, + action="store_true", help="Apply global normalization on speaker embeddings.") parser.add_argument("--embedding-mean-norm", - type=bool, default=True, help="Apply mean normalization on speaker embeddings.") parser.add_argument("--embedding-std-norm", diff --git a/paddlespeech/vector/exps/ecapa_tdnn/train.py b/paddlespeech/vector/exps/ecapa_tdnn/train.py index 6e6e5ab24..cb20ef167 100644 --- a/paddlespeech/vector/exps/ecapa_tdnn/train.py +++ b/paddlespeech/vector/exps/ecapa_tdnn/train.py @@ -178,9 +178,9 @@ def main(args, config): timer.count() # step plus one in timer # stage 9-10: print the log information only on 0-rank per log-freq batchs - if (batch_idx + 1) % config.log_freq == 0 and local_rank == 0: + if (batch_idx + 1) % config.log_interval == 0 and local_rank == 0: lr = optimizer.get_lr() - avg_loss /= config.log_freq + avg_loss /= config.log_interval avg_acc = num_corrects / num_samples print_msg = 'Train Epoch={}/{}, Step={}/{}'.format( @@ -196,7 +196,7 @@ def main(args, config): num_samples = 0 # stage 9-11: save the model parameters only on 0-rank per save-freq batchs - if epoch % config.save_freq == 0 and batch_idx + 1 == steps_per_epoch: + if epoch % config.save_interval == 0 and batch_idx + 1 == steps_per_epoch: if local_rank != 0: paddle.distributed.barrier( ) # Wait for valid step in main process diff --git a/paddlespeech/vector/io/augment.py b/paddlespeech/vector/io/augment.py index 1b9d1fbd8..f40ce41b7 100644 --- a/paddlespeech/vector/io/augment.py +++ b/paddlespeech/vector/io/augment.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# this is modified from https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/lobes/augment.py +# this is modified from SpeechBrain +# https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/lobes/augment.py import math import os from typing import List diff --git a/paddlespeech/vector/io/batch.py b/paddlespeech/vector/io/batch.py index 811775e20..85f2ab8b6 100644 --- a/paddlespeech/vector/io/batch.py +++ b/paddlespeech/vector/io/batch.py @@ -75,6 +75,9 @@ def batch_feature_normalize(batch, mean_norm: bool=True, std_norm: bool=True): i]:].sum() == 0 # Padding valus should all be 0. # Converts into ratios. + # the utterance of the max length doesn't need to padding + # the remaining utterances need to padding and all of them will be padded to max length + # we convert the original length of each utterance to the ratio of the max length lengths = (lengths / lengths.max()).astype(np.float32) return {'ids': ids, 'feats': feats, 'lengths': lengths} \ No newline at end of file