diff --git a/paddlespeech/server/engine/engine_pool.py b/paddlespeech/server/engine/engine_pool.py index 5300303f6..81ed3cb9a 100644 --- a/paddlespeech/server/engine/engine_pool.py +++ b/paddlespeech/server/engine/engine_pool.py @@ -30,8 +30,9 @@ def init_engine_pool(config) -> bool: global ENGINE_POOL for engine_and_type in config.engine_list: - engine = engine_and_type.split("_")[0] - engine_type = engine_and_type.split("_")[1] + engine, engine_type = engine_and_type.split("_") + assert (engine != 'asr') or ('text_python' in config.engine_list), \ + "When engine is 'asr', must be enabled 'text_python' to support punctuation recovery" ENGINE_POOL[engine] = EngineFactory.get_engine( engine_name=engine, engine_type=engine_type) diff --git a/paddlespeech/server/restful/asr_api.py b/paddlespeech/server/restful/asr_api.py index c7bc50ce4..bc6342ce1 100644 --- a/paddlespeech/server/restful/asr_api.py +++ b/paddlespeech/server/restful/asr_api.py @@ -26,6 +26,7 @@ from paddlespeech.server.restful.response import ErrorResponse from paddlespeech.server.utils.errors import ErrorCode from paddlespeech.server.utils.errors import failed_response from paddlespeech.server.utils.exception import ServerBaseException +from paddlespeech.server.engine.text.python.text_engine import PaddleTextConnectionHandler router = APIRouter() @@ -83,6 +84,11 @@ def asr(request_body: ASRRequest): connection_handler.run(audio_data) asr_results = connection_handler.postprocess() + if request_body.punc: + text_engine = engine_pool['text'] + connection_handler = PaddleTextConnectionHandler(text_engine) + asr_results = connection_handler.run(asr_results) + response = { "success": True, "code": 200,