diff --git a/paddlespeech/cli/executor.py b/paddlespeech/cli/executor.py index 2261e011..2314bd6d 100644 --- a/paddlespeech/cli/executor.py +++ b/paddlespeech/cli/executor.py @@ -14,6 +14,7 @@ import os from abc import ABC from abc import abstractmethod +from typing import List from typing import Union import paddle @@ -64,3 +65,17 @@ class BaseExecutor(ABC): Output postprocess and return human-readable results such as texts and audio files. """ pass + + @abstractmethod + def execute(self, argv: List[str]) -> bool: + """ + Command line entry. + """ + pass + + @abstractmethod + def __call__(self, *arg, **kwargs): + """ + Python API to call an executor. + """ + pass diff --git a/paddlespeech/cli/s2t/infer.py b/paddlespeech/cli/s2t/infer.py index 912d1df0..b3507cb6 100644 --- a/paddlespeech/cli/s2t/infer.py +++ b/paddlespeech/cli/s2t/infer.py @@ -18,13 +18,14 @@ from typing import List from typing import Optional from typing import Union -import soundfile import paddle -from paddlespeech.cli.executor import BaseExecutor -from paddlespeech.cli.utils import cli_register -from paddlespeech.cli.utils import download_and_decompress -from paddlespeech.cli.utils import logger -from paddlespeech.cli.utils import MODEL_HOME +import soundfile + +from ..executor import BaseExecutor +from ..utils import cli_register +from ..utils import download_and_decompress +from ..utils import logger +from ..utils import MODEL_HOME from paddlespeech.s2t.exps.u2.config import get_cfg_defaults from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.io.collator import SpeechCollator @@ -55,29 +56,6 @@ model_alias = { "wenetspeech": "paddlespeech.s2t.models.u2:U2Model", } -pretrain_model_alias = { - "ds2_online_zn": [ - "https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/aishell_ds2_online_cer8.00_release.tar.gz", - "", "" - ], - "ds2_offline_zn": [ - "https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/ds2.model.tar.gz", - "", "" - ], - "transformer_zn": [ - "https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/transformer.model.tar.gz", - "", "" - ], - "conformer_zn": [ - "https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz", - "", "" - ], - "wenetspeech_zn": [ - "https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz", - "conf/conformer.yaml", "exp/conformer/checkpoints/wenetspeech" - ], -} - @cli_register( name='paddlespeech.s2t', description='Speech to text infer command.') @@ -107,7 +85,6 @@ class S2TExecutor(BaseExecutor): self.parser.add_argument( '--input', type=str, - default="../Downloads/asr-demo-1.wav", help='Audio file to recognize.') self.parser.add_argument( '--device', @@ -155,7 +132,9 @@ class S2TExecutor(BaseExecutor): res_path = os.path.dirname( os.path.dirname(os.path.abspath(self.cfg_path))) + # Enter the path of model root os.chdir(res_path) + #Init body. parser_args = self.parser_args paddle.set_device(parser_args.device) @@ -206,7 +185,7 @@ class S2TExecutor(BaseExecutor): config = self.config audio_file = input #print("audio_file", audio_file) - logger.info("audio_file"+ audio_file) + logger.info("audio_file" + audio_file) self.sr = config.collator.target_sample_rate @@ -307,7 +286,11 @@ class S2TExecutor(BaseExecutor): return self.result_transcripts def execute(self, argv: List[str]) -> bool: + """ + Command line entry. + """ self.parser_args = self.parser.parse_args(argv) + print(self.parser_args) model = self.parser_args.model lang = self.parser_args.lang @@ -317,17 +300,20 @@ class S2TExecutor(BaseExecutor): device = self.parser_args.device try: - self._init_from_path(model, lang, config, ckpt_path) - self.preprocess(audio_file) - self.infer() - res = self.postprocess() # Retrieve result of s2t. - logger.info(res) + res = self(model, lang, config, ckpt_path, audio_file, device) + print(res) return True except Exception as e: print(e) return False + def __call__(self, model, lang, config, ckpt_path, audio_file, device): + """ + Python API to call an executor. + """ + self._init_from_path(model, lang, config, ckpt_path) + self.preprocess(audio_file) + self.infer() + res = self.postprocess() # Retrieve result of s2t. -if __name__ == "__main__": - exe = S2TExecutor() - exe.execute('') + return res