diff --git a/paddlespeech/cli/vector/infer.py b/paddlespeech/cli/vector/infer.py index 175a9723e..f709383d9 100644 --- a/paddlespeech/cli/vector/infer.py +++ b/paddlespeech/cli/vector/infer.py @@ -15,6 +15,7 @@ import argparse import os import sys from collections import OrderedDict +from typing import Dict from typing import List from typing import Optional from typing import Union @@ -79,7 +80,7 @@ class VectorExecutor(BaseExecutor): "--task", type=str, default="spk", - choices=["spk"], + choices=["spk", "score"], help="task type in vector domain") self.parser.add_argument( "--input", @@ -147,13 +148,40 @@ class VectorExecutor(BaseExecutor): logger.info(f"task source: {task_source}") # stage 3: process the audio one by one + # we do action according the task type task_result = OrderedDict() has_exceptions = False for id_, input_ in task_source.items(): try: - res = self(input_, model, sample_rate, config, ckpt_path, - device) - task_result[id_] = res + # extract the speaker audio embedding + if parser_args.task == "spk": + logger.info("do vector spk task") + res = self(input_, model, sample_rate, config, ckpt_path, + device) + task_result[id_] = res + elif parser_args.task == "score": + logger.info("do vector score task") + logger.info(f"input content {input_}") + if len(input_.split()) != 2: + logger.error( + f"vector score task input {input_} wav num is not two," + "that is {len(input_.split())}") + sys.exit(-1) + + # get the enroll and test embedding + enroll_audio, test_audio = input_.split() + logger.info( + f"score task, enroll audio: {enroll_audio}, test audio: {test_audio}" + ) + enroll_embedding = self(enroll_audio, model, sample_rate, + config, ckpt_path, device) + test_embedding = self(test_audio, model, sample_rate, + config, ckpt_path, device) + + # get the score + res = self.get_embeddings_score(enroll_embedding, + test_embedding) + task_result[id_] = res except Exception as e: has_exceptions = True task_result[id_] = f'{e.__class__.__name__}: {e}' @@ -172,6 +200,49 @@ class VectorExecutor(BaseExecutor): else: return True + def _get_job_contents( + self, job_input: os.PathLike) -> Dict[str, Union[str, os.PathLike]]: + """ + Read a job input file and return its contents in a dictionary. + Refactor from the Executor._get_job_contents + + Args: + job_input (os.PathLike): The job input file. + + Returns: + Dict[str, str]: Contents of job input. + """ + job_contents = OrderedDict() + with open(job_input) as f: + for line in f: + line = line.strip() + if not line: + continue + k = line.split(' ')[0] + v = ' '.join(line.split(' ')[1:]) + job_contents[k] = v + return job_contents + + def get_embeddings_score(self, enroll_embedding, test_embedding): + """get the enroll embedding and test embedding score + + Args: + enroll_embedding (numpy.array): shape: (emb_size), enroll audio embedding + test_embedding (numpy.array): shape: (emb_size), test audio embedding + + Returns: + score: the score between enroll embedding and test embedding + """ + if not hasattr(self, "score_func"): + self.score_func = paddle.nn.CosineSimilarity(axis=0) + logger.info("create the cosine score function ") + + score = self.score_func( + paddle.to_tensor(enroll_embedding), + paddle.to_tensor(test_embedding)) + + return score.item() + @stats_wrapper def __call__(self, audio_file: os.PathLike,