Merge pull request #2554 from dahu1/develop

标点恢复代码更新,test=asr
pull/2558/head
TianYuan 2 years ago committed by GitHub
commit 2a60c3d854
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -25,8 +25,6 @@ DefinedClassifier = {
'ErnieLinear': ErnieLinear,
}
tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')
def _clean_text(text, punc_list):
text = text.lower()
@ -35,7 +33,7 @@ def _clean_text(text, punc_list):
return text
def preprocess(text, punc_list):
def preprocess(text, punc_list, tokenizer):
clean_text = _clean_text(text, punc_list)
assert len(clean_text) > 0, f'Invalid input string: {text}'
tokenized_input = tokenizer(
@ -51,7 +49,8 @@ def test(args):
with open(args.config) as f:
config = CfgNode(yaml.safe_load(f))
print("========Args========")
print(yaml.safe_dump(vars(args)))
print(yaml.safe_dump(vars(args), allow_unicode=True))
# print(args)
print("========Config========")
print(config)
@ -61,10 +60,16 @@ def test(args):
punc_list.append(line.strip())
model = DefinedClassifier[config["model_type"]](**config["model"])
# print(model)
pretrained_token = config['data_params']['pretrained_token']
tokenizer = ErnieTokenizer.from_pretrained(pretrained_token)
# tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')
state_dict = paddle.load(args.checkpoint)
model.set_state_dict(state_dict["main_params"])
model.eval()
_inputs = preprocess(args.text, punc_list)
_inputs = preprocess(args.text, punc_list, tokenizer)
seq_len = _inputs['seq_len']
input_ids = paddle.to_tensor(_inputs['input_ids']).unsqueeze(0)
seg_ids = paddle.to_tensor(_inputs['seg_ids']).unsqueeze(0)

Loading…
Cancel
Save