You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/paddlespeech/vector/exps/ecapa_tdnn/test.py

207 lines
8.5 KiB

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import os
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle.io import BatchSampler
from paddle.io import DataLoader
from tqdm import tqdm
from yacs.config import CfgNode
from paddleaudio.datasets import VoxCeleb
from paddleaudio.metric import compute_eer
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.batch import batch_feature_normalize
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
from paddlespeech.vector.training.seeding import seed_everything
logger = Log(__name__).getlog()
def main(args, config):
# stage0: set the training device, cpu or gpu
paddle.set_device(args.device)
# set the random seed, it is a must for multiprocess training
seed_everything(config.seed)
# stage1: build the dnn backbone model network
ecapa_tdnn = EcapaTdnn(**config.model)
# stage2: build the speaker verification eval instance with backbone model
model = SpeakerIdetification(
backbone=ecapa_tdnn, num_class=config.num_speakers)
# stage3: load the pre-trained model
# we get the last model from the epoch and save_interval
last_save_epoch = (config.epochs // config.save_interval) * config.save_interval
args.load_checkpoint = os.path.join(args.load_checkpoint, "epoch_" + str(last_save_epoch))
args.load_checkpoint = os.path.abspath(
os.path.expanduser(args.load_checkpoint))
# load model checkpoint to sid model
state_dict = paddle.load(
os.path.join(args.load_checkpoint, 'model.pdparams'))
model.set_state_dict(state_dict)
logger.info(f'Checkpoint loaded from {args.load_checkpoint}')
# stage4: construct the enroll and test dataloader
enroll_dataset = VoxCeleb(
subset='enroll',
target_dir=args.data_dir,
feat_type='melspectrogram',
random_chunk=False,
n_mels=config.n_mels,
window_size=config.window_size,
hop_length=config.hop_length)
enroll_sampler = BatchSampler(
enroll_dataset, batch_size=config.batch_size,
shuffle=True) # Shuffle to make embedding normalization more robust.
enrol_loader = DataLoader(enroll_dataset,
batch_sampler=enroll_sampler,
collate_fn=lambda x: batch_feature_normalize(
x, mean_norm=True, std_norm=False),
num_workers=config.num_workers,
return_list=True,)
test_dataset = VoxCeleb(
subset='test',
target_dir=args.data_dir,
feat_type='melspectrogram',
random_chunk=False,
n_mels=config.n_mels,
window_size=config.window_size,
hop_length=config.hop_length)
test_sampler = BatchSampler(
test_dataset, batch_size=config.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset,
batch_sampler=test_sampler,
collate_fn=lambda x: batch_feature_normalize(
x, mean_norm=True, std_norm=False),
num_workers=config.num_workers,
return_list=True,)
# stage5: we must set the model to eval mode
model.eval()
# stage6: global embedding norm to imporve the performance
logger.info(f"global embedding norm: {config.global_embedding_norm}")
if config.global_embedding_norm:
global_embedding_mean = None
global_embedding_std = None
mean_norm_flag = config.embedding_mean_norm
std_norm_flag = config.embedding_std_norm
batch_count = 0
# stage7: Compute embeddings of audios in enrol and test dataset from model.
id2embedding = {}
# Run multi times to make embedding normalization more stable.
for i in range(2):
for dl in [enrol_loader, test_loader]:
logger.info(
f'Loop {[i+1]}: Computing embeddings on {dl.dataset.subset} dataset'
)
with paddle.no_grad():
for batch_idx, batch in enumerate(tqdm(dl)):
# stage 8-1: extrac the audio embedding
ids, feats, lengths = batch['ids'], batch['feats'], batch[
'lengths']
embeddings = model.backbone(feats, lengths).squeeze(
-1).numpy() # (N, emb_size, 1) -> (N, emb_size)
# Global embedding normalization.
# if we use the global embedding norm
# eer can reduece about relative 10%
if config.global_embedding_norm:
batch_count += 1
current_mean = embeddings.mean(
axis=0) if mean_norm_flag else 0
current_std = embeddings.std(
axis=0) if std_norm_flag else 1
# Update global mean and std.
if global_embedding_mean is None and global_embedding_std is None:
global_embedding_mean, global_embedding_std = current_mean, current_std
else:
weight = 1 / batch_count # Weight decay by batches.
global_embedding_mean = (
1 - weight
) * global_embedding_mean + weight * current_mean
global_embedding_std = (
1 - weight
) * global_embedding_std + weight * current_std
# Apply global embedding normalization.
embeddings = (embeddings - global_embedding_mean
) / global_embedding_std
# Update embedding dict.
id2embedding.update(dict(zip(ids, embeddings)))
# stage 8: Compute cosine scores.
labels = []
enroll_ids = []
test_ids = []
logger.info(f"read the trial from {VoxCeleb.veri_test_file}")
with open(VoxCeleb.veri_test_file, 'r') as f:
for line in f.readlines():
label, enroll_id, test_id = line.strip().split(' ')
labels.append(int(label))
enroll_ids.append(enroll_id.split('.')[0].replace('/', '-'))
test_ids.append(test_id.split('.')[0].replace('/', '-'))
cos_sim_func = paddle.nn.CosineSimilarity(axis=1)
enrol_embeddings, test_embeddings = map(lambda ids: paddle.to_tensor(
np.asarray([id2embedding[uttid] for uttid in ids], dtype='float32')),
[enroll_ids, test_ids
]) # (N, emb_size)
scores = cos_sim_func(enrol_embeddings, test_embeddings)
EER, threshold = compute_eer(np.asarray(labels), scores.numpy())
logger.info(
f'EER of verification test: {EER*100:.4f}%, score threshold: {threshold:.5f}'
)
if __name__ == "__main__":
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument('--device',
choices=['cpu', 'gpu'],
default="gpu",
help="Select which device to train model, defaults to gpu.")
parser.add_argument("--config",
default=None,
type=str,
help="configuration file")
parser.add_argument("--data-dir",
default="./data/",
type=str,
help="data directory")
parser.add_argument("--load-checkpoint",
type=str,
default='',
help="Directory to load model checkpoint to contiune trainning.")
args = parser.parse_args()
# yapf: enable
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
config.freeze()
print(config)
main(args, config)