add score method, test=doc

pull/1646/head
xiongxinlei 3 years ago
parent cfc390e0b4
commit 48b8cc8937

@ -15,6 +15,7 @@ import argparse
import os import os
import sys import sys
from collections import OrderedDict from collections import OrderedDict
from typing import Dict
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Union from typing import Union
@ -79,7 +80,7 @@ class VectorExecutor(BaseExecutor):
"--task", "--task",
type=str, type=str,
default="spk", default="spk",
choices=["spk"], choices=["spk", "score"],
help="task type in vector domain") help="task type in vector domain")
self.parser.add_argument( self.parser.add_argument(
"--input", "--input",
@ -147,13 +148,40 @@ class VectorExecutor(BaseExecutor):
logger.info(f"task source: {task_source}") logger.info(f"task source: {task_source}")
# stage 3: process the audio one by one # stage 3: process the audio one by one
# we do action according the task type
task_result = OrderedDict() task_result = OrderedDict()
has_exceptions = False has_exceptions = False
for id_, input_ in task_source.items(): for id_, input_ in task_source.items():
try: try:
res = self(input_, model, sample_rate, config, ckpt_path, # extract the speaker audio embedding
device) if parser_args.task == "spk":
task_result[id_] = res logger.info("do vector spk task")
res = self(input_, model, sample_rate, config, ckpt_path,
device)
task_result[id_] = res
elif parser_args.task == "score":
logger.info("do vector score task")
logger.info(f"input content {input_}")
if len(input_.split()) != 2:
logger.error(
f"vector score task input {input_} wav num is not two,"
"that is {len(input_.split())}")
sys.exit(-1)
# get the enroll and test embedding
enroll_audio, test_audio = input_.split()
logger.info(
f"score task, enroll audio: {enroll_audio}, test audio: {test_audio}"
)
enroll_embedding = self(enroll_audio, model, sample_rate,
config, ckpt_path, device)
test_embedding = self(test_audio, model, sample_rate,
config, ckpt_path, device)
# get the score
res = self.get_embeddings_score(enroll_embedding,
test_embedding)
task_result[id_] = res
except Exception as e: except Exception as e:
has_exceptions = True has_exceptions = True
task_result[id_] = f'{e.__class__.__name__}: {e}' task_result[id_] = f'{e.__class__.__name__}: {e}'
@ -172,6 +200,49 @@ class VectorExecutor(BaseExecutor):
else: else:
return True return True
def _get_job_contents(
self, job_input: os.PathLike) -> Dict[str, Union[str, os.PathLike]]:
"""
Read a job input file and return its contents in a dictionary.
Refactor from the Executor._get_job_contents
Args:
job_input (os.PathLike): The job input file.
Returns:
Dict[str, str]: Contents of job input.
"""
job_contents = OrderedDict()
with open(job_input) as f:
for line in f:
line = line.strip()
if not line:
continue
k = line.split(' ')[0]
v = ' '.join(line.split(' ')[1:])
job_contents[k] = v
return job_contents
def get_embeddings_score(self, enroll_embedding, test_embedding):
"""get the enroll embedding and test embedding score
Args:
enroll_embedding (numpy.array): shape: (emb_size), enroll audio embedding
test_embedding (numpy.array): shape: (emb_size), test audio embedding
Returns:
score: the score between enroll embedding and test embedding
"""
if not hasattr(self, "score_func"):
self.score_func = paddle.nn.CosineSimilarity(axis=0)
logger.info("create the cosine score function ")
score = self.score_func(
paddle.to_tensor(enroll_embedding),
paddle.to_tensor(test_embedding))
return score.item()
@stats_wrapper @stats_wrapper
def __call__(self, def __call__(self,
audio_file: os.PathLike, audio_file: os.PathLike,

Loading…
Cancel
Save