|
|
|
@ -15,6 +15,7 @@ import argparse
|
|
|
|
|
import os
|
|
|
|
|
import sys
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
from typing import Dict
|
|
|
|
|
from typing import List
|
|
|
|
|
from typing import Optional
|
|
|
|
|
from typing import Union
|
|
|
|
@ -79,7 +80,7 @@ class VectorExecutor(BaseExecutor):
|
|
|
|
|
"--task",
|
|
|
|
|
type=str,
|
|
|
|
|
default="spk",
|
|
|
|
|
choices=["spk"],
|
|
|
|
|
choices=["spk", "score"],
|
|
|
|
|
help="task type in vector domain")
|
|
|
|
|
self.parser.add_argument(
|
|
|
|
|
"--input",
|
|
|
|
@ -147,13 +148,40 @@ class VectorExecutor(BaseExecutor):
|
|
|
|
|
logger.info(f"task source: {task_source}")
|
|
|
|
|
|
|
|
|
|
# stage 3: process the audio one by one
|
|
|
|
|
# we do action according the task type
|
|
|
|
|
task_result = OrderedDict()
|
|
|
|
|
has_exceptions = False
|
|
|
|
|
for id_, input_ in task_source.items():
|
|
|
|
|
try:
|
|
|
|
|
res = self(input_, model, sample_rate, config, ckpt_path,
|
|
|
|
|
device)
|
|
|
|
|
task_result[id_] = res
|
|
|
|
|
# extract the speaker audio embedding
|
|
|
|
|
if parser_args.task == "spk":
|
|
|
|
|
logger.info("do vector spk task")
|
|
|
|
|
res = self(input_, model, sample_rate, config, ckpt_path,
|
|
|
|
|
device)
|
|
|
|
|
task_result[id_] = res
|
|
|
|
|
elif parser_args.task == "score":
|
|
|
|
|
logger.info("do vector score task")
|
|
|
|
|
logger.info(f"input content {input_}")
|
|
|
|
|
if len(input_.split()) != 2:
|
|
|
|
|
logger.error(
|
|
|
|
|
f"vector score task input {input_} wav num is not two,"
|
|
|
|
|
"that is {len(input_.split())}")
|
|
|
|
|
sys.exit(-1)
|
|
|
|
|
|
|
|
|
|
# get the enroll and test embedding
|
|
|
|
|
enroll_audio, test_audio = input_.split()
|
|
|
|
|
logger.info(
|
|
|
|
|
f"score task, enroll audio: {enroll_audio}, test audio: {test_audio}"
|
|
|
|
|
)
|
|
|
|
|
enroll_embedding = self(enroll_audio, model, sample_rate,
|
|
|
|
|
config, ckpt_path, device)
|
|
|
|
|
test_embedding = self(test_audio, model, sample_rate,
|
|
|
|
|
config, ckpt_path, device)
|
|
|
|
|
|
|
|
|
|
# get the score
|
|
|
|
|
res = self.get_embeddings_score(enroll_embedding,
|
|
|
|
|
test_embedding)
|
|
|
|
|
task_result[id_] = res
|
|
|
|
|
except Exception as e:
|
|
|
|
|
has_exceptions = True
|
|
|
|
|
task_result[id_] = f'{e.__class__.__name__}: {e}'
|
|
|
|
@ -172,6 +200,49 @@ class VectorExecutor(BaseExecutor):
|
|
|
|
|
else:
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
def _get_job_contents(
|
|
|
|
|
self, job_input: os.PathLike) -> Dict[str, Union[str, os.PathLike]]:
|
|
|
|
|
"""
|
|
|
|
|
Read a job input file and return its contents in a dictionary.
|
|
|
|
|
Refactor from the Executor._get_job_contents
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
job_input (os.PathLike): The job input file.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Dict[str, str]: Contents of job input.
|
|
|
|
|
"""
|
|
|
|
|
job_contents = OrderedDict()
|
|
|
|
|
with open(job_input) as f:
|
|
|
|
|
for line in f:
|
|
|
|
|
line = line.strip()
|
|
|
|
|
if not line:
|
|
|
|
|
continue
|
|
|
|
|
k = line.split(' ')[0]
|
|
|
|
|
v = ' '.join(line.split(' ')[1:])
|
|
|
|
|
job_contents[k] = v
|
|
|
|
|
return job_contents
|
|
|
|
|
|
|
|
|
|
def get_embeddings_score(self, enroll_embedding, test_embedding):
|
|
|
|
|
"""get the enroll embedding and test embedding score
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
enroll_embedding (numpy.array): shape: (emb_size), enroll audio embedding
|
|
|
|
|
test_embedding (numpy.array): shape: (emb_size), test audio embedding
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
score: the score between enroll embedding and test embedding
|
|
|
|
|
"""
|
|
|
|
|
if not hasattr(self, "score_func"):
|
|
|
|
|
self.score_func = paddle.nn.CosineSimilarity(axis=0)
|
|
|
|
|
logger.info("create the cosine score function ")
|
|
|
|
|
|
|
|
|
|
score = self.score_func(
|
|
|
|
|
paddle.to_tensor(enroll_embedding),
|
|
|
|
|
paddle.to_tensor(test_embedding))
|
|
|
|
|
|
|
|
|
|
return score.item()
|
|
|
|
|
|
|
|
|
|
@stats_wrapper
|
|
|
|
|
def __call__(self,
|
|
|
|
|
audio_file: os.PathLike,
|
|
|
|
|