Update infer.py

change the infer in order to implement the new faster model for text
pull/2421/head
Zhao Yuting 2 years ago committed by GitHub
parent b627666ce9
commit 57dcd0d17f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -20,10 +20,13 @@ from typing import Optional
from typing import Union from typing import Union
import paddle import paddle
import yaml
from yacs.config import CfgNode
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..log import logger from ..log import logger
from ..utils import stats_wrapper from ..utils import stats_wrapper
from paddlespeech.text.models.ernie_linear import ErnieLinear
__all__ = ['TextExecutor'] __all__ = ['TextExecutor']
@ -139,6 +142,66 @@ class TextExecutor(BaseExecutor):
self.model.eval() self.model.eval()
#init new models
def _init_from_path_new(self,
task: str='punc',
model_type: str='ernie_linear_p7_wudao',
lang: str='zh',
cfg_path: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None,
vocab_file: Optional[os.PathLike]=None):
if hasattr(self, 'model'):
logger.debug('Model had been initialized.')
return
self.task = task
if cfg_path is None or ckpt_path is None or vocab_file is None:
tag = '-'.join([model_type, task, lang])
self.task_resource.set_task_model(tag, version=None)
self.cfg_path = os.path.join(
self.task_resource.res_dir,
self.task_resource.res_dict['cfg_path'])
self.ckpt_path = os.path.join(
self.task_resource.res_dir,
self.task_resource.res_dict['ckpt_path'])
self.vocab_file = os.path.join(
self.task_resource.res_dir,
self.task_resource.res_dict['vocab_file'])
else:
self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path)
self.vocab_file = os.path.abspath(vocab_file)
model_name = model_type[:model_type.rindex('_')]
if self.task == 'punc':
# punc list
self._punc_list = []
with open(self.vocab_file, 'r') as f:
for line in f:
self._punc_list.append(line.strip())
# model
with open(self.cfg_path) as f:
config = CfgNode(yaml.safe_load(f))
self.model = ErnieLinear(**config["model"])
_, tokenizer_class = self.task_resource.get_model_class(model_name)
state_dict = paddle.load(self.ckpt_path)
self.model.set_state_dict(state_dict["main_params"])
self.model.eval()
#tokenizer: fast version: ernie-3.0-mini-zh slow version:ernie-1.0
if 'fast' not in model_type:
self.tokenizer = tokenizer_class.from_pretrained('ernie-1.0')
else:
self.tokenizer = tokenizer_class.from_pretrained(
'ernie-3.0-mini-zh')
else:
raise NotImplementedError
def _clean_text(self, text): def _clean_text(self, text):
text = text.lower() text = text.lower()
text = re.sub('[^A-Za-z0-9\u4e00-\u9fa5]', '', text) text = re.sub('[^A-Za-z0-9\u4e00-\u9fa5]', '', text)
@ -179,7 +242,7 @@ class TextExecutor(BaseExecutor):
else: else:
raise NotImplementedError raise NotImplementedError
def postprocess(self) -> Union[str, os.PathLike]: def postprocess(self, isNewTrainer: bool=False) -> Union[str, os.PathLike]:
""" """
Output postprocess and return human-readable results such as texts and audio files. Output postprocess and return human-readable results such as texts and audio files.
""" """
@ -192,13 +255,13 @@ class TextExecutor(BaseExecutor):
input_ids[1:seq_len - 1]) input_ids[1:seq_len - 1])
labels = preds[1:seq_len - 1].tolist() labels = preds[1:seq_len - 1].tolist()
assert len(tokens) == len(labels) assert len(tokens) == len(labels)
if isNewTrainer:
self._punc_list = [0] + self._punc_list
text = '' text = ''
for t, l in zip(tokens, labels): for t, l in zip(tokens, labels):
text += t text += t
if l != 0: # Non punc. if l != 0: # Non punc.
text += self._punc_list[l] text += self._punc_list[l]
return text return text
else: else:
raise NotImplementedError raise NotImplementedError
@ -255,10 +318,20 @@ class TextExecutor(BaseExecutor):
""" """
Python API to call an executor. Python API to call an executor.
""" """
paddle.set_device(device) #Here is old version models
self._init_from_path(task, model, lang, config, ckpt_path, punc_vocab) if model in ['ernie_linear_p7_wudao', 'ernie_linear_p3_wudao']:
self.preprocess(text) paddle.set_device(device)
self.infer() self._init_from_path(task, model, lang, config, ckpt_path,
res = self.postprocess() # Retrieve result of text task. punc_vocab)
self.preprocess(text)
self.infer()
res = self.postprocess() # Retrieve result of text task.
#Add new way to infer
else:
paddle.set_device(device)
self._init_from_path_new(task, model, lang, config, ckpt_path,
punc_vocab)
self.preprocess(text)
self.infer()
res = self.postprocess(isNewTrainer=True)
return res return res

Loading…
Cancel
Save