diff --git a/paddlespeech/text/exps/ernie_linear/punc_restore.py b/paddlespeech/text/exps/ernie_linear/punc_restore.py index 2cb4d071..98804606 100644 --- a/paddlespeech/text/exps/ernie_linear/punc_restore.py +++ b/paddlespeech/text/exps/ernie_linear/punc_restore.py @@ -25,8 +25,6 @@ DefinedClassifier = { 'ErnieLinear': ErnieLinear, } -tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0') - def _clean_text(text, punc_list): text = text.lower() @@ -35,7 +33,7 @@ def _clean_text(text, punc_list): return text -def preprocess(text, punc_list): +def preprocess(text, punc_list, tokenizer): clean_text = _clean_text(text, punc_list) assert len(clean_text) > 0, f'Invalid input string: {text}' tokenized_input = tokenizer( @@ -51,7 +49,8 @@ def test(args): with open(args.config) as f: config = CfgNode(yaml.safe_load(f)) print("========Args========") - print(yaml.safe_dump(vars(args))) + print(yaml.safe_dump(vars(args), allow_unicode=True)) + # print(args) print("========Config========") print(config) @@ -61,10 +60,16 @@ def test(args): punc_list.append(line.strip()) model = DefinedClassifier[config["model_type"]](**config["model"]) + # print(model) + + pretrained_token = config['data_params']['pretrained_token'] + tokenizer = ErnieTokenizer.from_pretrained(pretrained_token) + # tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0') + state_dict = paddle.load(args.checkpoint) model.set_state_dict(state_dict["main_params"]) model.eval() - _inputs = preprocess(args.text, punc_list) + _inputs = preprocess(args.text, punc_list, tokenizer) seq_len = _inputs['seq_len'] input_ids = paddle.to_tensor(_inputs['input_ids']).unsqueeze(0) seg_ids = paddle.to_tensor(_inputs['seg_ids']).unsqueeze(0)