diff --git a/paddlespeech/cli/text/infer.py b/paddlespeech/cli/text/infer.py index 24b8c9c25..ff822f674 100644 --- a/paddlespeech/cli/text/infer.py +++ b/paddlespeech/cli/text/infer.py @@ -20,10 +20,13 @@ from typing import Optional from typing import Union import paddle +import yaml +from yacs.config import CfgNode from ..executor import BaseExecutor from ..log import logger from ..utils import stats_wrapper +from paddlespeech.text.models.ernie_linear import ErnieLinear __all__ = ['TextExecutor'] @@ -139,6 +142,66 @@ class TextExecutor(BaseExecutor): self.model.eval() + #init new models + def _init_from_path_new(self, + task: str='punc', + model_type: str='ernie_linear_p7_wudao', + lang: str='zh', + cfg_path: Optional[os.PathLike]=None, + ckpt_path: Optional[os.PathLike]=None, + vocab_file: Optional[os.PathLike]=None): + if hasattr(self, 'model'): + logger.debug('Model had been initialized.') + return + + self.task = task + + if cfg_path is None or ckpt_path is None or vocab_file is None: + tag = '-'.join([model_type, task, lang]) + self.task_resource.set_task_model(tag, version=None) + self.cfg_path = os.path.join( + self.task_resource.res_dir, + self.task_resource.res_dict['cfg_path']) + self.ckpt_path = os.path.join( + self.task_resource.res_dir, + self.task_resource.res_dict['ckpt_path']) + self.vocab_file = os.path.join( + self.task_resource.res_dir, + self.task_resource.res_dict['vocab_file']) + else: + self.cfg_path = os.path.abspath(cfg_path) + self.ckpt_path = os.path.abspath(ckpt_path) + self.vocab_file = os.path.abspath(vocab_file) + + model_name = model_type[:model_type.rindex('_')] + + if self.task == 'punc': + # punc list + self._punc_list = [] + with open(self.vocab_file, 'r') as f: + for line in f: + self._punc_list.append(line.strip()) + + # model + with open(self.cfg_path) as f: + config = CfgNode(yaml.safe_load(f)) + self.model = ErnieLinear(**config["model"]) + + _, tokenizer_class = self.task_resource.get_model_class(model_name) + state_dict = paddle.load(self.ckpt_path) + self.model.set_state_dict(state_dict["main_params"]) + self.model.eval() + + #tokenizer: fast version: ernie-3.0-mini-zh slow version:ernie-1.0 + if 'fast' not in model_type: + self.tokenizer = tokenizer_class.from_pretrained('ernie-1.0') + else: + self.tokenizer = tokenizer_class.from_pretrained( + 'ernie-3.0-mini-zh') + + else: + raise NotImplementedError + def _clean_text(self, text): text = text.lower() text = re.sub('[^A-Za-z0-9\u4e00-\u9fa5]', '', text) @@ -179,7 +242,7 @@ class TextExecutor(BaseExecutor): else: raise NotImplementedError - def postprocess(self) -> Union[str, os.PathLike]: + def postprocess(self, isNewTrainer: bool=False) -> Union[str, os.PathLike]: """ Output postprocess and return human-readable results such as texts and audio files. """ @@ -192,13 +255,13 @@ class TextExecutor(BaseExecutor): input_ids[1:seq_len - 1]) labels = preds[1:seq_len - 1].tolist() assert len(tokens) == len(labels) - + if isNewTrainer: + self._punc_list = [0] + self._punc_list text = '' for t, l in zip(tokens, labels): text += t if l != 0: # Non punc. text += self._punc_list[l] - return text else: raise NotImplementedError @@ -255,10 +318,20 @@ class TextExecutor(BaseExecutor): """ Python API to call an executor. """ - paddle.set_device(device) - self._init_from_path(task, model, lang, config, ckpt_path, punc_vocab) - self.preprocess(text) - self.infer() - res = self.postprocess() # Retrieve result of text task. - + #Here is old version models + if model in ['ernie_linear_p7_wudao', 'ernie_linear_p3_wudao']: + paddle.set_device(device) + self._init_from_path(task, model, lang, config, ckpt_path, + punc_vocab) + self.preprocess(text) + self.infer() + res = self.postprocess() # Retrieve result of text task. + #Add new way to infer + else: + paddle.set_device(device) + self._init_from_path_new(task, model, lang, config, ckpt_path, + punc_vocab) + self.preprocess(text) + self.infer() + res = self.postprocess(isNewTrainer=True) return res