From 567286add3186030027a9934daa463fdf4537446 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Sun, 10 Apr 2022 13:52:05 +0800 Subject: [PATCH] wrap the embedding mean and std norm, test=doc --- paddlespeech/vector/exps/ecapa_tdnn/test.py | 206 ++++++++++++------ paddlespeech/vector/exps/ecapa_tdnn/train.py | 8 +- paddlespeech/vector/io/dataset.py | 15 ++ paddlespeech/vector/io/embedding_norm.py | 214 +++++++++++++++++++ paddlespeech/vector/utils/time.py | 6 + 5 files changed, 379 insertions(+), 70 deletions(-) create mode 100644 paddlespeech/vector/io/embedding_norm.py diff --git a/paddlespeech/vector/exps/ecapa_tdnn/test.py b/paddlespeech/vector/exps/ecapa_tdnn/test.py index 4d78cfd3..70b1521e 100644 --- a/paddlespeech/vector/exps/ecapa_tdnn/test.py +++ b/paddlespeech/vector/exps/ecapa_tdnn/test.py @@ -25,6 +25,7 @@ from paddleaudio.metric import compute_eer from paddlespeech.s2t.utils.log import Log from paddlespeech.vector.io.batch import batch_feature_normalize from paddlespeech.vector.io.dataset import CSVDataset +from paddlespeech.vector.io.embedding_norm import InputNormalization from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn from paddlespeech.vector.modules.sid_model import SpeakerIdetification from paddlespeech.vector.training.seeding import seed_everything @@ -32,6 +33,91 @@ from paddlespeech.vector.training.seeding import seed_everything logger = Log(__name__).getlog() +def compute_dataset_embedding(data_loader, model, mean_var_norm_emb, config, + id2embedding): + """compute the dataset embeddings + + Args: + data_loader (_type_): _description_ + model (_type_): _description_ + mean_var_norm_emb (_type_): _description_ + config (_type_): _description_ + """ + logger.info( + f'Computing embeddings on {data_loader.dataset.csv_path} dataset') + with paddle.no_grad(): + for batch_idx, batch in enumerate(tqdm(data_loader)): + + # stage 8-1: extrac the audio embedding + ids, feats, lengths = batch['ids'], batch['feats'], batch['lengths'] + embeddings = model.backbone(feats, lengths).squeeze( + -1) # (N, emb_size, 1) -> (N, emb_size) + + # Global embedding normalization. + # if we use the global embedding norm + # eer can reduece about relative 10% + if config.global_embedding_norm and mean_var_norm_emb: + lengths = paddle.ones([embeddings.shape[0]]) + embeddings = mean_var_norm_emb(embeddings, lengths) + + # Update embedding dict. + id2embedding.update(dict(zip(ids, embeddings))) + + +def compute_verification_scores(id2embedding, train_cohort, config): + labels = [] + enroll_ids = [] + test_ids = [] + logger.info(f"read the trial from {config.verification_file}") + cos_sim_func = paddle.nn.CosineSimilarity(axis=-1) + scores = [] + with open(config.verification_file, 'r') as f: + for line in f.readlines(): + label, enroll_id, test_id = line.strip().split(' ') + enroll_id = enroll_id.split('.')[0].replace('/', '-') + test_id = test_id.split('.')[0].replace('/', '-') + labels.append(int(label)) + + enroll_emb = id2embedding[enroll_id] + test_emb = id2embedding[test_id] + score = cos_sim_func(enroll_emb, test_emb).item() + + if "score_norm" in config: + # Getting norm stats for enroll impostors + enroll_rep = paddle.tile( + enroll_emb, repeat_times=[train_cohort.shape[0], 1]) + score_e_c = cos_sim_func(enroll_rep, train_cohort) + if "cohort_size" in config: + score_e_c, _ = paddle.topk( + score_e_c, k=config.cohort_size, axis=0) + mean_e_c = paddle.mean(score_e_c, axis=0) + std_e_c = paddle.std(score_e_c, axis=0) + + # Getting norm stats for test impostors + test_rep = paddle.tile( + test_emb, repeat_times=[train_cohort.shape[0], 1]) + score_t_c = cos_sim_func(test_rep, train_cohort) + if "cohort_size" in config: + score_t_c, _ = paddle.topk( + score_t_c, k=config.cohort_size, axis=0) + mean_t_c = paddle.mean(score_t_c, axis=0) + std_t_c = paddle.std(score_t_c, axis=0) + + if config.score_norm == "s-norm": + score_e = (score - mean_e_c) / std_e_c + score_t = (score - mean_t_c) / std_t_c + + score = 0.5 * (score_e + score_t) + elif config.score_norm == "z-norm": + score = (score - mean_e_c) / std_e_c + elif config.score_norm == "t-norm": + score = (score - mean_t_c) / std_t_c + + scores.append(score) + + return scores, labels + + def main(args, config): # stage0: set the training device, cpu or gpu paddle.set_device(args.device) @@ -67,7 +153,7 @@ def main(args, config): hop_length=config.hop_size) enroll_sampler = BatchSampler( enroll_dataset, batch_size=config.batch_size, - shuffle=True) # Shuffle to make embedding normalization more robust. + shuffle=False) # Shuffle to make embedding normalization more robust. enroll_loader = DataLoader(enroll_dataset, batch_sampler=enroll_sampler, collate_fn=lambda x: batch_feature_normalize( @@ -83,7 +169,7 @@ def main(args, config): hop_length=config.hop_size) test_sampler = BatchSampler( - test_dataset, batch_size=config.batch_size, shuffle=True) + test_dataset, batch_size=config.batch_size, shuffle=False) test_loader = DataLoader(test_dataset, batch_sampler=test_sampler, collate_fn=lambda x: batch_feature_normalize( @@ -95,75 +181,65 @@ def main(args, config): # stage6: global embedding norm to imporve the performance logger.info(f"global embedding norm: {config.global_embedding_norm}") - if config.global_embedding_norm: - global_embedding_mean = None - global_embedding_std = None - mean_norm_flag = config.embedding_mean_norm - std_norm_flag = config.embedding_std_norm - batch_count = 0 # stage7: Compute embeddings of audios in enrol and test dataset from model. + + if config.global_embedding_norm: + mean_var_norm_emb = InputNormalization( + norm_type="global", + mean_norm=config.embedding_mean_norm, + std_norm=config.embedding_std_norm) + + if "score_norm" in config: + logger.info(f"we will do score norm: {config.score_norm}") + train_dataset = CSVDataset( + os.path.join(args.data_dir, "vox/csv/train.csv"), + feat_type='melspectrogram', + n_train_snts=config.n_train_snts, + random_chunk=False, + n_mels=config.n_mels, + window_size=config.window_size, + hop_length=config.hop_size) + train_sampler = BatchSampler( + train_dataset, batch_size=config.batch_size, shuffle=False) + train_loader = DataLoader(train_dataset, + batch_sampler=train_sampler, + collate_fn=lambda x: batch_feature_normalize( + x, mean_norm=True, std_norm=False), + num_workers=config.num_workers, + return_list=True,) + id2embedding = {} # Run multi times to make embedding normalization more stable. - for i in range(2): - for dl in [enroll_loader, test_loader]: - logger.info( - f'Loop {[i+1]}: Computing embeddings on {dl.dataset.csv_path} dataset' - ) - with paddle.no_grad(): - for batch_idx, batch in enumerate(tqdm(dl)): - - # stage 8-1: extrac the audio embedding - ids, feats, lengths = batch['ids'], batch['feats'], batch[ - 'lengths'] - embeddings = model.backbone(feats, lengths).squeeze( - -1).numpy() # (N, emb_size, 1) -> (N, emb_size) - - # Global embedding normalization. - # if we use the global embedding norm - # eer can reduece about relative 10% - if config.global_embedding_norm: - batch_count += 1 - current_mean = embeddings.mean( - axis=0) if mean_norm_flag else 0 - current_std = embeddings.std( - axis=0) if std_norm_flag else 1 - # Update global mean and std. - if global_embedding_mean is None and global_embedding_std is None: - global_embedding_mean, global_embedding_std = current_mean, current_std - else: - weight = 1 / batch_count # Weight decay by batches. - global_embedding_mean = ( - 1 - weight - ) * global_embedding_mean + weight * current_mean - global_embedding_std = ( - 1 - weight - ) * global_embedding_std + weight * current_std - # Apply global embedding normalization. - embeddings = (embeddings - global_embedding_mean - ) / global_embedding_std - - # Update embedding dict. - id2embedding.update(dict(zip(ids, embeddings))) + logger.info("First loop for enroll and test dataset") + compute_dataset_embedding(enroll_loader, model, mean_var_norm_emb, config, + id2embedding) + compute_dataset_embedding(test_loader, model, mean_var_norm_emb, config, + id2embedding) + + logger.info("Second loop for enroll and test dataset") + compute_dataset_embedding(enroll_loader, model, mean_var_norm_emb, config, + id2embedding) + compute_dataset_embedding(test_loader, model, mean_var_norm_emb, config, + id2embedding) + mean_var_norm_emb.save( + os.path.join(args.load_checkpoint, "mean_var_norm_emb")) # stage 8: Compute cosine scores. - labels = [] - enroll_ids = [] - test_ids = [] - logger.info(f"read the trial from {config.verification_file}") - with open(config.verification_file, 'r') as f: - for line in f.readlines(): - label, enroll_id, test_id = line.strip().split(' ') - labels.append(int(label)) - enroll_ids.append(enroll_id.split('.')[0].replace('/', '-')) - test_ids.append(test_id.split('.')[0].replace('/', '-')) - - cos_sim_func = paddle.nn.CosineSimilarity(axis=1) - enrol_embeddings, test_embeddings = map(lambda ids: paddle.to_tensor( - np.asarray([id2embedding[uttid] for uttid in ids], dtype='float32')), - [enroll_ids, test_ids - ]) # (N, emb_size) - scores = cos_sim_func(enrol_embeddings, test_embeddings) + train_cohort = None + if "score_norm" in config: + train_embeddings = {} + # cohort embedding not do mean and std norm + compute_dataset_embedding(train_loader, model, None, config, + train_embeddings) + train_cohort = paddle.stack(list(train_embeddings.values())) + + # compute the scores + scores, labels = compute_verification_scores(id2embedding, train_cohort, + config) + + # compute the EER and threshold + scores = paddle.to_tensor(scores) EER, threshold = compute_eer(np.asarray(labels), scores.numpy()) logger.info( f'EER of verification test: {EER*100:.4f}%, score threshold: {threshold:.5f}' diff --git a/paddlespeech/vector/exps/ecapa_tdnn/train.py b/paddlespeech/vector/exps/ecapa_tdnn/train.py index c1590c8f..adbb3285 100644 --- a/paddlespeech/vector/exps/ecapa_tdnn/train.py +++ b/paddlespeech/vector/exps/ecapa_tdnn/train.py @@ -197,17 +197,15 @@ def main(args, config): paddle.optimizer.lr.LRScheduler): optimizer._learning_rate.step() optimizer.clear_grad() - train_run_cost += time.time() - train_start # stage 9-8: Calculate average loss per batch - train_misce_start = time.time() avg_loss = loss.item() # stage 9-9: Calculate metrics, which is one-best accuracy preds = paddle.argmax(logits, axis=1) num_corrects += (preds == labels).numpy().sum() num_samples += feats.shape[0] - + train_run_cost += time.time() - train_start timer.count() # step plus one in timer # stage 9-10: print the log information only on 0-rank per log-freq batchs @@ -227,8 +225,8 @@ def main(args, config): print_msg += ' avg_train_cost: {:.5f} sec,'.format( train_run_cost / config.log_interval) - print_msg += ' lr={:.4E} step/sec={:.2f} | ETA {}'.format( - lr, timer.timing, timer.eta) + print_msg += ' lr={:.4E} step/sec={:.2f} ips={:.2f}| ETA {}'.format( + lr, timer.timing, timer.ips, timer.eta) logger.info(print_msg) avg_loss = 0 diff --git a/paddlespeech/vector/io/dataset.py b/paddlespeech/vector/io/dataset.py index e7a8445b..316c8ac3 100644 --- a/paddlespeech/vector/io/dataset.py +++ b/paddlespeech/vector/io/dataset.py @@ -65,6 +65,7 @@ class CSVDataset(Dataset): config=None, random_chunk=True, feat_type: str="raw", + n_train_snts: int=-1, **kwargs): """Implement the CSV Dataset @@ -73,6 +74,9 @@ class CSVDataset(Dataset): label2id_path (str): the utterance label to integer id map file path config (CfgNode): yaml config feat_type (str): dataset feature type. if it is raw, it return pcm data. + n_train_snts (int): select the n_train_snts sample from the dataset. + if n_train_snts = -1, dataset will load all the sample. + Default value is -1. kwargs : feature type args """ super().__init__() @@ -81,6 +85,7 @@ class CSVDataset(Dataset): self.config = config self.random_chunk = random_chunk self.feat_type = feat_type + self.n_train_snts = n_train_snts self.feat_config = kwargs self.id2label = {} self.label2id = {} @@ -93,6 +98,9 @@ class CSVDataset(Dataset): that is audio_id or utt_id, audio duration, segment start point, segment stop point and utterance label. Note in training period, the utterance label must has a map to integer id in label2id_path + + Returns: + list: the csv data with meta_info type """ data = [] @@ -104,6 +112,10 @@ class CSVDataset(Dataset): meta_info(audio_id, float(duration), wav, int(start), int(stop), spk_id)) + if self.n_train_snts > 0: + sample_num = min(self.n_train_snts, len(data)) + data = data[0:sample_num] + return data def load_speaker_to_label(self): @@ -173,5 +185,8 @@ class CSVDataset(Dataset): def __len__(self): """Return the dataset length + + Returns: + int: the length num of the dataset """ return len(self.data) diff --git a/paddlespeech/vector/io/embedding_norm.py b/paddlespeech/vector/io/embedding_norm.py new file mode 100644 index 00000000..619f3710 --- /dev/null +++ b/paddlespeech/vector/io/embedding_norm.py @@ -0,0 +1,214 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict + +import paddle + + +class InputNormalization: + spk_dict_mean: Dict[int, paddle.Tensor] + spk_dict_std: Dict[int, paddle.Tensor] + spk_dict_count: Dict[int, int] + + def __init__( + self, + mean_norm=True, + std_norm=True, + norm_type="global", ): + """Do feature or embedding mean and std norm + + Args: + mean_norm (bool, optional): mean norm flag. Defaults to True. + std_norm (bool, optional): std norm flag. Defaults to True. + norm_type (str, optional): norm type. Defaults to "global". + """ + super().__init__() + self.training = True + self.mean_norm = mean_norm + self.std_norm = std_norm + self.norm_type = norm_type + self.glob_mean = paddle.to_tensor([0], dtype="float32") + self.glob_std = paddle.to_tensor([0], dtype="float32") + self.spk_dict_mean = {} + self.spk_dict_std = {} + self.spk_dict_count = {} + self.weight = 1.0 + self.count = 0 + self.eps = 1e-10 + + def __call__(self, + x, + lengths, + spk_ids=paddle.to_tensor([], dtype="float32")): + """Returns the tensor with the surrounding context. + Args: + x (paddle.Tensor): A batch of tensors. + lengths (paddle.Tensor): A batch of tensors containing the relative length of each + sentence (e.g, [0.7, 0.9, 1.0]). It is used to avoid + computing stats on zero-padded steps. + spk_ids (_type_, optional): tensor containing the ids of each speaker (e.g, [0 10 6]). + It is used to perform per-speaker normalization when + norm_type='speaker'. Defaults to paddle.to_tensor([], dtype="float32"). + Returns: + paddle.Tensor: The normalized feature or embedding + """ + N_batches = x.shape[0] + # print(f"x shape: {x.shape[1]}") + current_means = [] + current_stds = [] + + for snt_id in range(N_batches): + + # Avoiding padded time steps + # actual size is the actual time data length + actual_size = paddle.round(lengths[snt_id] * + x.shape[1]).astype("int32") + # computing actual time data statistics + current_mean, current_std = self._compute_current_stats( + x[snt_id, 0:actual_size, ...].unsqueeze(0)) + current_means.append(current_mean) + current_stds.append(current_std) + + if self.norm_type == "global": + current_mean = paddle.mean(paddle.stack(current_means), axis=0) + current_std = paddle.mean(paddle.stack(current_stds), axis=0) + + if self.norm_type == "global": + + if self.training: + if self.count == 0: + self.glob_mean = current_mean + self.glob_std = current_std + + else: + self.weight = 1 / (self.count + 1) + + self.glob_mean = ( + 1 - self.weight + ) * self.glob_mean + self.weight * current_mean + + self.glob_std = ( + 1 - self.weight + ) * self.glob_std + self.weight * current_std + + self.glob_mean.detach() + self.glob_std.detach() + + self.count = self.count + 1 + x = (x - self.glob_mean) / (self.glob_std) + return x + + def _compute_current_stats(self, x): + """Returns the tensor with the surrounding context. + + Args: + x (paddle.Tensor): A batch of tensors. + + Returns: + the statistics of the data + """ + # Compute current mean + if self.mean_norm: + current_mean = paddle.mean(x, axis=0).detach() + else: + current_mean = paddle.to_tensor([0.0], dtype="float32") + + # Compute current std + if self.std_norm: + current_std = paddle.std(x, axis=0).detach() + else: + current_std = paddle.to_tensor([1.0], dtype="float32") + + # Improving numerical stability of std + current_std = paddle.maximum(current_std, + self.eps * paddle.ones_like(current_std)) + + return current_mean, current_std + + def _statistics_dict(self): + """Fills the dictionary containing the normalization statistics. + """ + state = {} + state["count"] = self.count + state["glob_mean"] = self.glob_mean + state["glob_std"] = self.glob_std + state["spk_dict_mean"] = self.spk_dict_mean + state["spk_dict_std"] = self.spk_dict_std + state["spk_dict_count"] = self.spk_dict_count + + return state + + def _load_statistics_dict(self, state): + """Loads the dictionary containing the statistics. + + Arguments + --------- + state : dict + A dictionary containing the normalization statistics. + """ + self.count = state["count"] + if isinstance(state["glob_mean"], int): + self.glob_mean = state["glob_mean"] + self.glob_std = state["glob_std"] + else: + self.glob_mean = state["glob_mean"] # .to(self.device_inp) + self.glob_std = state["glob_std"] # .to(self.device_inp) + + # Loading the spk_dict_mean in the right device + self.spk_dict_mean = {} + for spk in state["spk_dict_mean"]: + self.spk_dict_mean[spk] = state["spk_dict_mean"][spk] + + # Loading the spk_dict_std in the right device + self.spk_dict_std = {} + for spk in state["spk_dict_std"]: + self.spk_dict_std[spk] = state["spk_dict_std"][spk] + + self.spk_dict_count = state["spk_dict_count"] + + return state + + def to(self, device): + """Puts the needed tensors in the right device. + """ + self = super(InputNormalization, self).to(device) + self.glob_mean = self.glob_mean.to(device) + self.glob_std = self.glob_std.to(device) + for spk in self.spk_dict_mean: + self.spk_dict_mean[spk] = self.spk_dict_mean[spk].to(device) + self.spk_dict_std[spk] = self.spk_dict_std[spk].to(device) + return self + + def save(self, path): + """Save statistic dictionary. + + Args: + path (str): A path where to save the dictionary. + """ + stats = self._statistics_dict() + paddle.save(stats, path) + + def _load(self, path, end_of_epoch=False, device=None): + """Load statistic dictionary. + + Arguments + --------- + path : str + The path of the statistic dictionary + device : str, None + Passed to paddle.load(..., map_location=device) + """ + del end_of_epoch # Unused here. + stats = paddle.load(path, map_location=device) + self._load_statistics_dict(stats) diff --git a/paddlespeech/vector/utils/time.py b/paddlespeech/vector/utils/time.py index 8e85b0e1..f91b5156 100644 --- a/paddlespeech/vector/utils/time.py +++ b/paddlespeech/vector/utils/time.py @@ -23,6 +23,7 @@ class Timer(object): self.last_start_step = 0 self.current_step = 0 self._is_running = True + self.ips = 0 def start(self): self.last_time = time.time() @@ -43,12 +44,17 @@ class Timer(object): self.last_start_step = self.current_step time_used = time.time() - self.last_time self.last_time = time.time() + self.ips = run_steps / time_used return time_used / run_steps @property def is_running(self) -> bool: return self._is_running + @property + def ips(self) -> float: + return self.ips + @property def eta(self) -> str: if not self.is_running: