|
|
|
@ -25,8 +25,6 @@ DefinedClassifier = {
|
|
|
|
|
'ErnieLinear': ErnieLinear,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _clean_text(text, punc_list):
|
|
|
|
|
text = text.lower()
|
|
|
|
@ -35,7 +33,7 @@ def _clean_text(text, punc_list):
|
|
|
|
|
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess(text, punc_list):
|
|
|
|
|
def preprocess(text, punc_list, tokenizer):
|
|
|
|
|
clean_text = _clean_text(text, punc_list)
|
|
|
|
|
assert len(clean_text) > 0, f'Invalid input string: {text}'
|
|
|
|
|
tokenized_input = tokenizer(
|
|
|
|
@ -51,7 +49,8 @@ def test(args):
|
|
|
|
|
with open(args.config) as f:
|
|
|
|
|
config = CfgNode(yaml.safe_load(f))
|
|
|
|
|
print("========Args========")
|
|
|
|
|
print(yaml.safe_dump(vars(args)))
|
|
|
|
|
print(yaml.safe_dump(vars(args), allow_unicode=True))
|
|
|
|
|
# print(args)
|
|
|
|
|
print("========Config========")
|
|
|
|
|
print(config)
|
|
|
|
|
|
|
|
|
@ -61,10 +60,16 @@ def test(args):
|
|
|
|
|
punc_list.append(line.strip())
|
|
|
|
|
|
|
|
|
|
model = DefinedClassifier[config["model_type"]](**config["model"])
|
|
|
|
|
# print(model)
|
|
|
|
|
|
|
|
|
|
pretrained_token = config['data_params']['pretrained_token']
|
|
|
|
|
tokenizer = ErnieTokenizer.from_pretrained(pretrained_token)
|
|
|
|
|
# tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')
|
|
|
|
|
|
|
|
|
|
state_dict = paddle.load(args.checkpoint)
|
|
|
|
|
model.set_state_dict(state_dict["main_params"])
|
|
|
|
|
model.eval()
|
|
|
|
|
_inputs = preprocess(args.text, punc_list)
|
|
|
|
|
_inputs = preprocess(args.text, punc_list, tokenizer)
|
|
|
|
|
seq_len = _inputs['seq_len']
|
|
|
|
|
input_ids = paddle.to_tensor(_inputs['input_ids']).unsqueeze(0)
|
|
|
|
|
seg_ids = paddle.to_tensor(_inputs['seg_ids']).unsqueeze(0)
|
|
|
|
|