fix enable the punctuation for asr, test=asr(#3631)

pull/3949/head
潘志强 9 months ago
parent 2d7cf7f0e6
commit fb7e911caa

@ -30,8 +30,9 @@ def init_engine_pool(config) -> bool:
global ENGINE_POOL global ENGINE_POOL
for engine_and_type in config.engine_list: for engine_and_type in config.engine_list:
engine = engine_and_type.split("_")[0] engine, engine_type = engine_and_type.split("_")
engine_type = engine_and_type.split("_")[1] assert (engine != 'asr') or ('text_python' in config.engine_list), \
"When engine is 'asr', must be enabled 'text_python' to support punctuation recovery"
ENGINE_POOL[engine] = EngineFactory.get_engine( ENGINE_POOL[engine] = EngineFactory.get_engine(
engine_name=engine, engine_type=engine_type) engine_name=engine, engine_type=engine_type)

@ -26,6 +26,7 @@ from paddlespeech.server.restful.response import ErrorResponse
from paddlespeech.server.utils.errors import ErrorCode from paddlespeech.server.utils.errors import ErrorCode
from paddlespeech.server.utils.errors import failed_response from paddlespeech.server.utils.errors import failed_response
from paddlespeech.server.utils.exception import ServerBaseException from paddlespeech.server.utils.exception import ServerBaseException
from paddlespeech.server.engine.text.python.text_engine import PaddleTextConnectionHandler
router = APIRouter() router = APIRouter()
@ -83,6 +84,11 @@ def asr(request_body: ASRRequest):
connection_handler.run(audio_data) connection_handler.run(audio_data)
asr_results = connection_handler.postprocess() asr_results = connection_handler.postprocess()
if request_body.punc:
text_engine = engine_pool['text']
connection_handler = PaddleTextConnectionHandler(text_engine)
asr_results = connection_handler.run(asr_results)
response = { response = {
"success": True, "success": True,
"code": 200, "code": 200,

Loading…
Cancel
Save