diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index c9ec058cd..6ae038539 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -119,8 +119,8 @@ class ASRExecutor(BaseExecutor): lang: str='zh', sample_rate: int=16000, cfg_path: Optional[os.PathLike]=None, - ckpt_path: Optional[os.PathLike]=None, - device: str='cpu'): + ckpt_path: Optional[os.PathLike]=None + ): """ Init model and other resources from a specific path. """ @@ -142,7 +142,6 @@ class ASRExecutor(BaseExecutor): os.path.dirname(os.path.abspath(self.cfg_path))) #Init body. - paddle.set_device(device) self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path) self.config.decoding.decoding_method = "attention_rescoring" @@ -397,8 +396,8 @@ class ASRExecutor(BaseExecutor): """ audio_file = os.path.abspath(audio_file) self._check(audio_file, sample_rate) - self._init_from_path(model, lang, sample_rate, config, ckpt_path, - device) + paddle.set_device(device) + self._init_from_path(model, lang, sample_rate, config, ckpt_path) self.preprocess(model, audio_file) self.infer(model) res = self.postprocess() # Retrieve result of asr. diff --git a/paddlespeech/cli/executor.py b/paddlespeech/cli/executor.py index c132b3b87..00371371d 100644 --- a/paddlespeech/cli/executor.py +++ b/paddlespeech/cli/executor.py @@ -14,6 +14,7 @@ import os from abc import ABC from abc import abstractmethod +from typing import Any from typing import List from typing import Union @@ -32,50 +33,70 @@ class BaseExecutor(ABC): @abstractmethod def _get_pretrained_path(self, tag: str) -> os.PathLike: """ - Download and returns pretrained resources path of current task. + Download and returns pretrained resources path of current task. + + Args: + tag (str): A tag of pretrained model. + + Returns: + os.PathLike: The path on which resources of pretrained model locate. """ pass @abstractmethod def _init_from_path(self, *args, **kwargs): """ - Init model and other resources from a specific path. + Init model and other resources from arguments. This method should be called by `__call__()`. """ pass @abstractmethod - def preprocess(self, input: Union[str, os.PathLike]): + def preprocess(self, input: Any, *args, **kwargs): """ - Input preprocess and return paddle.Tensor stored in self.input. - Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet). + Input preprocess and return paddle.Tensor stored in self._inputs. + Input content can be a text(tts), a file(asr, cls), a stream(not supported yet) or anything needed. + + Args: + input (Any): Input text/file/stream or other content. """ pass @paddle.no_grad() @abstractmethod - def infer(self, device: str): + def infer(self, *args, **kwargs): """ - Model inference and result stored in self.output. + Model inference and put results into self._outputs. + This method get input tensors from self._inputs, and write output tensors into self._outputs. """ pass @abstractmethod - def postprocess(self) -> Union[str, os.PathLike]: + def postprocess(self, *args, **kwargs) -> Union[str, os.PathLike]: """ - Output postprocess and return human-readable results such as texts and audio files. + Output postprocess and return results. + This method get model output from self._outputs and convert it into human-readable results. + + Returns: + Union[str, os.PathLike]: Human-readable results such as texts and audio files. """ pass @abstractmethod def execute(self, argv: List[str]) -> bool: """ - Command line entry. + Command line entry. This method can only be accessed by a command line such as `paddlespeech asr`. + + Args: + argv (List[str]): Arguments from command line. + + Returns: + int: Result of the command execution. `True` for a success and `False` for a failure. """ pass @abstractmethod def __call__(self, *arg, **kwargs): """ - Python API to call an executor. + Python API to call an executor. """ pass