Merge branch 'PaddlePaddle:develop' into hongliang1014

pull/2531/head
David An (An Hongliang) 2 years ago committed by GitHub
commit 21cce0e0bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -113,7 +113,7 @@ class ServerExecutor(BaseExecutor):
""" """
config = get_config(config_file) config = get_config(config_file)
if self.init(config): if self.init(config):
uvicorn.run(app, host=config.host, port=config.port, debug=True) uvicorn.run(app, host=config.host, port=config.port)
@cli_server_register( @cli_server_register(

@ -25,8 +25,6 @@ DefinedClassifier = {
'ErnieLinear': ErnieLinear, 'ErnieLinear': ErnieLinear,
} }
tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')
def _clean_text(text, punc_list): def _clean_text(text, punc_list):
text = text.lower() text = text.lower()
@ -35,7 +33,7 @@ def _clean_text(text, punc_list):
return text return text
def preprocess(text, punc_list): def preprocess(text, punc_list, tokenizer):
clean_text = _clean_text(text, punc_list) clean_text = _clean_text(text, punc_list)
assert len(clean_text) > 0, f'Invalid input string: {text}' assert len(clean_text) > 0, f'Invalid input string: {text}'
tokenized_input = tokenizer( tokenized_input = tokenizer(
@ -51,7 +49,8 @@ def test(args):
with open(args.config) as f: with open(args.config) as f:
config = CfgNode(yaml.safe_load(f)) config = CfgNode(yaml.safe_load(f))
print("========Args========") print("========Args========")
print(yaml.safe_dump(vars(args))) print(yaml.safe_dump(vars(args), allow_unicode=True))
# print(args)
print("========Config========") print("========Config========")
print(config) print(config)
@ -61,10 +60,16 @@ def test(args):
punc_list.append(line.strip()) punc_list.append(line.strip())
model = DefinedClassifier[config["model_type"]](**config["model"]) model = DefinedClassifier[config["model_type"]](**config["model"])
# print(model)
pretrained_token = config['data_params']['pretrained_token']
tokenizer = ErnieTokenizer.from_pretrained(pretrained_token)
# tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')
state_dict = paddle.load(args.checkpoint) state_dict = paddle.load(args.checkpoint)
model.set_state_dict(state_dict["main_params"]) model.set_state_dict(state_dict["main_params"])
model.eval() model.eval()
_inputs = preprocess(args.text, punc_list) _inputs = preprocess(args.text, punc_list, tokenizer)
seq_len = _inputs['seq_len'] seq_len = _inputs['seq_len']
input_ids = paddle.to_tensor(_inputs['input_ids']).unsqueeze(0) input_ids = paddle.to_tensor(_inputs['input_ids']).unsqueeze(0)
seg_ids = paddle.to_tensor(_inputs['seg_ids']).unsqueeze(0) seg_ids = paddle.to_tensor(_inputs['seg_ids']).unsqueeze(0)

Loading…
Cancel
Save