101 lines
4.0 KiB
101 lines
4.0 KiB
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from typing import List
|
|
|
|
import numpy as np
|
|
import paddle
|
|
from sklearn.metrics import roc_curve
|
|
|
|
|
|
def compute_eer(labels: np.ndarray, scores: np.ndarray) -> List[float]:
|
|
"""Compute EER and return score threshold.
|
|
|
|
Args:
|
|
labels (np.ndarray): the trial label, shape: [N], one-dimention, N refer to the samples num
|
|
scores (np.ndarray): the trial scores, shape: [N], one-dimention, N refer to the samples num
|
|
|
|
Returns:
|
|
List[float]: eer and the specific threshold
|
|
"""
|
|
fpr, tpr, threshold = roc_curve(y_true=labels, y_score=scores)
|
|
fnr = 1 - tpr
|
|
eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))]
|
|
eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
|
return eer, eer_threshold
|
|
|
|
|
|
def compute_minDCF(positive_scores,
|
|
negative_scores,
|
|
c_miss=1.0,
|
|
c_fa=1.0,
|
|
p_target=0.01):
|
|
"""
|
|
This is modified from SpeechBrain
|
|
https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/utils/metric_stats.py#L509
|
|
Computes the minDCF metric normally used to evaluate speaker verification
|
|
systems. The min_DCF is the minimum of the following C_det function computed
|
|
within the defined threshold range:
|
|
|
|
C_det = c_miss * p_miss * p_target + c_fa * p_fa * (1 -p_target)
|
|
|
|
where p_miss is the missing probability and p_fa is the probability of having
|
|
a false alarm.
|
|
|
|
Args:
|
|
positive_scores (Paddle.Tensor): The scores from entries of the same class.
|
|
negative_scores (Paddle.Tensor): The scores from entries of different classes.
|
|
c_miss (float, optional): Cost assigned to a missing error (default 1.0).
|
|
c_fa (float, optional): Cost assigned to a false alarm (default 1.0).
|
|
p_target (float, optional): Prior probability of having a target (default 0.01).
|
|
|
|
Returns:
|
|
List[float]: min dcf and the specific threshold
|
|
"""
|
|
# Computing candidate thresholds
|
|
if len(positive_scores.shape) > 1:
|
|
positive_scores = positive_scores.squeeze()
|
|
|
|
if len(negative_scores.shape) > 1:
|
|
negative_scores = negative_scores.squeeze()
|
|
|
|
thresholds = paddle.sort(paddle.concat([positive_scores, negative_scores]))
|
|
thresholds = paddle.unique(thresholds)
|
|
|
|
# Adding intermediate thresholds
|
|
interm_thresholds = (thresholds[0:-1] + thresholds[1:]) / 2
|
|
thresholds = paddle.sort(paddle.concat([thresholds, interm_thresholds]))
|
|
|
|
# Computing False Rejection Rate (miss detection)
|
|
positive_scores = paddle.concat(
|
|
len(thresholds) * [positive_scores.unsqueeze(0)])
|
|
pos_scores_threshold = positive_scores.transpose(perm=[1, 0]) <= thresholds
|
|
p_miss = (pos_scores_threshold.sum(0)
|
|
).astype("float32") / positive_scores.shape[1]
|
|
del positive_scores
|
|
del pos_scores_threshold
|
|
|
|
# Computing False Acceptance Rate (false alarm)
|
|
negative_scores = paddle.concat(
|
|
len(thresholds) * [negative_scores.unsqueeze(0)])
|
|
neg_scores_threshold = negative_scores.transpose(perm=[1, 0]) > thresholds
|
|
p_fa = (neg_scores_threshold.sum(0)
|
|
).astype("float32") / negative_scores.shape[1]
|
|
del negative_scores
|
|
del neg_scores_threshold
|
|
|
|
c_det = c_miss * p_miss * p_target + c_fa * p_fa * (1 - p_target)
|
|
c_min = paddle.min(c_det, axis=0)
|
|
min_index = paddle.argmin(c_det, axis=0)
|
|
return float(c_min), float(thresholds[min_index])
|