Merge pull request #1 from PaddlePaddle/server

Server
pull/1399/head
WilliamZhang06 3 years ago committed by GitHub
commit b7126f091c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,38 @@
# This is the parameter configuration file for TTS server.
##################################################################
# TTS SERVER SETTING #
##################################################################
host: '0.0.0.0'
port: 8692
##################################################################
# ACOUSTIC MODEL SETTING #
# am choices=['speedyspeech_csmsc', 'fastspeech2_csmsc',
# 'fastspeech2_ljspeech', 'fastspeech2_aishell3',
# 'fastspeech2_vctk']
##################################################################
am: 'fastspeech2_csmsc'
am_config:
am_ckpt:
am_stat:
phones_dict:
tones_dict:
speaker_dict:
spk_id: 0
##################################################################
# VOCODER SETTING #
# voc choices=['pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3',
# 'pwgan_vctk', 'mb_melgan_csmsc']
##################################################################
voc: 'pwgan_csmsc'
voc_config:
voc_ckpt:
voc_stat:
##################################################################
# OTHERS #
##################################################################
lang: 'zh'
device: paddle.get_device()

@ -0,0 +1,156 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import base64
import librosa
import numpy as np
import soundfile as sf
import yaml
from engine.base_engine import BaseEngine
from paddlespeech.cli.log import logger
from paddlespeech.cli.tts.infer import TTSExecutor
from utils.errors import ErrorCode
from utils.exception import ServerBaseException
__all__ = ['TTSEngine']
class TTSServerExecutor(TTSExecutor):
def __init__(self):
super().__init__()
self.parser = argparse.ArgumentParser(
prog='paddlespeech.tts', add_help=True)
self.parser.add_argument(
'--conf',
type=str,
default='./conf/tts/tts.yaml',
help='Configuration parameters.')
class TTSEngine(BaseEngine):
"""TTS server engine
Args:
metaclass: Defaults to Singleton.
"""
def __init__(self, name=None):
"""Initialize TTS server engine
"""
super(TTSEngine, self).__init__()
self.executor = TTSServerExecutor()
config_path = self.executor.parser.parse_args().conf
with open(config_path, 'rt') as f:
self.conf_dict = yaml.safe_load(f)
self.executor._init_from_path(
am=self.conf_dict["am"],
am_config=self.conf_dict["am_config"],
am_ckpt=self.conf_dict["am_ckpt"],
am_stat=self.conf_dict["am_stat"],
phones_dict=self.conf_dict["phones_dict"],
tones_dict=self.conf_dict["tones_dict"],
speaker_dict=self.conf_dict["speaker_dict"],
voc=self.conf_dict["voc"],
voc_config=self.conf_dict["voc_config"],
voc_ckpt=self.conf_dict["voc_ckpt"],
voc_stat=self.conf_dict["voc_stat"],
lang=self.conf_dict["lang"])
logger.info("Initialize TTS server engine successfully.")
def postprocess(self,
wav,
original_fs: int,
target_fs: int=16000,
volume: float=1.0,
speed: float=1.0,
audio_path: str=None,
audio_format: str="wav"):
"""Post-processing operations, including speech, volume, sample rate, save audio file
Args:
wav (numpy(float)): Synthesized audio sample points
original_fs (int): original audio sample rate
target_fs (int): target audio sample rate
volume (float): target volume
speed (float): target speed
"""
# transform sample_rate
if target_fs == 0 or target_fs > original_fs:
target_fs = original_fs
wav_tar_fs = wav
else:
wav_tar_fs = librosa.resample(
np.squeeze(wav), original_fs, target_fs)
# transform volume
wav_vol = wav_tar_fs * volume
# transform speed
# TODO
target_wav = wav_vol.reshape(-1, 1)
# save audio
if audio_path is not None:
sf.write(audio_path, target_wav, target_fs)
logger.info('Wave file has been generated: {}'.format(audio_path))
# wav to base64
base64_bytes = base64.b64encode(target_wav)
base64_string = base64_bytes.decode('utf-8')
wav_base64 = base64_string
return target_fs, wav_base64
def run(self,
sentence: str,
spk_id: int=0,
speed: float=1.0,
volume: float=1.0,
sample_rate: int=0,
save_path: str=None,
audio_format: str="wav"):
lang = self.conf_dict["lang"]
try:
self.executor.infer(
text=sentence,
lang=lang,
am=self.conf_dict["am"],
spk_id=spk_id)
except:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts infer failed.")
try:
target_sample_rate, wav_base64 = self.postprocess(
wav=self.executor._outputs['wav'].numpy(),
original_fs=self.executor.am_config.fs,
target_fs=sample_rate,
volume=volume,
speed=speed,
audio_path=save_path,
audio_format=audio_format)
except:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts postprocess failed.")
return lang, target_sample_rate, wav_base64

