add error code, test=server (#1391)

pull/1415/head
liangym 3 years ago committed by GitHub
parent 3796a40e01
commit 1ee4577d3e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -22,6 +22,8 @@ from engine.base_engine import BaseEngine
from paddlespeech.cli.log import logger
from paddlespeech.cli.tts.infer import TTSExecutor
from utils.errors import ErrorCode
from utils.exception import ServerBaseException
__all__ = ['TTSEngine']
@ -128,16 +130,27 @@ class TTSEngine(BaseEngine):
lang = self.conf_dict["lang"]
self.executor.infer(
text=sentence, lang=lang, am=self.conf_dict["am"], spk_id=spk_id)
target_sample_rate, wav_base64 = self.postprocess(
wav=self.executor._outputs['wav'].numpy(),
original_fs=self.executor.am_config.fs,
target_fs=sample_rate,
volume=volume,
speed=speed,
audio_path=save_path,
audio_format=audio_format)
try:
self.executor.infer(
text=sentence,
lang=lang,
am=self.conf_dict["am"],
spk_id=spk_id)
except:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts infer failed.")
try:
target_sample_rate, wav_base64 = self.postprocess(
wav=self.executor._outputs['wav'].numpy(),
original_fs=self.executor.am_config.fs,
target_fs=sample_rate,
volume=volume,
speed=speed,
audio_path=save_path,
audio_format=audio_format)
except:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts postprocess failed.")
return lang, target_sample_rate, wav_base64

@ -68,7 +68,7 @@ class TTSResponse(BaseModel):
response example
{
"success": true,
"code": 0,
"code": 200,
"message": {
"description": "success"
},

@ -11,11 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import traceback
from engine.tts.python.tts_engine import TTSEngine
from fastapi import APIRouter
from .request import TTSRequest
from .response import TTSResponse
from utils.errors import ErrorCode
from utils.errors import ErrorMsg
from utils.errors import failed_response
from utils.exception import ServerBaseException
router = APIRouter()
@ -28,7 +34,7 @@ def help():
json: [description]
"""
json_body = {
"success": true,
"success": "True",
"code": 0,
"message": {
"global": "success"
@ -62,19 +68,31 @@ def tts(request_body: TTSRequest):
save_path = item_dict['save_path']
audio_format = item_dict['audio_format']
# Check parameters
if speed <=0 or speed > 3 or volume <=0 or volume > 3 or \
sample_rate not in [0, 16000, 8000] or \
audio_format not in ["pcm", "wav"]:
return failed_response(ErrorCode.SERVER_PARAM_ERR)
# single
tts_engine = TTSEngine()
#tts_engine.init()
lang, target_sample_rate, wav_base64 = tts_engine.run(
sentence, spk_id, speed, volume, sample_rate, save_path, audio_format)
#tts_engine.postprocess()
# run
try:
lang, target_sample_rate, wav_base64 = tts_engine.run(
sentence, spk_id, speed, volume, sample_rate, save_path,
audio_format)
except ServerBaseException as e:
response = failed_response(e.error_code, e.msg)
except:
response = failed_response(ErrorCode.SERVER_UNKOWN_ERR)
traceback.print_exc()
json_body = {
"success": True,
"code": 0,
"code": 200,
"message": {
"description": "success"
"description": "success."
},
"result": {
"lang": lang,

@ -0,0 +1,57 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from enum import IntEnum
from fastapi import Response
class ErrorCode(IntEnum):
SERVER_OK = 200 # success.
SERVER_PARAM_ERR = 400 # Input parameters are not valid.
SERVER_TASK_NOT_EXIST = 404 # Task is not exist.
SERVER_INTERNAL_ERR = 500 # Internal error.
SERVER_NETWORK_ERR = 502 # Network exception.
SERVER_UNKOWN_ERR = 509 # Unknown error occurred.
ErrorMsg = {
ErrorCode.SERVER_OK: "success.",
ErrorCode.SERVER_PARAM_ERR: "Input parameters are not valid.",
ErrorCode.SERVER_TASK_NOT_EXIST: "Task is not exist.",
ErrorCode.SERVER_INTERNAL_ERR: "Internal error.",
ErrorCode.SERVER_NETWORK_ERR: "Network exception.",
ErrorCode.SERVER_UNKOWN_ERR: "Unknown error occurred."
}
def failed_response(code, msg=""):
"""Interface call failure response
Args:
code (int): error code number
msg (str, optional): Interface call failure information. Defaults to "".
Returns:
Response (json): failure json information.
"""
if not msg:
msg = ErrorMsg.get(code, "Unknown error occurred.")
res = {"success": False, "code": int(code), "message": {"global": msg}}
return Response(content=json.dumps(res), media_type="application/json")

@ -0,0 +1,30 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import traceback
from utils.errors import ErrorMsg
class ServerBaseException(Exception):
""" Server Base exception
"""
def __init__(self, error_code, msg=None):
#if msg:
#log.error(msg)
msg = msg if msg else ErrorMsg.get(error_code, "")
super(ServerBaseException, self).__init__(error_code, msg)
self.error_code = error_code
self.msg = msg
traceback.print_exc()
Loading…
Cancel
Save