Update infer.py

change the infer in order to implement the new faster model for text
pull/2421/head
Zhao Yuting 2 years ago committed by GitHub
parent b627666ce9
commit 57dcd0d17f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -20,10 +20,13 @@ from typing import Optional
from typing import Union
import paddle
import yaml
from yacs.config import CfgNode
from ..executor import BaseExecutor
from ..log import logger
from ..utils import stats_wrapper
from paddlespeech.text.models.ernie_linear import ErnieLinear
__all__ = ['TextExecutor']
@ -139,6 +142,66 @@ class TextExecutor(BaseExecutor):
self.model.eval()
#init new models
def _init_from_path_new(self,
task: str='punc',
model_type: str='ernie_linear_p7_wudao',
lang: str='zh',
cfg_path: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None,
vocab_file: Optional[os.PathLike]=None):
if hasattr(self, 'model'):
logger.debug('Model had been initialized.')
return
self.task = task
if cfg_path is None or ckpt_path is None or vocab_file is None:
tag = '-'.join([model_type, task, lang])
self.task_resource.set_task_model(tag, version=None)
self.cfg_path = os.path.join(
self.task_resource.res_dir,
self.task_resource.res_dict['cfg_path'])
self.ckpt_path = os.path.join(
self.task_resource.res_dir,
self.task_resource.res_dict['ckpt_path'])
self.vocab_file = os.path.join(
self.task_resource.res_dir,
self.task_resource.res_dict['vocab_file'])
else:
self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path)
self.vocab_file = os.path.abspath(vocab_file)
model_name = model_type[:model_type.rindex('_')]
if self.task == 'punc':
# punc list
self._punc_list = []
with open(self.vocab_file, 'r') as f:
for line in f:
self._punc_list.append(line.strip())
# model
with open(self.cfg_path) as f:
config = CfgNode(yaml.safe_load(f))
self.model = ErnieLinear(**config["model"])
_, tokenizer_class = self.task_resource.get_model_class(model_name)
state_dict = paddle.load(self.ckpt_path)
self.model.set_state_dict(state_dict["main_params"])
self.model.eval()
#tokenizer: fast version: ernie-3.0-mini-zh slow version:ernie-1.0
if 'fast' not in model_type:
self.tokenizer = tokenizer_class.from_pretrained('ernie-1.0')
else:
self.tokenizer = tokenizer_class.from_pretrained(
'ernie-3.0-mini-zh')
else:
raise NotImplementedError
def _clean_text(self, text):
text = text.lower()
text = re.sub('[^A-Za-z0-9\u4e00-\u9fa5]', '', text)
@ -179,7 +242,7 @@ class TextExecutor(BaseExecutor):
else:
raise NotImplementedError
def postprocess(self) -> Union[str, os.PathLike]:
def postprocess(self, isNewTrainer: bool=False) -> Union[str, os.PathLike]:
"""
Output postprocess and return human-readable results such as texts and audio files.
"""
@ -192,13 +255,13 @@ class TextExecutor(BaseExecutor):
input_ids[1:seq_len - 1])
labels = preds[1:seq_len - 1].tolist()
assert len(tokens) == len(labels)
if isNewTrainer:
self._punc_list = [0] + self._punc_list
text = ''
for t, l in zip(tokens, labels):
text += t
if l != 0: # Non punc.
text += self._punc_list[l]
return text
else:
raise NotImplementedError
@ -255,10 +318,20 @@ class TextExecutor(BaseExecutor):
"""
Python API to call an executor.
"""
#Here is old version models
if model in ['ernie_linear_p7_wudao', 'ernie_linear_p3_wudao']:
paddle.set_device(device)
self._init_from_path(task, model, lang, config, ckpt_path, punc_vocab)
self._init_from_path(task, model, lang, config, ckpt_path,
punc_vocab)
self.preprocess(text)
self.infer()
res = self.postprocess() # Retrieve result of text task.
#Add new way to infer
else:
paddle.set_device(device)
self._init_from_path_new(task, model, lang, config, ckpt_path,
punc_vocab)
self.preprocess(text)
self.infer()
res = self.postprocess(isNewTrainer=True)
return res

Loading…
Cancel
Save