commit
880829fe89
@ -0,0 +1,575 @@
|
||||
# Copyright (c) 2022 PaddlePaddle and SpeechBrain Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""A popular speaker recognition/diarization model (LDA and PLDA).
|
||||
|
||||
Relevant Papers
|
||||
- This implementation of PLDA is based on the following papers.
|
||||
|
||||
- PLDA model Training
|
||||
* Ye Jiang et. al, "PLDA Modeling in I-Vector and Supervector Space for Speaker Verification," in Interspeech, 2012.
|
||||
* Patrick Kenny et. al, "PLDA for speaker verification with utterances of arbitrary duration," in ICASSP, 2013.
|
||||
|
||||
- PLDA scoring (fast scoring)
|
||||
* Daniel Garcia-Romero et. al, “Analysis of i-vector length normalization in speaker recognition systems,” in Interspeech, 2011.
|
||||
* Weiwei-LIN et. al, "Fast Scoring for PLDA with Uncertainty Propagation," in Odyssey, 2016.
|
||||
* Kong Aik Lee et. al, "Multi-session PLDA Scoring of I-vector for Partially Open-Set Speaker Detection," in Interspeech 2013.
|
||||
|
||||
Credits
|
||||
This code is adapted from: https://git-lium.univ-lemans.fr/Larcher/sidekit
|
||||
"""
|
||||
import copy
|
||||
import pickle
|
||||
|
||||
import numpy
|
||||
from scipy import linalg
|
||||
|
||||
from paddlespeech.vector.cluster.diarization import EmbeddingMeta
|
||||
|
||||
|
||||
def ismember(list1, list2):
|
||||
c = [item in list2 for item in list1]
|
||||
return c
|
||||
|
||||
|
||||
class Ndx:
|
||||
"""
|
||||
A class that encodes trial index information. It has a list of
|
||||
model names and a list of test segment names and a matrix
|
||||
indicating which combinations of model and test segment are
|
||||
trials of interest.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
modelset : list
|
||||
List of unique models in a ndarray.
|
||||
segset : list
|
||||
List of unique test segments in a ndarray.
|
||||
trialmask : 2D ndarray of bool.
|
||||
Rows correspond to the models and columns to the test segments. True, if the trial is of interest.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ndx_file_name="",
|
||||
models=numpy.array([]),
|
||||
testsegs=numpy.array([])):
|
||||
"""
|
||||
Initialize a Ndx object by loading information from a file.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
ndx_file_name : str
|
||||
Name of the file to load.
|
||||
"""
|
||||
self.modelset = numpy.empty(0, dtype="|O")
|
||||
self.segset = numpy.empty(0, dtype="|O")
|
||||
self.trialmask = numpy.array([], dtype="bool")
|
||||
|
||||
if ndx_file_name == "":
|
||||
# This is needed to make sizes same
|
||||
d = models.shape[0] - testsegs.shape[0]
|
||||
if d != 0:
|
||||
if d > 0:
|
||||
last = str(testsegs[-1])
|
||||
pad = numpy.array([last] * d)
|
||||
testsegs = numpy.hstack((testsegs, pad))
|
||||
# pad = testsegs[-d:]
|
||||
# testsegs = numpy.concatenate((testsegs, pad), axis=1)
|
||||
else:
|
||||
d = abs(d)
|
||||
last = str(models[-1])
|
||||
pad = numpy.array([last] * d)
|
||||
models = numpy.hstack((models, pad))
|
||||
# pad = models[-d:]
|
||||
# models = numpy.concatenate((models, pad), axis=1)
|
||||
|
||||
modelset = numpy.unique(models)
|
||||
segset = numpy.unique(testsegs)
|
||||
|
||||
trialmask = numpy.zeros(
|
||||
(modelset.shape[0], segset.shape[0]), dtype="bool")
|
||||
for m in range(modelset.shape[0]):
|
||||
segs = testsegs[numpy.array(ismember(models, modelset[m]))]
|
||||
trialmask[m, ] = ismember(segset, segs) # noqa E231
|
||||
|
||||
self.modelset = modelset
|
||||
self.segset = segset
|
||||
self.trialmask = trialmask
|
||||
assert self.validate(), "Wrong Ndx format"
|
||||
|
||||
else:
|
||||
ndx = Ndx.read(ndx_file_name)
|
||||
self.modelset = ndx.modelset
|
||||
self.segset = ndx.segset
|
||||
self.trialmask = ndx.trialmask
|
||||
|
||||
def save_ndx_object(self, output_file_name):
|
||||
with open(output_file_name, "wb") as output:
|
||||
pickle.dump(self, output, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def filter(self, modlist, seglist, keep):
|
||||
"""
|
||||
Removes some of the information in an Ndx. Useful for creating a
|
||||
gender specific Ndx from a pooled gender Ndx. Depending on the
|
||||
value of \'keep\', the two input lists indicate the strings to
|
||||
retain or the strings to discard.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
modlist : array
|
||||
A cell array of strings which will be compared with the modelset of 'inndx'.
|
||||
seglist : array
|
||||
A cell array of strings which will be compared with the segset of 'inndx'.
|
||||
keep : bool
|
||||
Indicating whether modlist and seglist are the models to keep or discard.
|
||||
"""
|
||||
if keep:
|
||||
keepmods = modlist
|
||||
keepsegs = seglist
|
||||
else:
|
||||
keepmods = diff(self.modelset, modlist)
|
||||
keepsegs = diff(self.segset, seglist)
|
||||
|
||||
keepmodidx = numpy.array(ismember(self.modelset, keepmods))
|
||||
keepsegidx = numpy.array(ismember(self.segset, keepsegs))
|
||||
|
||||
outndx = Ndx()
|
||||
outndx.modelset = self.modelset[keepmodidx]
|
||||
outndx.segset = self.segset[keepsegidx]
|
||||
tmp = self.trialmask[numpy.array(keepmodidx), :]
|
||||
outndx.trialmask = tmp[:, numpy.array(keepsegidx)]
|
||||
|
||||
assert outndx.validate, "Wrong Ndx format"
|
||||
|
||||
if self.modelset.shape[0] > outndx.modelset.shape[0]:
|
||||
print(
|
||||
"Number of models reduced from %d to %d" %
|
||||
self.modelset.shape[0],
|
||||
outndx.modelset.shape[0], )
|
||||
if self.segset.shape[0] > outndx.segset.shape[0]:
|
||||
print(
|
||||
"Number of test segments reduced from %d to %d",
|
||||
self.segset.shape[0],
|
||||
outndx.segset.shape[0], )
|
||||
return outndx
|
||||
|
||||
def validate(self):
|
||||
"""
|
||||
Checks that an object of type Ndx obeys certain rules that
|
||||
must always be true. Returns a boolean value indicating whether the object is valid
|
||||
"""
|
||||
ok = isinstance(self.modelset, numpy.ndarray)
|
||||
ok &= isinstance(self.segset, numpy.ndarray)
|
||||
ok &= isinstance(self.trialmask, numpy.ndarray)
|
||||
|
||||
ok &= self.modelset.ndim == 1
|
||||
ok &= self.segset.ndim == 1
|
||||
ok &= self.trialmask.ndim == 2
|
||||
|
||||
ok &= self.trialmask.shape == (self.modelset.shape[0],
|
||||
self.segset.shape[0], )
|
||||
return ok
|
||||
|
||||
|
||||
class Scores:
|
||||
"""
|
||||
A class for storing scores for trials. The modelset and segset
|
||||
fields are lists of model and test segment names respectively.
|
||||
The element i,j of scoremat and scoremask corresponds to the
|
||||
trial involving model i and test segment j.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
modelset : list
|
||||
List of unique models in a ndarray.
|
||||
segset : list
|
||||
List of unique test segments in a ndarray.
|
||||
scoremask : 2D ndarray of bool
|
||||
Indicates the trials of interest, i.e.,
|
||||
the entry i,j in scoremat should be ignored if scoremask[i,j] is False.
|
||||
scoremat : 2D ndarray
|
||||
Scores matrix.
|
||||
"""
|
||||
|
||||
def __init__(self, scores_file_name=""):
|
||||
"""
|
||||
Initialize a Scores object by loading information from a file HDF5 format.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
scores_file_name : str
|
||||
Name of the file to load.
|
||||
"""
|
||||
self.modelset = numpy.empty(0, dtype="|O")
|
||||
self.segset = numpy.empty(0, dtype="|O")
|
||||
self.scoremask = numpy.array([], dtype="bool")
|
||||
self.scoremat = numpy.array([])
|
||||
|
||||
if scores_file_name == "":
|
||||
pass
|
||||
else:
|
||||
tmp = Scores.read(scores_file_name)
|
||||
self.modelset = tmp.modelset
|
||||
self.segset = tmp.segset
|
||||
self.scoremask = tmp.scoremask
|
||||
self.scoremat = tmp.scoremat
|
||||
|
||||
def __repr__(self):
|
||||
ch = "modelset:\n"
|
||||
ch += self.modelset + "\n"
|
||||
ch += "segset:\n"
|
||||
ch += self.segset + "\n"
|
||||
ch += "scoremask:\n"
|
||||
ch += self.scoremask.__repr__() + "\n"
|
||||
ch += "scoremat:\n"
|
||||
ch += self.scoremat.__repr__() + "\n"
|
||||
|
||||
|
||||
def fa_model_loop(
|
||||
batch_start,
|
||||
mini_batch_indices,
|
||||
factor_analyser,
|
||||
stat0,
|
||||
stats,
|
||||
e_h,
|
||||
e_hh, ):
|
||||
"""
|
||||
A function for PLDA estimation.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
batch_start : int
|
||||
Index to start at in the list.
|
||||
mini_batch_indices : list
|
||||
Indices of the elements in the list (should start at zero).
|
||||
factor_analyser : instance of PLDA class
|
||||
PLDA class object.
|
||||
stat0 : tensor
|
||||
Matrix of zero-order statistics.
|
||||
stats: tensor
|
||||
Matrix of first-order statistics.
|
||||
e_h : tensor
|
||||
An accumulator matrix.
|
||||
e_hh: tensor
|
||||
An accumulator matrix.
|
||||
"""
|
||||
rank = factor_analyser.F.shape[1]
|
||||
if factor_analyser.Sigma.ndim == 2:
|
||||
A = factor_analyser.F.T.dot(factor_analyser.F)
|
||||
inv_lambda_unique = dict()
|
||||
for sess in numpy.unique(stat0[:, 0]):
|
||||
inv_lambda_unique[sess] = linalg.inv(sess * A + numpy.eye(A.shape[
|
||||
0]))
|
||||
|
||||
tmp = numpy.zeros(
|
||||
(factor_analyser.F.shape[1], factor_analyser.F.shape[1]),
|
||||
dtype=numpy.float64, )
|
||||
|
||||
for idx in mini_batch_indices:
|
||||
if factor_analyser.Sigma.ndim == 1:
|
||||
inv_lambda = linalg.inv(
|
||||
numpy.eye(rank) + (factor_analyser.F.T * stat0[
|
||||
idx + batch_start, :]).dot(factor_analyser.F))
|
||||
else:
|
||||
inv_lambda = inv_lambda_unique[stat0[idx + batch_start, 0]]
|
||||
|
||||
aux = factor_analyser.F.T.dot(stats[idx + batch_start, :])
|
||||
numpy.dot(aux, inv_lambda, out=e_h[idx])
|
||||
e_hh[idx] = inv_lambda + numpy.outer(e_h[idx], e_h[idx], tmp)
|
||||
|
||||
|
||||
def _check_missing_model(enroll, test, ndx):
|
||||
# Remove missing models and test segments
|
||||
clean_ndx = ndx.filter(enroll.modelset, test.segset, True)
|
||||
|
||||
# Align EmbeddingMeta to match the clean_ndx
|
||||
enroll.align_models(clean_ndx.modelset)
|
||||
test.align_segments(clean_ndx.segset)
|
||||
|
||||
return clean_ndx
|
||||
|
||||
|
||||
class PLDA:
|
||||
"""
|
||||
A class to train PLDA model from embeddings.
|
||||
|
||||
The input is in paddlespeech.vector.cluster.diarization.EmbeddingMeta format.
|
||||
Trains a simplified PLDA model no within-class covariance matrix but full residual covariance matrix.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
mean : tensor
|
||||
Mean of the vectors.
|
||||
F : tensor
|
||||
Eigenvoice matrix.
|
||||
Sigma : tensor
|
||||
Residual matrix.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mean=None,
|
||||
F=None,
|
||||
Sigma=None,
|
||||
rank_f=100,
|
||||
nb_iter=10,
|
||||
scaling_factor=1.0, ):
|
||||
self.mean = None
|
||||
self.F = None
|
||||
self.Sigma = None
|
||||
self.rank_f = rank_f
|
||||
self.nb_iter = nb_iter
|
||||
self.scaling_factor = scaling_factor
|
||||
|
||||
if mean is not None:
|
||||
self.mean = mean
|
||||
if F is not None:
|
||||
self.F = F
|
||||
if Sigma is not None:
|
||||
self.Sigma = Sigma
|
||||
|
||||
def plda(
|
||||
self,
|
||||
emb_meta=None,
|
||||
output_file_name=None, ):
|
||||
"""
|
||||
Trains PLDA model with no within class covariance matrix but full residual covariance matrix.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
emb_meta : paddlespeech.vector.cluster.diarization.EmbeddingMeta
|
||||
Contains vectors and meta-information to perform PLDA
|
||||
rank_f : int
|
||||
Rank of the between-class covariance matrix.
|
||||
nb_iter : int
|
||||
Number of iterations to run.
|
||||
scaling_factor : float
|
||||
Scaling factor to downscale statistics (value between 0 and 1).
|
||||
output_file_name : str
|
||||
Name of the output file where to store PLDA model.
|
||||
"""
|
||||
|
||||
# Dimension of the vector (x-vectors stored in stats)
|
||||
vect_size = emb_meta.stats.shape[1]
|
||||
|
||||
# Initialize mean and residual covariance from the training data
|
||||
self.mean = emb_meta.get_mean_stats()
|
||||
self.Sigma = emb_meta.get_total_covariance_stats()
|
||||
|
||||
# Sum stat0 and stat1 for each speaker model
|
||||
model_shifted_stat, session_per_model = emb_meta.sum_stat_per_model()
|
||||
|
||||
# Number of speakers (classes) in training set
|
||||
class_nb = model_shifted_stat.modelset.shape[0]
|
||||
|
||||
# Multiply statistics by scaling_factor
|
||||
model_shifted_stat.stat0 *= self.scaling_factor
|
||||
model_shifted_stat.stats *= self.scaling_factor
|
||||
session_per_model *= self.scaling_factor
|
||||
|
||||
# Covariance for stats
|
||||
sigma_obs = emb_meta.get_total_covariance_stats()
|
||||
evals, evecs = linalg.eigh(sigma_obs)
|
||||
|
||||
# Initial F (eigen voice matrix) from rank
|
||||
idx = numpy.argsort(evals)[::-1]
|
||||
evecs = evecs.real[:, idx[:self.rank_f]]
|
||||
self.F = evecs[:, :self.rank_f]
|
||||
|
||||
# Estimate PLDA model by iterating the EM algorithm
|
||||
for it in range(self.nb_iter):
|
||||
|
||||
# E-step
|
||||
|
||||
# Copy stats as they will be whitened with a different Sigma for each iteration
|
||||
local_stat = copy.deepcopy(model_shifted_stat)
|
||||
|
||||
# Whiten statistics (with the new mean and Sigma)
|
||||
local_stat.whiten_stats(self.mean, self.Sigma)
|
||||
|
||||
# Whiten the EigenVoice matrix
|
||||
eigen_values, eigen_vectors = linalg.eigh(self.Sigma)
|
||||
ind = eigen_values.real.argsort()[::-1]
|
||||
eigen_values = eigen_values.real[ind]
|
||||
eigen_vectors = eigen_vectors.real[:, ind]
|
||||
sqr_inv_eval_sigma = 1 / numpy.sqrt(eigen_values.real)
|
||||
sqr_inv_sigma = numpy.dot(eigen_vectors,
|
||||
numpy.diag(sqr_inv_eval_sigma))
|
||||
self.F = sqr_inv_sigma.T.dot(self.F)
|
||||
|
||||
# Replicate self.stat0
|
||||
index_map = numpy.zeros(vect_size, dtype=int)
|
||||
_stat0 = local_stat.stat0[:, index_map]
|
||||
|
||||
e_h = numpy.zeros((class_nb, self.rank_f))
|
||||
e_hh = numpy.zeros((class_nb, self.rank_f, self.rank_f))
|
||||
|
||||
# loop on model id's
|
||||
fa_model_loop(
|
||||
batch_start=0,
|
||||
mini_batch_indices=numpy.arange(class_nb),
|
||||
factor_analyser=self,
|
||||
stat0=_stat0,
|
||||
stats=local_stat.stats,
|
||||
e_h=e_h,
|
||||
e_hh=e_hh, )
|
||||
|
||||
# Accumulate for minimum divergence step
|
||||
_R = numpy.sum(e_hh, axis=0) / session_per_model.shape[0]
|
||||
|
||||
_C = e_h.T.dot(local_stat.stats).dot(linalg.inv(sqr_inv_sigma))
|
||||
_A = numpy.einsum("ijk,i->jk", e_hh, local_stat.stat0.squeeze())
|
||||
|
||||
# M-step
|
||||
self.F = linalg.solve(_A, _C).T
|
||||
|
||||
# Update the residual covariance
|
||||
self.Sigma = sigma_obs - self.F.dot(_C) / session_per_model.sum()
|
||||
|
||||
# Minimum Divergence step
|
||||
self.F = self.F.dot(linalg.cholesky(_R))
|
||||
|
||||
def scoring(
|
||||
self,
|
||||
enroll,
|
||||
test,
|
||||
ndx,
|
||||
test_uncertainty=None,
|
||||
Vtrans=None,
|
||||
p_known=0.0,
|
||||
scaling_factor=1.0,
|
||||
check_missing=True, ):
|
||||
"""
|
||||
Compute the PLDA scores between to sets of vectors. The list of
|
||||
trials to perform is given in an Ndx object. PLDA matrices have to be
|
||||
pre-computed. i-vectors/x-vectors are supposed to be whitened before.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
enroll : paddlespeech.vector.cluster.diarization.EmbeddingMeta
|
||||
A EmbeddingMeta in which stats are xvectors.
|
||||
test : paddlespeech.vector.cluster.diarization.EmbeddingMeta
|
||||
A EmbeddingMeta in which stats are xvectors.
|
||||
ndx : paddlespeech.vector.cluster.plda.Ndx
|
||||
An Ndx object defining the list of trials to perform.
|
||||
p_known : float
|
||||
Probability of having a known speaker for open-set
|
||||
identification case (=1 for the verification task and =0 for the
|
||||
closed-set case).
|
||||
check_missing : bool
|
||||
If True, check that all models and segments exist.
|
||||
"""
|
||||
|
||||
enroll_ctr = copy.deepcopy(enroll)
|
||||
test_ctr = copy.deepcopy(test)
|
||||
|
||||
# Remove missing models and test segments
|
||||
if check_missing:
|
||||
clean_ndx = _check_missing_model(enroll_ctr, test_ctr, ndx)
|
||||
else:
|
||||
clean_ndx = ndx
|
||||
|
||||
# Center the i-vectors around the PLDA mean
|
||||
enroll_ctr.center_stats(self.mean)
|
||||
test_ctr.center_stats(self.mean)
|
||||
|
||||
# Compute constant component of the PLDA distribution
|
||||
invSigma = linalg.inv(self.Sigma)
|
||||
I_spk = numpy.eye(self.F.shape[1], dtype="float")
|
||||
|
||||
K = self.F.T.dot(invSigma * scaling_factor).dot(self.F)
|
||||
K1 = linalg.inv(K + I_spk)
|
||||
K2 = linalg.inv(2 * K + I_spk)
|
||||
|
||||
# Compute the Gaussian distribution constant
|
||||
alpha1 = numpy.linalg.slogdet(K1)[1]
|
||||
alpha2 = numpy.linalg.slogdet(K2)[1]
|
||||
plda_cst = alpha2 / 2.0 - alpha1
|
||||
|
||||
# Compute intermediate matrices
|
||||
Sigma_ac = numpy.dot(self.F, self.F.T)
|
||||
Sigma_tot = Sigma_ac + self.Sigma
|
||||
Sigma_tot_inv = linalg.inv(Sigma_tot)
|
||||
|
||||
Tmp = linalg.inv(Sigma_tot - Sigma_ac.dot(Sigma_tot_inv).dot(Sigma_ac))
|
||||
Phi = Sigma_tot_inv - Tmp
|
||||
Psi = Sigma_tot_inv.dot(Sigma_ac).dot(Tmp)
|
||||
|
||||
# Compute the different parts of PLDA score
|
||||
model_part = 0.5 * numpy.einsum("ij, ji->i",
|
||||
enroll_ctr.stats.dot(Phi),
|
||||
enroll_ctr.stats.T)
|
||||
seg_part = 0.5 * numpy.einsum("ij, ji->i",
|
||||
test_ctr.stats.dot(Phi), test_ctr.stats.T)
|
||||
|
||||
# Compute verification scores
|
||||
score = Scores() # noqa F821
|
||||
score.modelset = clean_ndx.modelset
|
||||
score.segset = clean_ndx.segset
|
||||
score.scoremask = clean_ndx.trialmask
|
||||
|
||||
score.scoremat = model_part[:, numpy.newaxis] + seg_part + plda_cst
|
||||
score.scoremat += enroll_ctr.stats.dot(Psi).dot(test_ctr.stats.T)
|
||||
score.scoremat *= scaling_factor
|
||||
|
||||
# Case of open-set identification, we compute the log-likelihood
|
||||
# by taking into account the probability of having a known impostor
|
||||
# or an out-of set class
|
||||
if p_known != 0:
|
||||
N = score.scoremat.shape[0]
|
||||
open_set_scores = numpy.empty(score.scoremat.shape)
|
||||
tmp = numpy.exp(score.scoremat)
|
||||
for ii in range(N):
|
||||
# open-set term
|
||||
open_set_scores[ii, :] = score.scoremat[ii, :] - numpy.log(
|
||||
p_known * tmp[~(numpy.arange(N) == ii)].sum(axis=0) / (
|
||||
N - 1) + (1 - p_known))
|
||||
score.scoremat = open_set_scores
|
||||
|
||||
return score
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import random
|
||||
|
||||
dim, N, n_spkrs = 10, 100, 10
|
||||
train_xv = numpy.random.rand(N, dim)
|
||||
md = ['md' + str(random.randrange(1, n_spkrs, 1)) for i in range(N)] # spk
|
||||
modelset = numpy.array(md, dtype="|O")
|
||||
sg = ['sg' + str(i) for i in range(N)] # utt
|
||||
segset = numpy.array(sg, dtype="|O")
|
||||
stat0 = numpy.array([[1.0]] * N)
|
||||
xvectors_stat = EmbeddingMeta(
|
||||
modelset=modelset, segset=segset, stats=train_xv)
|
||||
# Training PLDA model: M ~ (mean, F, Sigma)
|
||||
plda = PLDA(rank_f=5)
|
||||
plda.plda(xvectors_stat)
|
||||
print(plda.mean.shape) #(10,)
|
||||
print(plda.F.shape) #(10, 5)
|
||||
print(plda.Sigma.shape) #(10, 10)
|
||||
# Enrollment (20 utts),
|
||||
en_N = 20
|
||||
en_xv = numpy.random.rand(en_N, dim)
|
||||
en_sgs = ['en' + str(i) for i in range(en_N)]
|
||||
en_sets = numpy.array(en_sgs, dtype="|O")
|
||||
en_stat = EmbeddingMeta(modelset=en_sets, segset=en_sets, stats=en_xv)
|
||||
# Test (30 utts)
|
||||
te_N = 30
|
||||
te_xv = numpy.random.rand(te_N, dim)
|
||||
te_sgs = ['te' + str(i) for i in range(te_N)]
|
||||
te_sets = numpy.array(te_sgs, dtype="|O")
|
||||
te_stat = EmbeddingMeta(modelset=te_sets, segset=te_sets, stats=te_xv)
|
||||
ndx = Ndx(models=en_sets, testsegs=te_sets) # trials
|
||||
# PLDA Scoring
|
||||
scores_plda = plda.scoring(en_stat, te_stat, ndx)
|
||||
print(scores_plda.scoremat.shape) #(20, 30)
|
Loading…
Reference in new issue