From fb7e911caa0bbffaf683fdac02b95326351e805e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BD=98=E5=BF=97=E5=BC=BA?= <85210755@qq.com> Date: Wed, 11 Dec 2024 18:10:36 +0800 Subject: [PATCH] fix enable the punctuation for asr, test=asr(#3631) --- paddlespeech/server/engine/engine_pool.py | 5 +++-- paddlespeech/server/restful/asr_api.py | 6 ++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/paddlespeech/server/engine/engine_pool.py b/paddlespeech/server/engine/engine_pool.py index 5300303f6..81ed3cb9a 100644 --- a/paddlespeech/server/engine/engine_pool.py +++ b/paddlespeech/server/engine/engine_pool.py @@ -30,8 +30,9 @@ def init_engine_pool(config) -> bool: global ENGINE_POOL for engine_and_type in config.engine_list: - engine = engine_and_type.split("_")[0] - engine_type = engine_and_type.split("_")[1] + engine, engine_type = engine_and_type.split("_") + assert (engine != 'asr') or ('text_python' in config.engine_list), \ + "When engine is 'asr', must be enabled 'text_python' to support punctuation recovery" ENGINE_POOL[engine] = EngineFactory.get_engine( engine_name=engine, engine_type=engine_type) diff --git a/paddlespeech/server/restful/asr_api.py b/paddlespeech/server/restful/asr_api.py index c7bc50ce4..bc6342ce1 100644 --- a/paddlespeech/server/restful/asr_api.py +++ b/paddlespeech/server/restful/asr_api.py @@ -26,6 +26,7 @@ from paddlespeech.server.restful.response import ErrorResponse from paddlespeech.server.utils.errors import ErrorCode from paddlespeech.server.utils.errors import failed_response from paddlespeech.server.utils.exception import ServerBaseException +from paddlespeech.server.engine.text.python.text_engine import PaddleTextConnectionHandler router = APIRouter() @@ -83,6 +84,11 @@ def asr(request_body: ASRRequest): connection_handler.run(audio_data) asr_results = connection_handler.postprocess() + if request_body.punc: + text_engine = engine_pool['text'] + connection_handler = PaddleTextConnectionHandler(text_engine) + asr_results = connection_handler.run(asr_results) + response = { "success": True, "code": 200,