diff --git a/speechserving/speechserving/engine/tts/python/tts_engine.py b/speechserving/speechserving/engine/tts/python/tts_engine.py index d790aa316..4f0e99066 100644 --- a/speechserving/speechserving/engine/tts/python/tts_engine.py +++ b/speechserving/speechserving/engine/tts/python/tts_engine.py @@ -22,6 +22,8 @@ from engine.base_engine import BaseEngine from paddlespeech.cli.log import logger from paddlespeech.cli.tts.infer import TTSExecutor +from utils.errors import ErrorCode +from utils.exception import ServerBaseException __all__ = ['TTSEngine'] @@ -128,16 +130,27 @@ class TTSEngine(BaseEngine): lang = self.conf_dict["lang"] - self.executor.infer( - text=sentence, lang=lang, am=self.conf_dict["am"], spk_id=spk_id) - - target_sample_rate, wav_base64 = self.postprocess( - wav=self.executor._outputs['wav'].numpy(), - original_fs=self.executor.am_config.fs, - target_fs=sample_rate, - volume=volume, - speed=speed, - audio_path=save_path, - audio_format=audio_format) + try: + self.executor.infer( + text=sentence, + lang=lang, + am=self.conf_dict["am"], + spk_id=spk_id) + except: + raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, + "tts infer failed.") + + try: + target_sample_rate, wav_base64 = self.postprocess( + wav=self.executor._outputs['wav'].numpy(), + original_fs=self.executor.am_config.fs, + target_fs=sample_rate, + volume=volume, + speed=speed, + audio_path=save_path, + audio_format=audio_format) + except: + raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, + "tts postprocess failed.") return lang, target_sample_rate, wav_base64 diff --git a/speechserving/speechserving/restful/response.py b/speechserving/speechserving/restful/response.py index 684a37f93..db24f5310 100644 --- a/speechserving/speechserving/restful/response.py +++ b/speechserving/speechserving/restful/response.py @@ -68,7 +68,7 @@ class TTSResponse(BaseModel): response example { "success": true, - "code": 0, + "code": 200, "message": { "description": "success" }, diff --git a/speechserving/speechserving/restful/tts_api.py b/speechserving/speechserving/restful/tts_api.py index c78eaf63e..69930f242 100644 --- a/speechserving/speechserving/restful/tts_api.py +++ b/speechserving/speechserving/restful/tts_api.py @@ -11,11 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import traceback + from engine.tts.python.tts_engine import TTSEngine from fastapi import APIRouter from .request import TTSRequest from .response import TTSResponse +from utils.errors import ErrorCode +from utils.errors import ErrorMsg +from utils.errors import failed_response +from utils.exception import ServerBaseException router = APIRouter() @@ -28,7 +34,7 @@ def help(): json: [description] """ json_body = { - "success": true, + "success": "True", "code": 0, "message": { "global": "success" @@ -62,19 +68,31 @@ def tts(request_body: TTSRequest): save_path = item_dict['save_path'] audio_format = item_dict['audio_format'] + # Check parameters + if speed <=0 or speed > 3 or volume <=0 or volume > 3 or \ + sample_rate not in [0, 16000, 8000] or \ + audio_format not in ["pcm", "wav"]: + return failed_response(ErrorCode.SERVER_PARAM_ERR) + # single tts_engine = TTSEngine() - #tts_engine.init() - lang, target_sample_rate, wav_base64 = tts_engine.run( - sentence, spk_id, speed, volume, sample_rate, save_path, audio_format) - #tts_engine.postprocess() + # run + try: + lang, target_sample_rate, wav_base64 = tts_engine.run( + sentence, spk_id, speed, volume, sample_rate, save_path, + audio_format) + except ServerBaseException as e: + response = failed_response(e.error_code, e.msg) + except: + response = failed_response(ErrorCode.SERVER_UNKOWN_ERR) + traceback.print_exc() json_body = { "success": True, - "code": 0, + "code": 200, "message": { - "description": "success" + "description": "success." }, "result": { "lang": lang, diff --git a/speechserving/speechserving/utils/errors.py b/speechserving/speechserving/utils/errors.py index e69de29bb..aa858cb08 100644 --- a/speechserving/speechserving/utils/errors.py +++ b/speechserving/speechserving/utils/errors.py @@ -0,0 +1,57 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from enum import IntEnum + +from fastapi import Response + + +class ErrorCode(IntEnum): + SERVER_OK = 200 # success. + + SERVER_PARAM_ERR = 400 # Input parameters are not valid. + SERVER_TASK_NOT_EXIST = 404 # Task is not exist. + + SERVER_INTERNAL_ERR = 500 # Internal error. + SERVER_NETWORK_ERR = 502 # Network exception. + SERVER_UNKOWN_ERR = 509 # Unknown error occurred. + + +ErrorMsg = { + ErrorCode.SERVER_OK: "success.", + ErrorCode.SERVER_PARAM_ERR: "Input parameters are not valid.", + ErrorCode.SERVER_TASK_NOT_EXIST: "Task is not exist.", + ErrorCode.SERVER_INTERNAL_ERR: "Internal error.", + ErrorCode.SERVER_NETWORK_ERR: "Network exception.", + ErrorCode.SERVER_UNKOWN_ERR: "Unknown error occurred." +} + + +def failed_response(code, msg=""): + """Interface call failure response + + Args: + code (int): error code number + msg (str, optional): Interface call failure information. Defaults to "". + + Returns: + Response (json): failure json information. + """ + + if not msg: + msg = ErrorMsg.get(code, "Unknown error occurred.") + + res = {"success": False, "code": int(code), "message": {"global": msg}} + + return Response(content=json.dumps(res), media_type="application/json") diff --git a/speechserving/speechserving/utils/exception.py b/speechserving/speechserving/utils/exception.py new file mode 100644 index 000000000..03a6deee2 --- /dev/null +++ b/speechserving/speechserving/utils/exception.py @@ -0,0 +1,30 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import traceback + +from utils.errors import ErrorMsg + + +class ServerBaseException(Exception): + """ Server Base exception + """ + + def __init__(self, error_code, msg=None): + #if msg: + #log.error(msg) + msg = msg if msg else ErrorMsg.get(error_code, "") + super(ServerBaseException, self).__init__(error_code, msg) + self.error_code = error_code + self.msg = msg + traceback.print_exc()