@ -15,11 +15,12 @@ import argparse
import uvicorn import uvicorn
import yaml import yaml
from engine.asr.python.asr_engine import ASREngine
from engine.tts.python.tts_engine import TTSEngine
from fastapi import FastAPI from fastapi import FastAPI
from restful.api import router as api_router from restful.api import router as api_router
from utils.log import logger from paddlespeech.cli.log import logger
app = FastAPI( app = FastAPI(
title="PaddleSpeech Serving API", description="Api", version="0.0.1") title="PaddleSpeech Serving API", description="Api", version="0.0.1")
@ -31,7 +32,8 @@ def init(args):
app.include_router(api_router) app.include_router(api_router)
# engine single # engine single
ASR_ENGINE = ASREngine("asr")
TTS_ENGINE = TTSEngine()
# todo others # todo others
@ -56,7 +58,8 @@ if __name__ == "__main__":
"--config_file", "--config_file",
action="store", action="store",
help="yaml file of the app", help="yaml file of the app",
default="./conf/application.yaml") default="./conf/tts/tts.yaml")
parser.add_argument( parser.add_argument(
"--log_file", "--log_file",
action="store", action="store",

@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
from fastapi import APIRouter from fastapi import APIRouter
from .asr_api import router as asr_router
from .tts_api import router as tts_router from .tts_api import router as tts_router
#from .asr_api import router as asr_router
router = APIRouter() router = APIRouter()
router.include_router(asr_router) #router.include_router(asr_router)
router.include_router(tts_router) router.include_router(tts_router)

@ -16,7 +16,8 @@ from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
__all__ = ['ASRRequest, TTSRequest'] __all__ = ['ASRRequest', 'TTSRequest']
#****************************************************************************************/ #****************************************************************************************/
@ -44,13 +45,26 @@ class ASRRequest(BaseModel):
#************************************ TTS request ***************************************/ #************************************ TTS request ***************************************/
#****************************************************************************************/ #****************************************************************************************/
class TTSRequest(BaseModel): class TTSRequest(BaseModel):
""" """TTS request
request body example request body example
{ {
"audio": "exSI6ICJlbiIsCgkgICAgInBvc2l0aW9uIjogImZhbHNlIgoJf...", "text": "你好,欢迎使用百度飞桨语音合成服务。",
"audio_format": "wav", "spk_id": 0,
"sample_rate": 16000, "speed": 1.0,
"lang ": "zh_cn", "volume": 1.0,
"ptt ":false "sample_rate": 0,
"tts_audio_path": "./tts.wav",
"audio_format": "wav"
} }
""" """
text: str
spk_id: int = 0
speed: float = 1.0
volume: float = 1.0
sample_rate: int = 0
save_path: str = None
audio_format: str = "wav"

@ -16,7 +16,7 @@ from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
__all__ = ['ASRResponse'] __all__ = ['ASRResponse', 'TTSResponse']
class Message(BaseModel): class Message(BaseModel):
@ -53,3 +53,36 @@ class ASRResponse(BaseModel):
#****************************************************************************************/ #****************************************************************************************/
#************************************ TTS response **************************************/ #************************************ TTS response **************************************/
#****************************************************************************************/ #****************************************************************************************/
class TTSResult(BaseModel):
lang: str = "zh"
sample_rate: int
spk_id: int = 0
speed: float = 1.0
volume: float = 1.0
save_path: str = None
audio: str
class TTSResponse(BaseModel):
"""
response example
{
"success": true,
"code": 200,
"message": {
"description": "success"
},
"result": {
"lang": "zh",
"sample_rate": 24000,
"speed": 1.0,
"volume": 1.0,
"audio": "LTI1OTIuNjI1OTUwMzQsOTk2OS41NDk4...",
"save_path": "./tts.wav"
}
}
"""
success: bool
code: int
message: Message
result: TTSResult

@ -11,8 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import traceback
from engine.tts.python.tts_engine import TTSEngine
from fastapi import APIRouter from fastapi import APIRouter
from .request import TTSRequest
from .response import TTSResponse
from utils.errors import ErrorCode
from utils.errors import ErrorMsg
from utils.errors import failed_response
from utils.exception import ServerBaseException
router = APIRouter() router = APIRouter()
@ -24,6 +33,76 @@ def help():
Returns: Returns:
json: [description] json: [description]
""" """
return {'hello': 'world'} json_body = {
"success": "True",
"code": 0,
"message": {
"global": "success"
},
"result": {
"description": "tts server",
"text": "sentence to be synthesized",
"audio": "the base64 of audio"
}
}
return json_body
@router.post("/paddlespeech/tts", response_model=TTSResponse)
def tts(request_body: TTSRequest):
"""tts api
Args:
request_body (TTSRequest): [description]
Returns:
json: [description]
"""
# json to dict
item_dict = request_body.dict()
sentence = item_dict['text']
spk_id = item_dict['spk_id']
speed = item_dict['speed']
volume = item_dict['volume']
sample_rate = item_dict['sample_rate']
save_path = item_dict['save_path']
audio_format = item_dict['audio_format']
# Check parameters
if speed <=0 or speed > 3 or volume <=0 or volume > 3 or \
sample_rate not in [0, 16000, 8000] or \
audio_format not in ["pcm", "wav"]:
return failed_response(ErrorCode.SERVER_PARAM_ERR)
# single
tts_engine = TTSEngine()
# run
try:
lang, target_sample_rate, wav_base64 = tts_engine.run(
sentence, spk_id, speed, volume, sample_rate, save_path,
audio_format)
except ServerBaseException as e:
response = failed_response(e.error_code, e.msg)
except:
response = failed_response(ErrorCode.SERVER_UNKOWN_ERR)
traceback.print_exc()
json_body = {
"success": True,
"code": 200,
"message": {
"description": "success."
},
"result": {
"lang": lang,
"spk_id": spk_id,
"speed": speed,
"volume": volume,
"sample_rate": target_sample_rate,
"save_path": save_path,
"audio": wav_base64
}
}
return json_body

@ -0,0 +1,57 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from enum import IntEnum
from fastapi import Response
class ErrorCode(IntEnum):
SERVER_OK = 200 # success.
SERVER_PARAM_ERR = 400 # Input parameters are not valid.
SERVER_TASK_NOT_EXIST = 404 # Task is not exist.
SERVER_INTERNAL_ERR = 500 # Internal error.
SERVER_NETWORK_ERR = 502 # Network exception.
SERVER_UNKOWN_ERR = 509 # Unknown error occurred.
ErrorMsg = {
ErrorCode.SERVER_OK: "success.",
ErrorCode.SERVER_PARAM_ERR: "Input parameters are not valid.",
ErrorCode.SERVER_TASK_NOT_EXIST: "Task is not exist.",
ErrorCode.SERVER_INTERNAL_ERR: "Internal error.",
ErrorCode.SERVER_NETWORK_ERR: "Network exception.",
ErrorCode.SERVER_UNKOWN_ERR: "Unknown error occurred."
}
def failed_response(code, msg=""):
"""Interface call failure response
Args:
code (int): error code number
msg (str, optional): Interface call failure information. Defaults to "".
Returns:
Response (json): failure json information.
"""
if not msg:
msg = ErrorMsg.get(code, "Unknown error occurred.")
res = {"success": False, "code": int(code), "message": {"global": msg}}
return Response(content=json.dumps(res), media_type="application/json")

@ -0,0 +1,30 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import traceback
from utils.errors import ErrorMsg
class ServerBaseException(Exception):
""" Server Base exception
"""
def __init__(self, error_code, msg=None):
#if msg:
#log.error(msg)
msg = msg if msg else ErrorMsg.get(error_code, "")
super(ServerBaseException, self).__init__(error_code, msg)
self.error_code = error_code
self.msg = msg
traceback.print_exc()
Loading…
Cancel
Save