|
|
|
@ -20,10 +20,13 @@ from typing import Optional
|
|
|
|
|
from typing import Union
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
|
import yaml
|
|
|
|
|
from yacs.config import CfgNode
|
|
|
|
|
|
|
|
|
|
from ..executor import BaseExecutor
|
|
|
|
|
from ..log import logger
|
|
|
|
|
from ..utils import stats_wrapper
|
|
|
|
|
from paddlespeech.text.models.ernie_linear import ErnieLinear
|
|
|
|
|
|
|
|
|
|
__all__ = ['TextExecutor']
|
|
|
|
|
|
|
|
|
@ -139,6 +142,66 @@ class TextExecutor(BaseExecutor):
|
|
|
|
|
|
|
|
|
|
self.model.eval()
|
|
|
|
|
|
|
|
|
|
#init new models
|
|
|
|
|
def _init_from_path_new(self,
|
|
|
|
|
task: str='punc',
|
|
|
|
|
model_type: str='ernie_linear_p7_wudao',
|
|
|
|
|
lang: str='zh',
|
|
|
|
|
cfg_path: Optional[os.PathLike]=None,
|
|
|
|
|
ckpt_path: Optional[os.PathLike]=None,
|
|
|
|
|
vocab_file: Optional[os.PathLike]=None):
|
|
|
|
|
if hasattr(self, 'model'):
|
|
|
|
|
logger.debug('Model had been initialized.')
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
self.task = task
|
|
|
|
|
|
|
|
|
|
if cfg_path is None or ckpt_path is None or vocab_file is None:
|
|
|
|
|
tag = '-'.join([model_type, task, lang])
|
|
|
|
|
self.task_resource.set_task_model(tag, version=None)
|
|
|
|
|
self.cfg_path = os.path.join(
|
|
|
|
|
self.task_resource.res_dir,
|
|
|
|
|
self.task_resource.res_dict['cfg_path'])
|
|
|
|
|
self.ckpt_path = os.path.join(
|
|
|
|
|
self.task_resource.res_dir,
|
|
|
|
|
self.task_resource.res_dict['ckpt_path'])
|
|
|
|
|
self.vocab_file = os.path.join(
|
|
|
|
|
self.task_resource.res_dir,
|
|
|
|
|
self.task_resource.res_dict['vocab_file'])
|
|
|
|
|
else:
|
|
|
|
|
self.cfg_path = os.path.abspath(cfg_path)
|
|
|
|
|
self.ckpt_path = os.path.abspath(ckpt_path)
|
|
|
|
|
self.vocab_file = os.path.abspath(vocab_file)
|
|
|
|
|
|
|
|
|
|
model_name = model_type[:model_type.rindex('_')]
|
|
|
|
|
|
|
|
|
|
if self.task == 'punc':
|
|
|
|
|
# punc list
|
|
|
|
|
self._punc_list = []
|
|
|
|
|
with open(self.vocab_file, 'r') as f:
|
|
|
|
|
for line in f:
|
|
|
|
|
self._punc_list.append(line.strip())
|
|
|
|
|
|
|
|
|
|
# model
|
|
|
|
|
with open(self.cfg_path) as f:
|
|
|
|
|
config = CfgNode(yaml.safe_load(f))
|
|
|
|
|
self.model = ErnieLinear(**config["model"])
|
|
|
|
|
|
|
|
|
|
_, tokenizer_class = self.task_resource.get_model_class(model_name)
|
|
|
|
|
state_dict = paddle.load(self.ckpt_path)
|
|
|
|
|
self.model.set_state_dict(state_dict["main_params"])
|
|
|
|
|
self.model.eval()
|
|
|
|
|
|
|
|
|
|
#tokenizer: fast version: ernie-3.0-mini-zh slow version:ernie-1.0
|
|
|
|
|
if 'fast' not in model_type:
|
|
|
|
|
self.tokenizer = tokenizer_class.from_pretrained('ernie-1.0')
|
|
|
|
|
else:
|
|
|
|
|
self.tokenizer = tokenizer_class.from_pretrained(
|
|
|
|
|
'ernie-3.0-mini-zh')
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
def _clean_text(self, text):
|
|
|
|
|
text = text.lower()
|
|
|
|
|
text = re.sub('[^A-Za-z0-9\u4e00-\u9fa5]', '', text)
|
|
|
|
@ -179,7 +242,7 @@ class TextExecutor(BaseExecutor):
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
def postprocess(self) -> Union[str, os.PathLike]:
|
|
|
|
|
def postprocess(self, isNewTrainer: bool=False) -> Union[str, os.PathLike]:
|
|
|
|
|
"""
|
|
|
|
|
Output postprocess and return human-readable results such as texts and audio files.
|
|
|
|
|
"""
|
|
|
|
@ -192,13 +255,13 @@ class TextExecutor(BaseExecutor):
|
|
|
|
|
input_ids[1:seq_len - 1])
|
|
|
|
|
labels = preds[1:seq_len - 1].tolist()
|
|
|
|
|
assert len(tokens) == len(labels)
|
|
|
|
|
|
|
|
|
|
if isNewTrainer:
|
|
|
|
|
self._punc_list = [0] + self._punc_list
|
|
|
|
|
text = ''
|
|
|
|
|
for t, l in zip(tokens, labels):
|
|
|
|
|
text += t
|
|
|
|
|
if l != 0: # Non punc.
|
|
|
|
|
text += self._punc_list[l]
|
|
|
|
|
|
|
|
|
|
return text
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
@ -255,10 +318,20 @@ class TextExecutor(BaseExecutor):
|
|
|
|
|
"""
|
|
|
|
|
Python API to call an executor.
|
|
|
|
|
"""
|
|
|
|
|
paddle.set_device(device)
|
|
|
|
|
self._init_from_path(task, model, lang, config, ckpt_path, punc_vocab)
|
|
|
|
|
self.preprocess(text)
|
|
|
|
|
self.infer()
|
|
|
|
|
res = self.postprocess() # Retrieve result of text task.
|
|
|
|
|
|
|
|
|
|
#Here is old version models
|
|
|
|
|
if model in ['ernie_linear_p7_wudao', 'ernie_linear_p3_wudao']:
|
|
|
|
|
paddle.set_device(device)
|
|
|
|
|
self._init_from_path(task, model, lang, config, ckpt_path,
|
|
|
|
|
punc_vocab)
|
|
|
|
|
self.preprocess(text)
|
|
|
|
|
self.infer()
|
|
|
|
|
res = self.postprocess() # Retrieve result of text task.
|
|
|
|
|
#Add new way to infer
|
|
|
|
|
else:
|
|
|
|
|
paddle.set_device(device)
|
|
|
|
|
self._init_from_path_new(task, model, lang, config, ckpt_path,
|
|
|
|
|
punc_vocab)
|
|
|
|
|
self.preprocess(text)
|
|
|
|
|
self.infer()
|
|
|
|
|
res = self.postprocess(isNewTrainer=True)
|
|
|
|
|
return res
|
|
|
|
|