You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
303 lines
12 KiB
303 lines
12 KiB
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import argparse
|
|
import os
|
|
|
|
import numpy as np
|
|
import paddle
|
|
from paddle.io import BatchSampler
|
|
from paddle.io import DataLoader
|
|
from tqdm import tqdm
|
|
from yacs.config import CfgNode
|
|
|
|
from paddleaudio.metric import compute_eer
|
|
from paddlespeech.s2t.utils.log import Log
|
|
from paddlespeech.vector.io.batch import batch_feature_normalize
|
|
from paddlespeech.vector.io.dataset import CSVDataset
|
|
from paddlespeech.vector.io.embedding_norm import InputNormalization
|
|
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
|
|
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
|
|
from paddlespeech.vector.training.seeding import seed_everything
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
|
|
def compute_dataset_embedding(data_loader, model, mean_var_norm_emb, config,
|
|
id2embedding):
|
|
"""compute the dataset embeddings
|
|
|
|
Args:
|
|
data_loader (paddle.io.Dataloader): the dataset loader to be compute the embedding
|
|
model (paddle.nn.Layer): the speaker verification model
|
|
mean_var_norm_emb : compute the embedding mean and std norm
|
|
config (yacs.config.CfgNode): the yaml config
|
|
"""
|
|
logger.info(
|
|
f'Computing embeddings on {data_loader.dataset.csv_path} dataset')
|
|
with paddle.no_grad():
|
|
for batch_idx, batch in enumerate(tqdm(data_loader)):
|
|
|
|
# stage 8-1: extrac the audio embedding
|
|
ids, feats, lengths = batch['ids'], batch['feats'], batch['lengths']
|
|
embeddings = model.backbone(feats, lengths).squeeze(
|
|
-1) # (N, emb_size, 1) -> (N, emb_size)
|
|
|
|
# Global embedding normalization.
|
|
# if we use the global embedding norm
|
|
# eer can reduece about relative 10%
|
|
if config.global_embedding_norm and mean_var_norm_emb:
|
|
lengths = paddle.ones([embeddings.shape[0]])
|
|
embeddings = mean_var_norm_emb(embeddings, lengths)
|
|
|
|
# Update embedding dict.
|
|
id2embedding.update(dict(zip(ids, embeddings)))
|
|
|
|
|
|
def compute_verification_scores(id2embedding, train_cohort, config):
|
|
"""Compute the verification trial scores
|
|
|
|
Args:
|
|
id2embedding (dict): the utterance embedding
|
|
train_cohort (paddle.tensor): the cohort dataset embedding
|
|
config (yacs.config.CfgNode): the yaml config
|
|
|
|
Returns:
|
|
the scores and the trial labels,
|
|
1 refers the target and 0 refers the nontarget in labels
|
|
"""
|
|
labels = []
|
|
enroll_ids = []
|
|
test_ids = []
|
|
logger.info(f"read the trial from {config.verification_file}")
|
|
cos_sim_func = paddle.nn.CosineSimilarity(axis=-1)
|
|
scores = []
|
|
with open(config.verification_file, 'r') as f:
|
|
for line in f.readlines():
|
|
label, enroll_id, test_id = line.strip().split(' ')
|
|
enroll_id = enroll_id.split('.')[0].replace('/', '-')
|
|
test_id = test_id.split('.')[0].replace('/', '-')
|
|
labels.append(int(label))
|
|
|
|
enroll_emb = id2embedding[enroll_id]
|
|
test_emb = id2embedding[test_id]
|
|
score = cos_sim_func(enroll_emb, test_emb).item()
|
|
|
|
if "score_norm" in config:
|
|
# Getting norm stats for enroll impostors
|
|
enroll_rep = paddle.tile(
|
|
enroll_emb, repeat_times=[train_cohort.shape[0], 1])
|
|
score_e_c = cos_sim_func(enroll_rep, train_cohort)
|
|
if "cohort_size" in config:
|
|
score_e_c, _ = paddle.topk(
|
|
score_e_c, k=config.cohort_size, axis=0)
|
|
mean_e_c = paddle.mean(score_e_c, axis=0)
|
|
std_e_c = paddle.std(score_e_c, axis=0)
|
|
|
|
# Getting norm stats for test impostors
|
|
test_rep = paddle.tile(
|
|
test_emb, repeat_times=[train_cohort.shape[0], 1])
|
|
score_t_c = cos_sim_func(test_rep, train_cohort)
|
|
if "cohort_size" in config:
|
|
score_t_c, _ = paddle.topk(
|
|
score_t_c, k=config.cohort_size, axis=0)
|
|
mean_t_c = paddle.mean(score_t_c, axis=0)
|
|
std_t_c = paddle.std(score_t_c, axis=0)
|
|
|
|
if config.score_norm == "s-norm":
|
|
score_e = (score - mean_e_c) / std_e_c
|
|
score_t = (score - mean_t_c) / std_t_c
|
|
|
|
score = 0.5 * (score_e + score_t)
|
|
elif config.score_norm == "z-norm":
|
|
score = (score - mean_e_c) / std_e_c
|
|
elif config.score_norm == "t-norm":
|
|
score = (score - mean_t_c) / std_t_c
|
|
|
|
scores.append(score)
|
|
|
|
return scores, labels
|
|
|
|
|
|
def main(args, config):
|
|
"""The main process for test the speaker verification model
|
|
|
|
Args:
|
|
args (argparse.Namespace): the command line args namespace
|
|
config (yacs.config.CfgNode): the yaml config
|
|
"""
|
|
|
|
# stage0: set the training device, cpu or gpu
|
|
# if set the gpu, paddlespeech will select a gpu according the env CUDA_VISIBLE_DEVICES
|
|
paddle.set_device(args.device)
|
|
# set the random seed, it is the necessary measures for multiprocess training
|
|
seed_everything(config.seed)
|
|
|
|
# stage1: build the dnn backbone model network
|
|
# we will extract the audio embedding from the backbone model
|
|
ecapa_tdnn = EcapaTdnn(**config.model)
|
|
|
|
# stage2: build the speaker verification eval instance with backbone model
|
|
# because the checkpoint dict name has the SpeakerIdetification prefix
|
|
# so we need to create the SpeakerIdetification instance
|
|
# but we acutally use the backbone model to extact the audio embedding
|
|
model = SpeakerIdetification(
|
|
backbone=ecapa_tdnn, num_class=config.num_speakers)
|
|
|
|
# stage3: load the pre-trained model
|
|
# generally, we get the last model from the epoch
|
|
args.load_checkpoint = os.path.abspath(
|
|
os.path.expanduser(args.load_checkpoint))
|
|
|
|
# load model checkpoint to sid model
|
|
state_dict = paddle.load(
|
|
os.path.join(args.load_checkpoint, 'model.pdparams'))
|
|
model.set_state_dict(state_dict)
|
|
logger.info(f'Checkpoint loaded from {args.load_checkpoint}')
|
|
|
|
# stage4: construct the enroll and test dataloader
|
|
# Now, wo think the enroll dataset is in the {args.data_dir}/vox/csv/enroll.csv,
|
|
# and the test dataset is in the {args.data_dir}/vox/csv/test.csv
|
|
enroll_dataset = CSVDataset(
|
|
os.path.join(args.data_dir, "vox/csv/enroll.csv"),
|
|
feat_type='melspectrogram',
|
|
random_chunk=False,
|
|
n_mels=config.n_mels,
|
|
window_size=config.window_size,
|
|
hop_length=config.hop_size)
|
|
enroll_sampler = BatchSampler(
|
|
enroll_dataset, batch_size=config.batch_size, shuffle=False)
|
|
enroll_loader = DataLoader(enroll_dataset,
|
|
batch_sampler=enroll_sampler,
|
|
collate_fn=lambda x: batch_feature_normalize(
|
|
x, mean_norm=True, std_norm=False),
|
|
num_workers=config.num_workers,
|
|
return_list=True,)
|
|
|
|
test_dataset = CSVDataset(
|
|
os.path.join(args.data_dir, "vox/csv/test.csv"),
|
|
feat_type='melspectrogram',
|
|
random_chunk=False,
|
|
n_mels=config.n_mels,
|
|
window_size=config.window_size,
|
|
hop_length=config.hop_size)
|
|
test_sampler = BatchSampler(
|
|
test_dataset, batch_size=config.batch_size, shuffle=False)
|
|
test_loader = DataLoader(test_dataset,
|
|
batch_sampler=test_sampler,
|
|
collate_fn=lambda x: batch_feature_normalize(
|
|
x, mean_norm=True, std_norm=False),
|
|
num_workers=config.num_workers,
|
|
return_list=True,)
|
|
# stage5: we must set the model to eval mode
|
|
model.eval()
|
|
|
|
# stage6: global embedding norm to imporve the performance
|
|
# and we create the InputNormalization instance to process the embedding mean and std norm
|
|
logger.info(f"global embedding norm: {config.global_embedding_norm}")
|
|
if config.global_embedding_norm:
|
|
mean_var_norm_emb = InputNormalization(
|
|
norm_type="global",
|
|
mean_norm=config.embedding_mean_norm,
|
|
std_norm=config.embedding_std_norm)
|
|
|
|
# stage 7: score norm need the imposters dataset
|
|
# we select the train dataset as the idea imposters dataset
|
|
# and we select the config.n_train_snts utterance to as the final imposters dataset
|
|
if "score_norm" in config:
|
|
logger.info(f"we will do score norm: {config.score_norm}")
|
|
train_dataset = CSVDataset(
|
|
os.path.join(args.data_dir, "vox/csv/train.csv"),
|
|
feat_type='melspectrogram',
|
|
n_train_snts=config.n_train_snts,
|
|
random_chunk=False,
|
|
n_mels=config.n_mels,
|
|
window_size=config.window_size,
|
|
hop_length=config.hop_size)
|
|
train_sampler = BatchSampler(
|
|
train_dataset, batch_size=config.batch_size, shuffle=False)
|
|
train_loader = DataLoader(train_dataset,
|
|
batch_sampler=train_sampler,
|
|
collate_fn=lambda x: batch_feature_normalize(
|
|
x, mean_norm=True, std_norm=False),
|
|
num_workers=config.num_workers,
|
|
return_list=True,)
|
|
|
|
# stage 8: Compute embeddings of audios in enrol and test dataset from model.
|
|
id2embedding = {}
|
|
# Run multi times to make embedding normalization more stable.
|
|
logger.info("First loop for enroll and test dataset")
|
|
compute_dataset_embedding(enroll_loader, model, mean_var_norm_emb, config,
|
|
id2embedding)
|
|
compute_dataset_embedding(test_loader, model, mean_var_norm_emb, config,
|
|
id2embedding)
|
|
|
|
logger.info("Second loop for enroll and test dataset")
|
|
compute_dataset_embedding(enroll_loader, model, mean_var_norm_emb, config,
|
|
id2embedding)
|
|
compute_dataset_embedding(test_loader, model, mean_var_norm_emb, config,
|
|
id2embedding)
|
|
mean_var_norm_emb.save(
|
|
os.path.join(args.load_checkpoint, "mean_var_norm_emb"))
|
|
|
|
# stage 9: Compute cosine scores.
|
|
train_cohort = None
|
|
if "score_norm" in config:
|
|
train_embeddings = {}
|
|
# cohort embedding not do mean and std norm
|
|
compute_dataset_embedding(train_loader, model, None, config,
|
|
train_embeddings)
|
|
train_cohort = paddle.stack(list(train_embeddings.values()))
|
|
|
|
# stage 10: compute the scores
|
|
scores, labels = compute_verification_scores(id2embedding, train_cohort,
|
|
config)
|
|
|
|
# stage 11: compute the EER and threshold
|
|
scores = paddle.to_tensor(scores)
|
|
EER, threshold = compute_eer(np.asarray(labels), scores.numpy())
|
|
logger.info(
|
|
f'EER of verification test: {EER*100:.4f}%, score threshold: {threshold:.5f}'
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# yapf: disable
|
|
parser = argparse.ArgumentParser(__doc__)
|
|
parser.add_argument('--device',
|
|
choices=['cpu', 'gpu'],
|
|
default="gpu",
|
|
help="Select which device to train model, defaults to gpu.")
|
|
parser.add_argument("--config",
|
|
default=None,
|
|
type=str,
|
|
help="configuration file")
|
|
parser.add_argument("--data-dir",
|
|
default="./data/",
|
|
type=str,
|
|
help="data directory")
|
|
parser.add_argument("--load-checkpoint",
|
|
type=str,
|
|
default='',
|
|
help="Directory to load model checkpoint to contiune trainning.")
|
|
args = parser.parse_args()
|
|
# yapf: enable
|
|
# https://yaml.org/type/float.html
|
|
config = CfgNode(new_allowed=True)
|
|
if args.config:
|
|
config.merge_from_file(args.config)
|
|
|
|
config.freeze()
|
|
print(config)
|
|
main(args, config)
|