yes
pull/2548/head
WongLaw 2 years ago
commit cc05147163

@ -25,8 +25,6 @@ DefinedClassifier = {
'ErnieLinear': ErnieLinear, 'ErnieLinear': ErnieLinear,
} }
tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')
def _clean_text(text, punc_list): def _clean_text(text, punc_list):
text = text.lower() text = text.lower()
@ -35,7 +33,7 @@ def _clean_text(text, punc_list):
return text return text
def preprocess(text, punc_list): def preprocess(text, punc_list, tokenizer):
clean_text = _clean_text(text, punc_list) clean_text = _clean_text(text, punc_list)
assert len(clean_text) > 0, f'Invalid input string: {text}' assert len(clean_text) > 0, f'Invalid input string: {text}'
tokenized_input = tokenizer( tokenized_input = tokenizer(
@ -51,7 +49,8 @@ def test(args):
with open(args.config) as f: with open(args.config) as f:
config = CfgNode(yaml.safe_load(f)) config = CfgNode(yaml.safe_load(f))
print("========Args========") print("========Args========")
print(yaml.safe_dump(vars(args))) print(yaml.safe_dump(vars(args), allow_unicode=True))
# print(args)
print("========Config========") print("========Config========")
print(config) print(config)
@ -61,10 +60,16 @@ def test(args):
punc_list.append(line.strip()) punc_list.append(line.strip())
model = DefinedClassifier[config["model_type"]](**config["model"]) model = DefinedClassifier[config["model_type"]](**config["model"])
# print(model)
pretrained_token = config['data_params']['pretrained_token']
tokenizer = ErnieTokenizer.from_pretrained(pretrained_token)
# tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')
state_dict = paddle.load(args.checkpoint) state_dict = paddle.load(args.checkpoint)
model.set_state_dict(state_dict["main_params"]) model.set_state_dict(state_dict["main_params"])
model.eval() model.eval()
_inputs = preprocess(args.text, punc_list) _inputs = preprocess(args.text, punc_list, tokenizer)
seq_len = _inputs['seq_len'] seq_len = _inputs['seq_len']
input_ids = paddle.to_tensor(_inputs['input_ids']).unsqueeze(0) input_ids = paddle.to_tensor(_inputs['input_ids']).unsqueeze(0)
seg_ids = paddle.to_tensor(_inputs['seg_ids']).unsqueeze(0) seg_ids = paddle.to_tensor(_inputs['seg_ids']).unsqueeze(0)

Loading…
Cancel
Save