From cb76e664017f15b7963eca0e126e5429f0a58ba9 Mon Sep 17 00:00:00 2001 From: dahu1 <707133607@qq.com> Date: Wed, 19 Oct 2022 15:54:08 +0800 Subject: [PATCH 1/3] =?UTF-8?q?1.token=E9=85=8D=E7=BD=AE=E4=B8=8D=E5=86=99?= =?UTF-8?q?=E6=AD=BB=EF=BC=8C2.text=E6=98=BE=E7=A4=BA=E4=B8=8D=E4=B9=B1?= =?UTF-8?q?=E7=A0=81,=20test=3Dasr?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../text/exps/ernie_linear/punc_restore.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/paddlespeech/text/exps/ernie_linear/punc_restore.py b/paddlespeech/text/exps/ernie_linear/punc_restore.py index 2cb4d071..98804606 100644 --- a/paddlespeech/text/exps/ernie_linear/punc_restore.py +++ b/paddlespeech/text/exps/ernie_linear/punc_restore.py @@ -25,8 +25,6 @@ DefinedClassifier = { 'ErnieLinear': ErnieLinear, } -tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0') - def _clean_text(text, punc_list): text = text.lower() @@ -35,7 +33,7 @@ def _clean_text(text, punc_list): return text -def preprocess(text, punc_list): +def preprocess(text, punc_list, tokenizer): clean_text = _clean_text(text, punc_list) assert len(clean_text) > 0, f'Invalid input string: {text}' tokenized_input = tokenizer( @@ -51,7 +49,8 @@ def test(args): with open(args.config) as f: config = CfgNode(yaml.safe_load(f)) print("========Args========") - print(yaml.safe_dump(vars(args))) + print(yaml.safe_dump(vars(args), allow_unicode=True)) + # print(args) print("========Config========") print(config) @@ -61,10 +60,16 @@ def test(args): punc_list.append(line.strip()) model = DefinedClassifier[config["model_type"]](**config["model"]) + # print(model) + + pretrained_token = config['data_params']['pretrained_token'] + tokenizer = ErnieTokenizer.from_pretrained(pretrained_token) + # tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0') + state_dict = paddle.load(args.checkpoint) model.set_state_dict(state_dict["main_params"]) model.eval() - _inputs = preprocess(args.text, punc_list) + _inputs = preprocess(args.text, punc_list, tokenizer) seq_len = _inputs['seq_len'] input_ids = paddle.to_tensor(_inputs['input_ids']).unsqueeze(0) seg_ids = paddle.to_tensor(_inputs['seg_ids']).unsqueeze(0) From da525d346f0a78fc1b6f11db408a5ce1a76c5610 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Thu, 20 Oct 2022 06:17:17 +0000 Subject: [PATCH 2/3] fix uvicorn's version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index e551d9fa..3353cdad 100644 --- a/setup.py +++ b/setup.py @@ -77,7 +77,7 @@ base = [ "pybind11", ] -server = ["fastapi", "uvicorn", "pattern_singleton", "websockets"] +server = ["fastapi", "uvicorn<=0.18.3", "pattern_singleton", "websockets"] requirements = { "install": From 63c80121e2c5691145a2bc8c49cf1a2b277c7067 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Thu, 20 Oct 2022 06:33:07 +0000 Subject: [PATCH 3/3] fix uvicorn's bug --- paddlespeech/server/bin/paddlespeech_server.py | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddlespeech/server/bin/paddlespeech_server.py b/paddlespeech/server/bin/paddlespeech_server.py index 10a91d9b..1b1792bd 100644 --- a/paddlespeech/server/bin/paddlespeech_server.py +++ b/paddlespeech/server/bin/paddlespeech_server.py @@ -113,7 +113,7 @@ class ServerExecutor(BaseExecutor): """ config = get_config(config_file) if self.init(config): - uvicorn.run(app, host=config.host, port=config.port, debug=True) + uvicorn.run(app, host=config.host, port=config.port) @cli_server_register( diff --git a/setup.py b/setup.py index 3353cdad..e551d9fa 100644 --- a/setup.py +++ b/setup.py @@ -77,7 +77,7 @@ base = [ "pybind11", ] -server = ["fastapi", "uvicorn<=0.18.3", "pattern_singleton", "websockets"] +server = ["fastapi", "uvicorn", "pattern_singleton", "websockets"] requirements = { "install":