From 05288fe1c39b30f61bc9f61646dd97cf05a9e574 Mon Sep 17 00:00:00 2001 From: KP <109694228@qq.com> Date: Mon, 21 Feb 2022 00:03:37 +0800 Subject: [PATCH] Update batch input and stdin input. --- paddlespeech/cli/cls/infer.py | 63 +++++++++++---------------- paddlespeech/cli/executor.py | 81 ++++++++++++++++++++++++++++++----- 2 files changed, 95 insertions(+), 49 deletions(-) diff --git a/paddlespeech/cli/cls/infer.py b/paddlespeech/cli/cls/infer.py index b2278734..e5d4b546 100644 --- a/paddlespeech/cli/cls/infer.py +++ b/paddlespeech/cli/cls/infer.py @@ -14,7 +14,7 @@ import argparse import ast import os -import sys +from collections import OrderedDict from typing import List from typing import Optional from typing import Union @@ -79,7 +79,7 @@ class CLSExecutor(BaseExecutor): self.parser = argparse.ArgumentParser( prog='paddlespeech.cls', add_help=True) self.parser.add_argument( - '--input', type=str, required=True, help='Audio file to classify.') + '--input', type=str, default=None, help='Audio file to classify.') self.parser.add_argument( '--model', type=str, @@ -221,7 +221,7 @@ class CLSExecutor(BaseExecutor): ret = '' for idx in topk_idx: label, score = self._label_list[idx], result[idx] - ret += f'{label} {score}\n' + ret += f'{label} {score} ' return ret def postprocess(self, topk: int) -> Union[str, os.PathLike]: @@ -241,36 +241,34 @@ class CLSExecutor(BaseExecutor): label_file = parser_args.label_file cfg_path = parser_args.config ckpt_path = parser_args.ckpt_path - input_file = parser_args.input topk = parser_args.topk device = parser_args.device job_dump_result = parser_args.job_dump_result - try: - if job_dump_result: - assert self._is_job_input( - input_file - ), 'Input file should be a job file(*.job) when `job_dump_result` is True.' - job_output_file = os.path.abspath(input_file) + '.done' - sys.stdout = open(job_output_file, 'w') + task_source = self.get_task_source(parser_args.input) + task_results = OrderedDict() + has_exceptions = False - print( - self(input_file, model_type, cfg_path, ckpt_path, label_file, - topk, device)) + for id_, input_ in task_source.items(): + try: + res = self(input_, model_type, cfg_path, ckpt_path, label_file, + topk, device) + task_results[id_] = res + except Exception as e: + has_exceptions = True + task_results[id_] = f'{e.__class__.__name__}: {e}' - if job_dump_result: - logger.info(f'Results had been saved to: {job_output_file}') + self.process_task_results(parser_args.input, task_results, + job_dump_result) - return True - except Exception as e: - logger.exception(e) + if has_exceptions: return False - finally: - sys.stdout.close() + else: + return True @stats_wrapper def __call__(self, - input_file: os.PathLike, + audio_file: os.PathLike, model: str='panns_cnn14', config: Optional[os.PathLike]=None, ckpt_path: Optional[os.PathLike]=None, @@ -280,24 +278,11 @@ class CLSExecutor(BaseExecutor): """ Python API to call an executor. """ - input_file = os.path.abspath(input_file) + audio_file = os.path.abspath(os.path.expanduser(audio_file)) paddle.set_device(device) self._init_from_path(model, config, ckpt_path, label_file) - - if self._is_job_input(input_file): # *.job - job_outputs = {} - job_contents = self._job_preprocess(input_file) - for id_, file in job_contents.items(): - try: - self.preprocess(file) - self.infer() - job_outputs[id_] = self.postprocess(topk).strip() - except Exception as e: - job_outputs[id_] = f'{e.__class__.__name__}: {e}' - res = self._job_postprecess(job_outputs) - else: - self.preprocess(input_file) - self.infer() - res = self.postprocess(topk) # Retrieve result of cls. + self.preprocess(audio_file) + self.infer() + res = self.postprocess(topk) # Retrieve result of cls. return res diff --git a/paddlespeech/cli/executor.py b/paddlespeech/cli/executor.py index dded6758..d81f8f9f 100644 --- a/paddlespeech/cli/executor.py +++ b/paddlespeech/cli/executor.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import sys from abc import ABC from abc import abstractmethod +from collections import OrderedDict from typing import Any from typing import Dict from typing import List @@ -21,6 +23,8 @@ from typing import Union import paddle +from .log import logger + class BaseExecutor(ABC): """ @@ -28,8 +32,8 @@ class BaseExecutor(ABC): """ def __init__(self): - self._inputs = dict() - self._outputs = dict() + self._inputs = OrderedDict() + self._outputs = OrderedDict() @abstractmethod def _get_pretrained_path(self, tag: str) -> os.PathLike: @@ -102,6 +106,61 @@ class BaseExecutor(ABC): """ pass + def get_task_source(self, input_: Union[str, os.PathLike, None] + ) -> Dict[str, Union[str, os.PathLike]]: + """ + Get task input source from command line input. + + Args: + input_ (Union[str, os.PathLike, None]): Input from command line. + + Returns: + Dict[str, Union[str, os.PathLike]]: A dict with ids and inputs. + """ + if self._is_job_input(input_): + ret = self._get_job_contents(input_) + else: + ret = OrderedDict() + + if input_ is None: # Take input from stdin + for i, line in enumerate(sys.stdin): + line = line.strip() + if len(line.split(' ')) == 1: + ret[str(i + 1)] = line + elif len(line.split(' ')) == 2: + id_, info = line.split(' ') + ret[id_] = info + else: # No valid input info from one line. + continue + else: + ret[1] = input_ + return ret + + def process_task_results(self, + input_: Union[str, os.PathLike, None], + results: Dict[str, os.PathLike], + job_dump_result: bool=False): + """ + Handling task results and redirect stdout if needed. + + Args: + input_ (Union[str, os.PathLike, None]): Input from command line. + results (Dict[str, os.PathLike]): Task outputs. + job_dump_result (bool, optional): if True, dumps job results into file. Defaults to False. + """ + + raw_text = self._format_task_results(results) + print(raw_text, end='') + + if self._is_job_input(input_) and job_dump_result: + try: + job_output_file = os.path.abspath(input_) + '.done' + sys.stdout = open(job_output_file, 'w') + print(raw_text, end='') + logger.info(f'Results had been saved to: {job_output_file}') + finally: + sys.stdout.close() + def _is_job_input(self, input_: Union[str, os.PathLike]) -> bool: """ Check if current input file is a job input or not. @@ -112,9 +171,10 @@ class BaseExecutor(ABC): Returns: bool: return `True` for job input, `False` otherwise. """ - return os.path.isfile(input_) and input_.endswith('.job') + return input_ and os.path.isfile(input_) and input_.endswith('.job') - def _job_preprocess(self, job_input: os.PathLike) -> Dict[str, str]: + def _get_job_contents( + self, job_input: os.PathLike) -> Dict[str, Union[str, os.PathLike]]: """ Read a job input file and return its contents in a dictionary. @@ -124,7 +184,7 @@ class BaseExecutor(ABC): Returns: Dict[str, str]: Contents of job input. """ - job_contents = {} + job_contents = OrderedDict() with open(job_input) as f: for line in f: line = line.strip() @@ -134,17 +194,18 @@ class BaseExecutor(ABC): job_contents[k] = v return job_contents - def _job_postprecess(self, job_outputs: Dict[str, str]) -> str: + def _format_task_results( + self, results: Dict[str, Union[str, os.PathLike]]) -> str: """ - Convert job results to string. + Convert task results to raw text. Args: - job_outputs (Dict[str, str]): A dictionary with job ids and results. + results (Dict[str, str]): A dictionary of task results. Returns: - str: A string object contains job outputs. + str: A string object contains task results. """ ret = '' - for k, v in job_outputs.items(): + for k, v in results.items(): ret += f'{k} {v}\n' return ret