add score method, test=doc

pull/1646/head
xiongxinlei 3 years ago
parent cfc390e0b4
commit 48b8cc8937

@ -15,6 +15,7 @@ import argparse
import os
import sys
from collections import OrderedDict
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
@ -79,7 +80,7 @@ class VectorExecutor(BaseExecutor):
"--task",
type=str,
default="spk",
choices=["spk"],
choices=["spk", "score"],
help="task type in vector domain")
self.parser.add_argument(
"--input",
@ -147,13 +148,40 @@ class VectorExecutor(BaseExecutor):
logger.info(f"task source: {task_source}")
# stage 3: process the audio one by one
# we do action according the task type
task_result = OrderedDict()
has_exceptions = False
for id_, input_ in task_source.items():
try:
# extract the speaker audio embedding
if parser_args.task == "spk":
logger.info("do vector spk task")
res = self(input_, model, sample_rate, config, ckpt_path,
device)
task_result[id_] = res
elif parser_args.task == "score":
logger.info("do vector score task")
logger.info(f"input content {input_}")
if len(input_.split()) != 2:
logger.error(
f"vector score task input {input_} wav num is not two,"
"that is {len(input_.split())}")
sys.exit(-1)
# get the enroll and test embedding
enroll_audio, test_audio = input_.split()
logger.info(
f"score task, enroll audio: {enroll_audio}, test audio: {test_audio}"
)
enroll_embedding = self(enroll_audio, model, sample_rate,
config, ckpt_path, device)
test_embedding = self(test_audio, model, sample_rate,
config, ckpt_path, device)
# get the score
res = self.get_embeddings_score(enroll_embedding,
test_embedding)
task_result[id_] = res
except Exception as e:
has_exceptions = True
task_result[id_] = f'{e.__class__.__name__}: {e}'
@ -172,6 +200,49 @@ class VectorExecutor(BaseExecutor):
else:
return True
def _get_job_contents(
self, job_input: os.PathLike) -> Dict[str, Union[str, os.PathLike]]:
"""
Read a job input file and return its contents in a dictionary.
Refactor from the Executor._get_job_contents
Args:
job_input (os.PathLike): The job input file.
Returns:
Dict[str, str]: Contents of job input.
"""
job_contents = OrderedDict()
with open(job_input) as f:
for line in f:
line = line.strip()
if not line:
continue
k = line.split(' ')[0]
v = ' '.join(line.split(' ')[1:])
job_contents[k] = v
return job_contents
def get_embeddings_score(self, enroll_embedding, test_embedding):
"""get the enroll embedding and test embedding score
Args:
enroll_embedding (numpy.array): shape: (emb_size), enroll audio embedding
test_embedding (numpy.array): shape: (emb_size), test audio embedding
Returns:
score: the score between enroll embedding and test embedding
"""
if not hasattr(self, "score_func"):
self.score_func = paddle.nn.CosineSimilarity(axis=0)
logger.info("create the cosine score function ")
score = self.score_func(
paddle.to_tensor(enroll_embedding),
paddle.to_tensor(test_embedding))
return score.item()
@stats_wrapper
def __call__(self,
audio_file: os.PathLike,

Loading…
Cancel
Save