|
|
|
@ -299,114 +299,6 @@ def _check_missing_model(enroll, test, ndx):
|
|
|
|
|
return clean_ndx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fast_PLDA_scoring(
|
|
|
|
|
enroll,
|
|
|
|
|
test,
|
|
|
|
|
ndx,
|
|
|
|
|
mu,
|
|
|
|
|
F,
|
|
|
|
|
Sigma,
|
|
|
|
|
test_uncertainty=None,
|
|
|
|
|
Vtrans=None,
|
|
|
|
|
p_known=0.0,
|
|
|
|
|
scaling_factor=1.0,
|
|
|
|
|
check_missing=True, ):
|
|
|
|
|
"""
|
|
|
|
|
Compute the PLDA scores between to sets of vectors. The list of
|
|
|
|
|
trials to perform is given in an Ndx object. PLDA matrices have to be
|
|
|
|
|
pre-computed. i-vectors/x-vectors are supposed to be whitened before.
|
|
|
|
|
|
|
|
|
|
Arguments
|
|
|
|
|
---------
|
|
|
|
|
enroll : speechbrain.utils.Xvector_PLDA_sp.StatObject_SB
|
|
|
|
|
A StatServer in which stat1 are xvectors.
|
|
|
|
|
test : speechbrain.utils.Xvector_PLDA_sp.StatObject_SB
|
|
|
|
|
A StatServer in which stat1 are xvectors.
|
|
|
|
|
ndx : speechbrain.utils.Xvector_PLDA_sp.Ndx
|
|
|
|
|
An Ndx object defining the list of trials to perform.
|
|
|
|
|
mu : double
|
|
|
|
|
The mean vector of the PLDA gaussian.
|
|
|
|
|
F : tensor
|
|
|
|
|
The between-class co-variance matrix of the PLDA.
|
|
|
|
|
Sigma: tensor
|
|
|
|
|
The residual covariance matrix.
|
|
|
|
|
p_known : float
|
|
|
|
|
Probability of having a known speaker for open-set
|
|
|
|
|
identification case (=1 for the verification task and =0 for the
|
|
|
|
|
closed-set case).
|
|
|
|
|
check_missing : bool
|
|
|
|
|
If True, check that all models and segments exist.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
enroll_ctr = copy.deepcopy(enroll)
|
|
|
|
|
test_ctr = copy.deepcopy(test)
|
|
|
|
|
|
|
|
|
|
# Remove missing models and test segments
|
|
|
|
|
if check_missing:
|
|
|
|
|
clean_ndx = _check_missing_model(enroll_ctr, test_ctr, ndx)
|
|
|
|
|
else:
|
|
|
|
|
clean_ndx = ndx
|
|
|
|
|
|
|
|
|
|
# Center the i-vectors around the PLDA mean
|
|
|
|
|
enroll_ctr.center_stats(mu)
|
|
|
|
|
test_ctr.center_stats(mu)
|
|
|
|
|
|
|
|
|
|
# Compute constant component of the PLDA distribution
|
|
|
|
|
invSigma = linalg.inv(Sigma)
|
|
|
|
|
I_spk = numpy.eye(F.shape[1], dtype="float")
|
|
|
|
|
|
|
|
|
|
K = F.T.dot(invSigma * scaling_factor).dot(F)
|
|
|
|
|
K1 = linalg.inv(K + I_spk)
|
|
|
|
|
K2 = linalg.inv(2 * K + I_spk)
|
|
|
|
|
|
|
|
|
|
# Compute the Gaussian distribution constant
|
|
|
|
|
alpha1 = numpy.linalg.slogdet(K1)[1]
|
|
|
|
|
alpha2 = numpy.linalg.slogdet(K2)[1]
|
|
|
|
|
plda_cst = alpha2 / 2.0 - alpha1
|
|
|
|
|
|
|
|
|
|
# Compute intermediate matrices
|
|
|
|
|
Sigma_ac = numpy.dot(F, F.T)
|
|
|
|
|
Sigma_tot = Sigma_ac + Sigma
|
|
|
|
|
Sigma_tot_inv = linalg.inv(Sigma_tot)
|
|
|
|
|
|
|
|
|
|
Tmp = linalg.inv(Sigma_tot - Sigma_ac.dot(Sigma_tot_inv).dot(Sigma_ac))
|
|
|
|
|
Phi = Sigma_tot_inv - Tmp
|
|
|
|
|
Psi = Sigma_tot_inv.dot(Sigma_ac).dot(Tmp)
|
|
|
|
|
|
|
|
|
|
# Compute the different parts of PLDA score
|
|
|
|
|
model_part = 0.5 * numpy.einsum("ij, ji->i",
|
|
|
|
|
enroll_ctr.stats.dot(Phi),
|
|
|
|
|
enroll_ctr.stats.T)
|
|
|
|
|
seg_part = 0.5 * numpy.einsum("ij, ji->i",
|
|
|
|
|
test_ctr.stats.dot(Phi), test_ctr.stats.T)
|
|
|
|
|
|
|
|
|
|
# Compute verification scores
|
|
|
|
|
score = Scores() # noqa F821
|
|
|
|
|
score.modelset = clean_ndx.modelset
|
|
|
|
|
score.segset = clean_ndx.segset
|
|
|
|
|
score.scoremask = clean_ndx.trialmask
|
|
|
|
|
|
|
|
|
|
score.scoremat = model_part[:, numpy.newaxis] + seg_part + plda_cst
|
|
|
|
|
score.scoremat += enroll_ctr.stats.dot(Psi).dot(test_ctr.stats.T)
|
|
|
|
|
score.scoremat *= scaling_factor
|
|
|
|
|
|
|
|
|
|
# Case of open-set identification, we compute the log-likelihood
|
|
|
|
|
# by taking into account the probability of having a known impostor
|
|
|
|
|
# or an out-of set class
|
|
|
|
|
if p_known != 0:
|
|
|
|
|
N = score.scoremat.shape[0]
|
|
|
|
|
open_set_scores = numpy.empty(score.scoremat.shape)
|
|
|
|
|
tmp = numpy.exp(score.scoremat)
|
|
|
|
|
for ii in range(N):
|
|
|
|
|
# open-set term
|
|
|
|
|
open_set_scores[ii, :] = score.scoremat[ii, :] - numpy.log(
|
|
|
|
|
p_known * tmp[~(numpy.arange(N) == ii)].sum(axis=0) / (
|
|
|
|
|
N - 1) + (1 - p_known))
|
|
|
|
|
score.scoremat = open_set_scores
|
|
|
|
|
|
|
|
|
|
return score
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PLDA:
|
|
|
|
|
"""
|
|
|
|
|
A class to train PLDA model from embeddings.
|
|
|
|
@ -547,6 +439,105 @@ class PLDA:
|
|
|
|
|
# Minimum Divergence step
|
|
|
|
|
self.F = self.F.dot(linalg.cholesky(_R))
|
|
|
|
|
|
|
|
|
|
def scoring(
|
|
|
|
|
self,
|
|
|
|
|
enroll,
|
|
|
|
|
test,
|
|
|
|
|
ndx,
|
|
|
|
|
test_uncertainty=None,
|
|
|
|
|
Vtrans=None,
|
|
|
|
|
p_known=0.0,
|
|
|
|
|
scaling_factor=1.0,
|
|
|
|
|
check_missing=True, ):
|
|
|
|
|
"""
|
|
|
|
|
Compute the PLDA scores between to sets of vectors. The list of
|
|
|
|
|
trials to perform is given in an Ndx object. PLDA matrices have to be
|
|
|
|
|
pre-computed. i-vectors/x-vectors are supposed to be whitened before.
|
|
|
|
|
|
|
|
|
|
Arguments
|
|
|
|
|
---------
|
|
|
|
|
enroll : paddlespeech.vector.cluster.diarization.EmbeddingMeta
|
|
|
|
|
A EmbeddingMeta in which stats are xvectors.
|
|
|
|
|
test : paddlespeech.vector.cluster.diarization.EmbeddingMeta
|
|
|
|
|
A EmbeddingMeta in which stats are xvectors.
|
|
|
|
|
ndx : paddlespeech.vector.cluster.plda.Ndx
|
|
|
|
|
An Ndx object defining the list of trials to perform.
|
|
|
|
|
p_known : float
|
|
|
|
|
Probability of having a known speaker for open-set
|
|
|
|
|
identification case (=1 for the verification task and =0 for the
|
|
|
|
|
closed-set case).
|
|
|
|
|
check_missing : bool
|
|
|
|
|
If True, check that all models and segments exist.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
enroll_ctr = copy.deepcopy(enroll)
|
|
|
|
|
test_ctr = copy.deepcopy(test)
|
|
|
|
|
|
|
|
|
|
# Remove missing models and test segments
|
|
|
|
|
if check_missing:
|
|
|
|
|
clean_ndx = _check_missing_model(enroll_ctr, test_ctr, ndx)
|
|
|
|
|
else:
|
|
|
|
|
clean_ndx = ndx
|
|
|
|
|
|
|
|
|
|
# Center the i-vectors around the PLDA mean
|
|
|
|
|
enroll_ctr.center_stats(self.mean)
|
|
|
|
|
test_ctr.center_stats(self.mean)
|
|
|
|
|
|
|
|
|
|
# Compute constant component of the PLDA distribution
|
|
|
|
|
invSigma = linalg.inv(self.Sigma)
|
|
|
|
|
I_spk = numpy.eye(self.F.shape[1], dtype="float")
|
|
|
|
|
|
|
|
|
|
K = self.F.T.dot(invSigma * scaling_factor).dot(self.F)
|
|
|
|
|
K1 = linalg.inv(K + I_spk)
|
|
|
|
|
K2 = linalg.inv(2 * K + I_spk)
|
|
|
|
|
|
|
|
|
|
# Compute the Gaussian distribution constant
|
|
|
|
|
alpha1 = numpy.linalg.slogdet(K1)[1]
|
|
|
|
|
alpha2 = numpy.linalg.slogdet(K2)[1]
|
|
|
|
|
plda_cst = alpha2 / 2.0 - alpha1
|
|
|
|
|
|
|
|
|
|
# Compute intermediate matrices
|
|
|
|
|
Sigma_ac = numpy.dot(self.F, self.F.T)
|
|
|
|
|
Sigma_tot = Sigma_ac + self.Sigma
|
|
|
|
|
Sigma_tot_inv = linalg.inv(Sigma_tot)
|
|
|
|
|
|
|
|
|
|
Tmp = linalg.inv(Sigma_tot - Sigma_ac.dot(Sigma_tot_inv).dot(Sigma_ac))
|
|
|
|
|
Phi = Sigma_tot_inv - Tmp
|
|
|
|
|
Psi = Sigma_tot_inv.dot(Sigma_ac).dot(Tmp)
|
|
|
|
|
|
|
|
|
|
# Compute the different parts of PLDA score
|
|
|
|
|
model_part = 0.5 * numpy.einsum("ij, ji->i",
|
|
|
|
|
enroll_ctr.stats.dot(Phi),
|
|
|
|
|
enroll_ctr.stats.T)
|
|
|
|
|
seg_part = 0.5 * numpy.einsum("ij, ji->i",
|
|
|
|
|
test_ctr.stats.dot(Phi), test_ctr.stats.T)
|
|
|
|
|
|
|
|
|
|
# Compute verification scores
|
|
|
|
|
score = Scores() # noqa F821
|
|
|
|
|
score.modelset = clean_ndx.modelset
|
|
|
|
|
score.segset = clean_ndx.segset
|
|
|
|
|
score.scoremask = clean_ndx.trialmask
|
|
|
|
|
|
|
|
|
|
score.scoremat = model_part[:, numpy.newaxis] + seg_part + plda_cst
|
|
|
|
|
score.scoremat += enroll_ctr.stats.dot(Psi).dot(test_ctr.stats.T)
|
|
|
|
|
score.scoremat *= scaling_factor
|
|
|
|
|
|
|
|
|
|
# Case of open-set identification, we compute the log-likelihood
|
|
|
|
|
# by taking into account the probability of having a known impostor
|
|
|
|
|
# or an out-of set class
|
|
|
|
|
if p_known != 0:
|
|
|
|
|
N = score.scoremat.shape[0]
|
|
|
|
|
open_set_scores = numpy.empty(score.scoremat.shape)
|
|
|
|
|
tmp = numpy.exp(score.scoremat)
|
|
|
|
|
for ii in range(N):
|
|
|
|
|
# open-set term
|
|
|
|
|
open_set_scores[ii, :] = score.scoremat[ii, :] - numpy.log(
|
|
|
|
|
p_known * tmp[~(numpy.arange(N) == ii)].sum(axis=0) / (
|
|
|
|
|
N - 1) + (1 - p_known))
|
|
|
|
|
score.scoremat = open_set_scores
|
|
|
|
|
|
|
|
|
|
return score
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
import random
|
|
|
|
@ -580,6 +571,5 @@ if __name__ == '__main__':
|
|
|
|
|
te_stat = EmbeddingMeta(modelset=te_sets, segset=te_sets, stats=te_xv)
|
|
|
|
|
ndx = Ndx(models=en_sets, testsegs=te_sets)
|
|
|
|
|
# PLDA Scoring
|
|
|
|
|
scores_plda = fast_PLDA_scoring(en_stat, te_stat, ndx, plda.mean, plda.F,
|
|
|
|
|
plda.Sigma)
|
|
|
|
|
scores_plda = plda.scoring(en_stat, te_stat, ndx)
|
|
|
|
|
print(scores_plda.scoremat.shape) #(20, 30)
|
|
|
|
|