diff --git a/paddlespeech/cli/README.md b/paddlespeech/cli/README.md index 19c822040..e6e216c0b 100644 --- a/paddlespeech/cli/README.md +++ b/paddlespeech/cli/README.md @@ -42,3 +42,7 @@ ```bash paddlespeech text --task punc --input 今天的天气真不错啊你下午有空吗我想约你一起去吃饭 ``` +- Faster Punctuation Restoration + ```bash + paddlespeech text --task punc --input 今天的天气真不错啊你下午有空吗我想约你一起去吃饭 --model ernie_linear_p3_wudao_fast + ``` diff --git a/paddlespeech/cli/README_cn.md b/paddlespeech/cli/README_cn.md index 4b15d6c7b..6464c598c 100644 --- a/paddlespeech/cli/README_cn.md +++ b/paddlespeech/cli/README_cn.md @@ -43,3 +43,7 @@ ```bash paddlespeech text --task punc --input 今天的天气真不错啊你下午有空吗我想约你一起去吃饭 ``` +- 快速标点恢复 + ```bash + paddlespeech text --task punc --input 今天的天气真不错啊你下午有空吗我想约你一起去吃饭 --model ernie_linear_p3_wudao_fast + ``` diff --git a/paddlespeech/cli/text/infer.py b/paddlespeech/cli/text/infer.py index 24b8c9c25..ff822f674 100644 --- a/paddlespeech/cli/text/infer.py +++ b/paddlespeech/cli/text/infer.py @@ -20,10 +20,13 @@ from typing import Optional from typing import Union import paddle +import yaml +from yacs.config import CfgNode from ..executor import BaseExecutor from ..log import logger from ..utils import stats_wrapper +from paddlespeech.text.models.ernie_linear import ErnieLinear __all__ = ['TextExecutor'] @@ -139,6 +142,66 @@ class TextExecutor(BaseExecutor): self.model.eval() + #init new models + def _init_from_path_new(self, + task: str='punc', + model_type: str='ernie_linear_p7_wudao', + lang: str='zh', + cfg_path: Optional[os.PathLike]=None, + ckpt_path: Optional[os.PathLike]=None, + vocab_file: Optional[os.PathLike]=None): + if hasattr(self, 'model'): + logger.debug('Model had been initialized.') + return + + self.task = task + + if cfg_path is None or ckpt_path is None or vocab_file is None: + tag = '-'.join([model_type, task, lang]) + self.task_resource.set_task_model(tag, version=None) + self.cfg_path = os.path.join( + self.task_resource.res_dir, + self.task_resource.res_dict['cfg_path']) + self.ckpt_path = os.path.join( + self.task_resource.res_dir, + self.task_resource.res_dict['ckpt_path']) + self.vocab_file = os.path.join( + self.task_resource.res_dir, + self.task_resource.res_dict['vocab_file']) + else: + self.cfg_path = os.path.abspath(cfg_path) + self.ckpt_path = os.path.abspath(ckpt_path) + self.vocab_file = os.path.abspath(vocab_file) + + model_name = model_type[:model_type.rindex('_')] + + if self.task == 'punc': + # punc list + self._punc_list = [] + with open(self.vocab_file, 'r') as f: + for line in f: + self._punc_list.append(line.strip()) + + # model + with open(self.cfg_path) as f: + config = CfgNode(yaml.safe_load(f)) + self.model = ErnieLinear(**config["model"]) + + _, tokenizer_class = self.task_resource.get_model_class(model_name) + state_dict = paddle.load(self.ckpt_path) + self.model.set_state_dict(state_dict["main_params"]) + self.model.eval() + + #tokenizer: fast version: ernie-3.0-mini-zh slow version:ernie-1.0 + if 'fast' not in model_type: + self.tokenizer = tokenizer_class.from_pretrained('ernie-1.0') + else: + self.tokenizer = tokenizer_class.from_pretrained( + 'ernie-3.0-mini-zh') + + else: + raise NotImplementedError + def _clean_text(self, text): text = text.lower() text = re.sub('[^A-Za-z0-9\u4e00-\u9fa5]', '', text) @@ -179,7 +242,7 @@ class TextExecutor(BaseExecutor): else: raise NotImplementedError - def postprocess(self) -> Union[str, os.PathLike]: + def postprocess(self, isNewTrainer: bool=False) -> Union[str, os.PathLike]: """ Output postprocess and return human-readable results such as texts and audio files. """ @@ -192,13 +255,13 @@ class TextExecutor(BaseExecutor): input_ids[1:seq_len - 1]) labels = preds[1:seq_len - 1].tolist() assert len(tokens) == len(labels) - + if isNewTrainer: + self._punc_list = [0] + self._punc_list text = '' for t, l in zip(tokens, labels): text += t if l != 0: # Non punc. text += self._punc_list[l] - return text else: raise NotImplementedError @@ -255,10 +318,20 @@ class TextExecutor(BaseExecutor): """ Python API to call an executor. """ - paddle.set_device(device) - self._init_from_path(task, model, lang, config, ckpt_path, punc_vocab) - self.preprocess(text) - self.infer() - res = self.postprocess() # Retrieve result of text task. - + #Here is old version models + if model in ['ernie_linear_p7_wudao', 'ernie_linear_p3_wudao']: + paddle.set_device(device) + self._init_from_path(task, model, lang, config, ckpt_path, + punc_vocab) + self.preprocess(text) + self.infer() + res = self.postprocess() # Retrieve result of text task. + #Add new way to infer + else: + paddle.set_device(device) + self._init_from_path_new(task, model, lang, config, ckpt_path, + punc_vocab) + self.preprocess(text) + self.infer() + res = self.postprocess(isNewTrainer=True) return res diff --git a/paddlespeech/resource/model_alias.py b/paddlespeech/resource/model_alias.py index 9c76dd4b3..85187a8d1 100644 --- a/paddlespeech/resource/model_alias.py +++ b/paddlespeech/resource/model_alias.py @@ -51,6 +51,10 @@ model_alias = { "paddlespeech.text.models:ErnieLinear", "paddlenlp.transformers:ErnieTokenizer" ], + "ernie_linear_p3_wudao": [ + "paddlespeech.text.models:ErnieLinear", + "paddlenlp.transformers:ErnieTokenizer" + ], # --------------------------------- # -------------- TTS -------------- diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py index f049879a3..b6ab7f01c 100644 --- a/paddlespeech/resource/pretrained_models.py +++ b/paddlespeech/resource/pretrained_models.py @@ -529,7 +529,7 @@ text_dynamic_pretrained_models = { 'ckpt/model_state.pdparams', 'vocab_file': 'punc_vocab.txt', - }, + } }, "ernie_linear_p3_wudao-punc-zh": { '1.0': { @@ -543,8 +543,22 @@ text_dynamic_pretrained_models = { 'ckpt/model_state.pdparams', 'vocab_file': 'punc_vocab.txt', - }, + } }, + "ernie_linear_p3_wudao_fast-punc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p3_wudao_fast-punc-zh.tar.gz', + 'md5': + 'c93f9594119541a5dbd763381a751d08', + 'cfg_path': + 'ckpt/model_config.json', + 'ckpt_path': + 'ckpt/model_state.pdparams', + 'vocab_file': + 'punc_vocab.txt', + } + } } # --------------------------------- diff --git a/tests/unit/cli/test_cli.sh b/tests/unit/cli/test_cli.sh index 15604961d..c6837c303 100755 --- a/tests/unit/cli/test_cli.sh +++ b/tests/unit/cli/test_cli.sh @@ -7,7 +7,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/cat.wav https://paddlespe paddlespeech cls --input ./cat.wav --topk 10 # Punctuation_restoration -paddlespeech text --input 今天的天气真不错啊你下午有空吗我想约你一起去吃饭 +paddlespeech text --input 今天的天气真不错啊你下午有空吗我想约你一起去吃饭 --model ernie_linear_p3_wudao_fast # Speech_recognition wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav