From 1818b058aad9650b4d8962451e7b14b9d3381463 Mon Sep 17 00:00:00 2001 From: KP <109694228@qq.com> Date: Fri, 18 Feb 2022 18:38:04 +0800 Subject: [PATCH] Support batch input in cls task. --- demos/audio_tagging/README.md | 2 +- demos/audio_tagging/README_cn.md | 2 +- paddlespeech/cli/cls/infer.py | 54 ++++++++++++++++++++++++++------ paddlespeech/cli/executor.py | 48 ++++++++++++++++++++++++++++ 4 files changed, 94 insertions(+), 12 deletions(-) diff --git a/demos/audio_tagging/README.md b/demos/audio_tagging/README.md index 9d4af0be..d679ea54 100644 --- a/demos/audio_tagging/README.md +++ b/demos/audio_tagging/README.md @@ -65,7 +65,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/cat.wav https://paddlespe config=None, # Set `config` and `ckpt_path` to None to use pretrained model. label_file=None, ckpt_path=None, - audio_file='./cat.wav', + input_file='./cat.wav', topk=10, device=paddle.get_device()) print('CLS Result: \n{}'.format(result)) diff --git a/demos/audio_tagging/README_cn.md b/demos/audio_tagging/README_cn.md index 79f87bf8..bd94dd58 100644 --- a/demos/audio_tagging/README_cn.md +++ b/demos/audio_tagging/README_cn.md @@ -65,7 +65,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/cat.wav https://paddlespe config=None, # Set `config` and `ckpt_path` to None to use pretrained model. label_file=None, ckpt_path=None, - audio_file='./cat.wav', + input_file='./cat.wav', topk=10, device=paddle.get_device()) print('CLS Result: \n{}'.format(result)) diff --git a/paddlespeech/cli/cls/infer.py b/paddlespeech/cli/cls/infer.py index 5839ff30..b2278734 100644 --- a/paddlespeech/cli/cls/infer.py +++ b/paddlespeech/cli/cls/infer.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import ast import os +import sys from typing import List from typing import Optional from typing import Union @@ -109,6 +111,11 @@ class CLSExecutor(BaseExecutor): type=str, default=paddle.get_device(), help='Choose device to execute model inference.') + self.parser.add_argument( + '--job_dump_result', + type=ast.literal_eval, + default=False, + help='Save job result into file.') def _get_pretrained_path(self, tag: str) -> os.PathLike: """ @@ -214,7 +221,7 @@ class CLSExecutor(BaseExecutor): ret = '' for idx in topk_idx: label, score = self._label_list[idx], result[idx] - ret += f'{label}: {score}\n' + ret += f'{label} {score}\n' return ret def postprocess(self, topk: int) -> Union[str, os.PathLike]: @@ -234,22 +241,36 @@ class CLSExecutor(BaseExecutor): label_file = parser_args.label_file cfg_path = parser_args.config ckpt_path = parser_args.ckpt_path - audio_file = parser_args.input + input_file = parser_args.input topk = parser_args.topk device = parser_args.device + job_dump_result = parser_args.job_dump_result try: - res = self(audio_file, model_type, cfg_path, ckpt_path, label_file, - topk, device) - logger.info('CLS Result:\n{}'.format(res)) + if job_dump_result: + assert self._is_job_input( + input_file + ), 'Input file should be a job file(*.job) when `job_dump_result` is True.' + job_output_file = os.path.abspath(input_file) + '.done' + sys.stdout = open(job_output_file, 'w') + + print( + self(input_file, model_type, cfg_path, ckpt_path, label_file, + topk, device)) + + if job_dump_result: + logger.info(f'Results had been saved to: {job_output_file}') + return True except Exception as e: logger.exception(e) return False + finally: + sys.stdout.close() @stats_wrapper def __call__(self, - audio_file: os.PathLike, + input_file: os.PathLike, model: str='panns_cnn14', config: Optional[os.PathLike]=None, ckpt_path: Optional[os.PathLike]=None, @@ -259,11 +280,24 @@ class CLSExecutor(BaseExecutor): """ Python API to call an executor. """ - audio_file = os.path.abspath(audio_file) + input_file = os.path.abspath(input_file) paddle.set_device(device) self._init_from_path(model, config, ckpt_path, label_file) - self.preprocess(audio_file) - self.infer() - res = self.postprocess(topk) # Retrieve result of cls. + + if self._is_job_input(input_file): # *.job + job_outputs = {} + job_contents = self._job_preprocess(input_file) + for id_, file in job_contents.items(): + try: + self.preprocess(file) + self.infer() + job_outputs[id_] = self.postprocess(topk).strip() + except Exception as e: + job_outputs[id_] = f'{e.__class__.__name__}: {e}' + res = self._job_postprecess(job_outputs) + else: + self.preprocess(input_file) + self.infer() + res = self.postprocess(topk) # Retrieve result of cls. return res diff --git a/paddlespeech/cli/executor.py b/paddlespeech/cli/executor.py index 00371371..dded6758 100644 --- a/paddlespeech/cli/executor.py +++ b/paddlespeech/cli/executor.py @@ -15,6 +15,7 @@ import os from abc import ABC from abc import abstractmethod from typing import Any +from typing import Dict from typing import List from typing import Union @@ -100,3 +101,50 @@ class BaseExecutor(ABC): Python API to call an executor. """ pass + + def _is_job_input(self, input_: Union[str, os.PathLike]) -> bool: + """ + Check if current input file is a job input or not. + + Args: + input_ (Union[str, os.PathLike]): Input file of current task. + + Returns: + bool: return `True` for job input, `False` otherwise. + """ + return os.path.isfile(input_) and input_.endswith('.job') + + def _job_preprocess(self, job_input: os.PathLike) -> Dict[str, str]: + """ + Read a job input file and return its contents in a dictionary. + + Args: + job_input (os.PathLike): The job input file. + + Returns: + Dict[str, str]: Contents of job input. + """ + job_contents = {} + with open(job_input) as f: + for line in f: + line = line.strip() + if not line: + continue + k, v = line.split(' ') + job_contents[k] = v + return job_contents + + def _job_postprecess(self, job_outputs: Dict[str, str]) -> str: + """ + Convert job results to string. + + Args: + job_outputs (Dict[str, str]): A dictionary with job ids and results. + + Returns: + str: A string object contains job outputs. + """ + ret = '' + for k, v in job_outputs.items(): + ret += f'{k} {v}\n' + return ret