From cb76e664017f15b7963eca0e126e5429f0a58ba9 Mon Sep 17 00:00:00 2001 From: dahu1 <707133607@qq.com> Date: Wed, 19 Oct 2022 15:54:08 +0800 Subject: [PATCH] =?UTF-8?q?1.token=E9=85=8D=E7=BD=AE=E4=B8=8D=E5=86=99?= =?UTF-8?q?=E6=AD=BB=EF=BC=8C2.text=E6=98=BE=E7=A4=BA=E4=B8=8D=E4=B9=B1?= =?UTF-8?q?=E7=A0=81,=20test=3Dasr?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../text/exps/ernie_linear/punc_restore.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/paddlespeech/text/exps/ernie_linear/punc_restore.py b/paddlespeech/text/exps/ernie_linear/punc_restore.py index 2cb4d071..98804606 100644 --- a/paddlespeech/text/exps/ernie_linear/punc_restore.py +++ b/paddlespeech/text/exps/ernie_linear/punc_restore.py @@ -25,8 +25,6 @@ DefinedClassifier = { 'ErnieLinear': ErnieLinear, } -tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0') - def _clean_text(text, punc_list): text = text.lower() @@ -35,7 +33,7 @@ def _clean_text(text, punc_list): return text -def preprocess(text, punc_list): +def preprocess(text, punc_list, tokenizer): clean_text = _clean_text(text, punc_list) assert len(clean_text) > 0, f'Invalid input string: {text}' tokenized_input = tokenizer( @@ -51,7 +49,8 @@ def test(args): with open(args.config) as f: config = CfgNode(yaml.safe_load(f)) print("========Args========") - print(yaml.safe_dump(vars(args))) + print(yaml.safe_dump(vars(args), allow_unicode=True)) + # print(args) print("========Config========") print(config) @@ -61,10 +60,16 @@ def test(args): punc_list.append(line.strip()) model = DefinedClassifier[config["model_type"]](**config["model"]) + # print(model) + + pretrained_token = config['data_params']['pretrained_token'] + tokenizer = ErnieTokenizer.from_pretrained(pretrained_token) + # tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0') + state_dict = paddle.load(args.checkpoint) model.set_state_dict(state_dict["main_params"]) model.eval() - _inputs = preprocess(args.text, punc_list) + _inputs = preprocess(args.text, punc_list, tokenizer) seq_len = _inputs['seq_len'] input_ids = paddle.to_tensor(_inputs['input_ids']).unsqueeze(0) seg_ids = paddle.to_tensor(_inputs['seg_ids']).unsqueeze(0)