# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import sys from abc import ABC from abc import abstractmethod from collections import OrderedDict from typing import Any from typing import Dict from typing import List from typing import Union import paddle from .log import logger class BaseExecutor(ABC): """ An abstract executor of paddlespeech tasks. """ def __init__(self): self._inputs = OrderedDict() self._outputs = OrderedDict() @abstractmethod def _get_pretrained_path(self, tag: str) -> os.PathLike: """ Download and returns pretrained resources path of current task. Args: tag (str): A tag of pretrained model. Returns: os.PathLike: The path on which resources of pretrained model locate. """ pass @abstractmethod def _init_from_path(self, *args, **kwargs): """ Init model and other resources from arguments. This method should be called by `__call__()`. """ pass @abstractmethod def preprocess(self, input: Any, *args, **kwargs): """ Input preprocess and return paddle.Tensor stored in self._inputs. Input content can be a text(tts), a file(asr, cls), a stream(not supported yet) or anything needed. Args: input (Any): Input text/file/stream or other content. """ pass @paddle.no_grad() @abstractmethod def infer(self, *args, **kwargs): """ Model inference and put results into self._outputs. This method get input tensors from self._inputs, and write output tensors into self._outputs. """ pass @abstractmethod def postprocess(self, *args, **kwargs) -> Union[str, os.PathLike]: """ Output postprocess and return results. This method get model output from self._outputs and convert it into human-readable results. Returns: Union[str, os.PathLike]: Human-readable results such as texts and audio files. """ pass @abstractmethod def execute(self, argv: List[str]) -> bool: """ Command line entry. This method can only be accessed by a command line such as `paddlespeech asr`. Args: argv (List[str]): Arguments from command line. Returns: int: Result of the command execution. `True` for a success and `False` for a failure. """ pass @abstractmethod def __call__(self, *arg, **kwargs): """ Python API to call an executor. """ pass def get_task_source(self, input_: Union[str, os.PathLike, None] ) -> Dict[str, Union[str, os.PathLike]]: """ Get task input source from command line input. Args: input_ (Union[str, os.PathLike, None]): Input from command line. Returns: Dict[str, Union[str, os.PathLike]]: A dict with ids and inputs. """ if self._is_job_input(input_): ret = self._get_job_contents(input_) else: ret = OrderedDict() if input_ is None: # Take input from stdin for i, line in enumerate(sys.stdin): line = line.strip() if len(line.split(' ')) == 1: ret[str(i + 1)] = line elif len(line.split(' ')) == 2: id_, info = line.split(' ') ret[id_] = info else: # No valid input info from one line. continue else: ret[1] = input_ return ret def process_task_results(self, input_: Union[str, os.PathLike, None], results: Dict[str, os.PathLike], job_dump_result: bool=False): """ Handling task results and redirect stdout if needed. Args: input_ (Union[str, os.PathLike, None]): Input from command line. results (Dict[str, os.PathLike]): Task outputs. job_dump_result (bool, optional): if True, dumps job results into file. Defaults to False. """ raw_text = self._format_task_results(results) print(raw_text, end='') if self._is_job_input(input_) and job_dump_result: try: job_output_file = os.path.abspath(input_) + '.done' sys.stdout = open(job_output_file, 'w') print(raw_text, end='') logger.info(f'Results had been saved to: {job_output_file}') finally: sys.stdout.close() def _is_job_input(self, input_: Union[str, os.PathLike]) -> bool: """ Check if current input file is a job input or not. Args: input_ (Union[str, os.PathLike]): Input file of current task. Returns: bool: return `True` for job input, `False` otherwise. """ return input_ and os.path.isfile(input_) and input_.endswith('.job') def _get_job_contents( self, job_input: os.PathLike) -> Dict[str, Union[str, os.PathLike]]: """ Read a job input file and return its contents in a dictionary. Args: job_input (os.PathLike): The job input file. Returns: Dict[str, str]: Contents of job input. """ job_contents = OrderedDict() with open(job_input) as f: for line in f: line = line.strip() if not line: continue k, v = line.split(' ') job_contents[k] = v return job_contents def _format_task_results( self, results: Dict[str, Union[str, os.PathLike]]) -> str: """ Convert task results to raw text. Args: results (Dict[str, str]): A dictionary of task results. Returns: str: A string object contains task results. """ ret = '' for k, v in results.items(): ret += f'{k} {v}\n' return ret