update engine, test=doc

pull/1992/head
lym0302 3 years ago
parent f07f57a3a8
commit d48c4d686a

@ -0,0 +1,70 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import warnings
import uvicorn
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware
from paddlespeech.server.engine.engine_pool import init_engine_pool
from paddlespeech.server.restful.api import setup_router as setup_http_router
from paddlespeech.server.utils.config import get_config
from paddlespeech.server.ws.api import setup_router as setup_ws_router
warnings.filterwarnings("ignore")
import sys
app = FastAPI(
title="PaddleSpeech Serving API", description="Api", version="0.0.1")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"])
# change yaml file here
config_file = "./conf/application.yaml"
config = get_config(config_file)
# init engine
if not init_engine_pool(config):
print("Failed to init engine.")
sys.exit(-1)
# get api_router
api_list = list(engine.split("_")[0] for engine in config.engine_list)
if config.protocol == "websocket":
api_router = setup_ws_router(api_list)
elif config.protocol == "http":
api_router = setup_http_router(api_list)
else:
raise Exception("unsupported protocol")
sys.exit(-1)
# app needs to operate outside the main function
app.include_router(api_router)
if __name__ == "__main__":
parser = argparse.ArgumentParser(add_help=True)
parser.add_argument(
"--workers", type=int, help="workers of server", default=1)
args = parser.parse_args()
uvicorn.run(
"start_multi_progress_server:app",
host=config.host,
port=config.port,
debug=True,
workers=args.workers)

@ -26,6 +26,7 @@ from ..util import cli_server_register
from ..util import stats_wrapper
from paddlespeech.cli.log import logger
from paddlespeech.server.engine.engine_pool import init_engine_pool
from paddlespeech.server.engine.engine_warmup import warm_up
from paddlespeech.server.restful.api import setup_router as setup_http_router
from paddlespeech.server.utils.config import get_config
from paddlespeech.server.ws.api import setup_router as setup_ws_router
@ -86,6 +87,11 @@ class ServerExecutor(BaseExecutor):
if not init_engine_pool(config):
return False
# warm up
for engine_and_type in config.engine_list:
if not warm_up(engine_and_type):
return False
return True
def execute(self, argv: List[str]) -> bool:

@ -30,7 +30,7 @@ from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.paddle_predictor import init_predictor
from paddlespeech.server.utils.paddle_predictor import run_model
__all__ = ['ASREngine']
__all__ = ['ASREngine', 'PaddleASRConnectionHandler']
class ASRServerExecutor(ASRExecutor):
@ -50,7 +50,7 @@ class ASRServerExecutor(ASRExecutor):
"""
Init model and other resources from a specific path.
"""
self.max_len = 50
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str
if cfg_path is None or am_model is None or am_params is None:
@ -172,10 +172,23 @@ class ASREngine(BaseEngine):
Returns:
bool: init failed or success
"""
self.input = None
self.output = None
self.executor = ASRServerExecutor()
self.config = config
self.engine_type = "inference"
try:
if self.config.am_predictor_conf.device is not None:
self.device = self.config.am_predictor_conf.device
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except Exception as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error(e)
return False
self.executor._init_from_path(
model_type=self.config.model_type,
@ -190,22 +203,42 @@ class ASREngine(BaseEngine):
logger.info("Initialize ASR server engine successfully.")
return True
class PaddleASRConnectionHandler(ASRServerExecutor):
def __init__(self, asr_engine):
"""The PaddleSpeech ASR Server Connection Handler
This connection process every asr server request
Args:
asr_engine (ASREngine): The ASR engine
"""
super().__init__()
self.input = None
self.output = None
self.asr_engine = asr_engine
self.executor = self.asr_engine.executor
self.config = self.executor.config
self.max_len = self.executor.max_len
self.decoder = self.executor.decoder
self.am_predictor = self.executor.am_predictor
self.text_feature = self.executor.text_feature
self.collate_fn_test = self.executor.collate_fn_test
def run(self, audio_data):
"""engine run
Args:
audio_data (bytes): base64.b64decode
"""
if self.executor._check(
io.BytesIO(audio_data), self.config.sample_rate,
self.config.force_yes):
if self._check(
io.BytesIO(audio_data), self.asr_engine.config.sample_rate,
self.asr_engine.config.force_yes):
logger.info("start running asr engine")
self.executor.preprocess(self.config.model_type,
io.BytesIO(audio_data))
self.preprocess(self.asr_engine.config.model_type,
io.BytesIO(audio_data))
st = time.time()
self.executor.infer(self.config.model_type)
self.infer(self.asr_engine.config.model_type)
infer_time = time.time() - st
self.output = self.executor.postprocess() # Retrieve result of asr.
self.output = self.postprocess() # Retrieve result of asr.
logger.info("end inferring asr engine")
else:
logger.info("file check failed!")
@ -213,8 +246,3 @@ class ASREngine(BaseEngine):
logger.info("inference time: {}".format(infer_time))
logger.info("asr engine type: paddle inference")
def postprocess(self):
"""postprocess
"""
return self.output

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import sys
import time
import paddle
@ -20,7 +21,7 @@ from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger
from paddlespeech.server.engine.base_engine import BaseEngine
__all__ = ['ASREngine']
__all__ = ['ASREngine', 'PaddleASRConnectionHandler']
class ASRServerExecutor(ASRExecutor):
@ -48,20 +49,23 @@ class ASREngine(BaseEngine):
Returns:
bool: init failed or success
"""
self.input = None
self.output = None
self.executor = ASRServerExecutor()
self.config = config
self.engine_type = "python"
try:
if self.config.device:
if self.config.device is not None:
self.device = self.config.device
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except BaseException:
except Exception as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error(e)
return False
self.executor._init_from_path(
self.config.model, self.config.lang, self.config.sample_rate,
@ -72,6 +76,24 @@ class ASREngine(BaseEngine):
(self.device))
return True
class PaddleASRConnectionHandler(ASRServerExecutor):
def __init__(self, asr_engine):
"""The PaddleSpeech ASR Server Connection Handler
This connection process every asr server request
Args:
asr_engine (ASREngine): The ASR engine
"""
super().__init__()
self.input = None
self.output = None
self.asr_engine = asr_engine
self.executor = self.asr_engine.executor
self.max_len = self.executor.max_len
self.text_feature = self.executor.text_feature
self.model = self.executor.model
self.config = self.executor.config
def run(self, audio_data):
"""engine run
@ -79,17 +101,16 @@ class ASREngine(BaseEngine):
audio_data (bytes): base64.b64decode
"""
try:
if self.executor._check(
io.BytesIO(audio_data), self.config.sample_rate,
self.config.force_yes):
if self._check(
io.BytesIO(audio_data), self.asr_engine.config.sample_rate,
self.asr_engine.config.force_yes):
logger.info("start run asr engine")
self.executor.preprocess(self.config.model,
io.BytesIO(audio_data))
self.preprocess(self.asr_engine.config.model,
io.BytesIO(audio_data))
st = time.time()
self.executor.infer(self.config.model)
self.infer(self.asr_engine.config.model)
infer_time = time.time() - st
self.output = self.executor.postprocess(
) # Retrieve result of asr.
self.output = self.postprocess() # Retrieve result of asr.
else:
logger.info("file check failed!")
self.output = None
@ -98,8 +119,4 @@ class ASREngine(BaseEngine):
logger.info("asr engine type: python")
except Exception as e:
logger.info(e)
def postprocess(self):
"""postprocess
"""
return self.output
sys.exit(-1)

@ -14,6 +14,7 @@
import io
import os
import time
from collections import OrderedDict
from typing import Optional
import numpy as np
@ -27,7 +28,7 @@ from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.paddle_predictor import init_predictor
from paddlespeech.server.utils.paddle_predictor import run_model
__all__ = ['CLSEngine']
__all__ = ['CLSEngine', 'PaddleCLSConnectionHandler']
class CLSServerExecutor(CLSExecutor):
@ -119,14 +120,55 @@ class CLSEngine(BaseEngine):
"""
self.executor = CLSServerExecutor()
self.config = config
self.executor._init_from_path(
self.config.model_type, self.config.cfg_path,
self.config.model_path, self.config.params_path,
self.config.label_file, self.config.predictor_conf)
self.engine_type = "inference"
try:
if self.config.predictor_conf.device is not None:
self.device = self.config.predictor_conf.device
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except Exception as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error(e)
return False
try:
self.executor._init_from_path(
self.config.model_type, self.config.cfg_path,
self.config.model_path, self.config.params_path,
self.config.label_file, self.config.predictor_conf)
except Exception as e:
logger.error("Initialize CLS server engine Failed.")
logger.error(e)
return False
logger.info("Initialize CLS server engine successfully.")
return True
class PaddleCLSConnectionHandler(CLSServerExecutor):
def __init__(self, cls_engine):
"""The PaddleSpeech CLS Server Connection Handler
This connection process every cls server request
Args:
cls_engine (CLSEngine): The CLS engine
"""
super().__init__()
logger.info(
"Create PaddleCLSConnectionHandler to process the cls request")
self._inputs = OrderedDict()
self._outputs = OrderedDict()
self.cls_engine = cls_engine
self.executor = self.cls_engine.executor
self._conf = self.executor._conf
self._label_list = self.executor._label_list
self.predictor = self.executor.predictor
def run(self, audio_data):
"""engine run
@ -134,9 +176,9 @@ class CLSEngine(BaseEngine):
audio_data (bytes): base64.b64decode
"""
self.executor.preprocess(io.BytesIO(audio_data))
self.preprocess(io.BytesIO(audio_data))
st = time.time()
self.executor.infer()
self.infer()
infer_time = time.time() - st
logger.info("inference time: {}".format(infer_time))
@ -145,15 +187,15 @@ class CLSEngine(BaseEngine):
def postprocess(self, topk: int):
"""postprocess
"""
assert topk <= len(self.executor._label_list
), 'Value of topk is larger than number of labels.'
assert topk <= len(
self._label_list), 'Value of topk is larger than number of labels.'
result = np.squeeze(self.executor._outputs['logits'], axis=0)
result = np.squeeze(self._outputs['logits'], axis=0)
topk_idx = (-result).argsort()[:topk]
topk_results = []
for idx in topk_idx:
res = {}
label, score = self.executor._label_list[idx], result[idx]
label, score = self._label_list[idx], result[idx]
res['class_name'] = label
res['prob'] = score
topk_results.append(res)

@ -13,7 +13,7 @@
# limitations under the License.
import io
import time
from typing import List
from collections import OrderedDict
import paddle
@ -21,7 +21,7 @@ from paddlespeech.cli.cls.infer import CLSExecutor
from paddlespeech.cli.log import logger
from paddlespeech.server.engine.base_engine import BaseEngine
__all__ = ['CLSEngine']
__all__ = ['CLSEngine', 'PaddleCLSConnectionHandler']
class CLSServerExecutor(CLSExecutor):
@ -29,21 +29,6 @@ class CLSServerExecutor(CLSExecutor):
super().__init__()
pass
def get_topk_results(self, topk: int) -> List:
assert topk <= len(
self._label_list), 'Value of topk is larger than number of labels.'
result = self._outputs['logits'].squeeze(0).numpy()
topk_idx = (-result).argsort()[:topk]
res = {}
topk_results = []
for idx in topk_idx:
label, score = self._label_list[idx], result[idx]
res['class'] = label
res['prob'] = score
topk_results.append(res)
return topk_results
class CLSEngine(BaseEngine):
"""CLS server engine
@ -64,42 +49,65 @@ class CLSEngine(BaseEngine):
Returns:
bool: init failed or success
"""
self.input = None
self.output = None
self.executor = CLSServerExecutor()
self.config = config
self.engine_type = "python"
try:
if self.config.device:
if self.config.device is not None:
self.device = self.config.device
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except BaseException:
except Exception as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error(e)
return False
try:
self.executor._init_from_path(
self.config.model, self.config.cfg_path, self.config.ckpt_path,
self.config.label_file)
except BaseException:
except Exception as e:
logger.error("Initialize CLS server engine Failed.")
logger.error(e)
return False
logger.info("Initialize CLS server engine successfully on device: %s." %
(self.device))
return True
class PaddleCLSConnectionHandler(CLSServerExecutor):
def __init__(self, cls_engine):
"""The PaddleSpeech CLS Server Connection Handler
This connection process every cls server request
Args:
cls_engine (CLSEngine): The CLS engine
"""
super().__init__()
logger.info(
"Create PaddleCLSConnectionHandler to process the cls request")
self._inputs = OrderedDict()
self._outputs = OrderedDict()
self.cls_engine = cls_engine
self.executor = self.cls_engine.executor
self._conf = self.executor._conf
self._label_list = self.executor._label_list
self.model = self.executor.model
def run(self, audio_data):
"""engine run
Args:
audio_data (bytes): base64.b64decode
"""
self.executor.preprocess(io.BytesIO(audio_data))
self.preprocess(io.BytesIO(audio_data))
st = time.time()
self.executor.infer()
self.infer()
infer_time = time.time() - st
logger.info("inference time: {}".format(infer_time))
@ -108,15 +116,15 @@ class CLSEngine(BaseEngine):
def postprocess(self, topk: int):
"""postprocess
"""
assert topk <= len(self.executor._label_list
), 'Value of topk is larger than number of labels.'
assert topk <= len(
self._label_list), 'Value of topk is larger than number of labels.'
result = self.executor._outputs['logits'].squeeze(0).numpy()
result = self._outputs['logits'].squeeze(0).numpy()
topk_idx = (-result).argsort()[:topk]
topk_results = []
for idx in topk_idx:
res = {}
label, score = self.executor._label_list[idx], result[idx]
label, score = self._label_list[idx], result[idx]
res['class_name'] = label
res['prob'] = score
topk_results.append(res)

@ -0,0 +1,75 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from paddlespeech.cli.log import logger
from paddlespeech.server.engine.engine_pool import get_engine_pool
def warm_up(engine_and_type: str, warm_up_time: int=3) -> bool:
engine_pool = get_engine_pool()
if "tts" in engine_and_type:
tts_engine = engine_pool['tts']
flag_online = False
if tts_engine.lang == 'zh':
sentence = "您好,欢迎使用语音合成服务。"
elif tts_engine.lang == 'en':
sentence = "Hello and welcome to the speech synthesis service."
else:
logger.error("tts engine only support lang: zh or en.")
sys.exit(-1)
if engine_and_type == "tts_python":
from paddlespeech.server.engine.tts.python.tts_engine import PaddleTTSConnectionHandler
elif engine_and_type == "tts_inference":
from paddlespeech.server.engine.tts.paddleinference.tts_engine import PaddleTTSConnectionHandler
elif engine_and_type == "tts_online":
from paddlespeech.server.engine.tts.online.python.tts_engine import PaddleTTSConnectionHandler
flag_online = True
elif engine_and_type == "tts_online-onnx":
from paddlespeech.server.engine.tts.online.onnx.tts_engine import PaddleTTSConnectionHandler
flag_online = True
else:
logger.error("Please check tte engine type.")
try:
logger.info("Start to warm up tts engine.")
for i in range(warm_up_time):
connection_handler = PaddleTTSConnectionHandler(tts_engine)
if flag_online:
for wav in connection_handler.infer(
text=sentence,
lang=tts_engine.lang,
am=tts_engine.config.am):
logger.info(
f"The first response time of the {i} warm up: {connection_handler.first_response_time} s"
)
break
else:
st = time.time()
connection_handler.infer(text=sentence)
et = time.time()
logger.info(
f"The response time of the {i} warm up: {et - st} s")
except Exception as e:
logger.error("Failed to warm up on tts engine.")
logger.error(e)
return False
else:
pass
return True

@ -31,18 +31,12 @@ from paddlespeech.server.utils.util import get_chunks
from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend
__all__ = ['TTSEngine']
__all__ = ['TTSEngine', 'PaddleTTSConnectionHandler']
class TTSServerExecutor(TTSExecutor):
def __init__(self, am_block, am_pad, voc_block, voc_pad, voc_upsample):
def __init__(self):
super().__init__()
self.am_block = am_block
self.am_pad = am_pad
self.voc_block = voc_block
self.voc_pad = voc_pad
self.voc_upsample = voc_upsample
self.pretrained_models = pretrained_models
def _init_from_path(
@ -161,6 +155,115 @@ class TTSServerExecutor(TTSExecutor):
self.frontend = English(phone_vocab_path=self.phones_dict)
logger.info("frontend done!")
class TTSEngine(BaseEngine):
"""TTS server engine
Args:
metaclass: Defaults to Singleton.
"""
def __init__(self, name=None):
"""Initialize TTS server engine
"""
super().__init__()
def init(self, config: dict) -> bool:
self.executor = TTSServerExecutor()
self.config = config
self.lang = self.config.lang
self.engine_type = "online-onnx"
self.am_block = self.config.am_block
self.am_pad = self.config.am_pad
self.voc_block = self.config.voc_block
self.voc_pad = self.config.voc_pad
self.am_upsample = 1
self.voc_upsample = self.config.voc_upsample
assert (
self.config.am == "fastspeech2_csmsc_onnx" or
self.config.am == "fastspeech2_cnndecoder_csmsc_onnx"
) and (
self.config.voc == "hifigan_csmsc_onnx" or
self.config.voc == "mb_melgan_csmsc_onnx"
), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.'
assert (
self.config.voc_block > 0 and self.config.voc_pad > 0
), "Please set correct voc_block and voc_pad, they should be more than 0."
assert (
self.config.voc_sample_rate == self.config.am_sample_rate
), "The sample rate of AM and Vocoder model are different, please check model."
try:
if self.config.am_sess_conf.device is not None:
self.device = self.config.am_sess_conf.device
elif self.config.voc_sess_conf.device is not None:
self.device = self.config.voc_sess_conf.device
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except Exception as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
logger.error(e)
return False
try:
self.executor._init_from_path(
am=self.config.am,
am_ckpt=self.config.am_ckpt,
am_stat=self.config.am_stat,
phones_dict=self.config.phones_dict,
tones_dict=self.config.tones_dict,
speaker_dict=self.config.speaker_dict,
am_sample_rate=self.config.am_sample_rate,
am_sess_conf=self.config.am_sess_conf,
voc=self.config.voc,
voc_ckpt=self.config.voc_ckpt,
voc_sample_rate=self.config.voc_sample_rate,
voc_sess_conf=self.config.voc_sess_conf,
lang=self.config.lang)
except Exception as e:
logger.error("Failed to get model related files.")
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.config.voc_sess_conf.device))
logger(e)
return False
logger.info("Initialize TTS server engine successfully on device: %s." %
(self.config.voc_sess_conf.device))
return True
class PaddleTTSConnectionHandler:
def __init__(self, tts_engine):
"""The PaddleSpeech TTS Server Connection Handler
This connection process every tts server request
Args:
tts_engine (TTSEngine): The TTS engine
"""
super().__init__()
logger.info(
"Create PaddleTTSConnectionHandler to process the tts request")
self.tts_engine = tts_engine
self.executor = self.tts_engine.executor
self.config = self.tts_engine.config
self.am_block = self.tts_engine.am_block
self.am_pad = self.tts_engine.am_pad
self.voc_block = self.tts_engine.voc_block
self.voc_pad = self.tts_engine.voc_pad
self.am_upsample = self.tts_engine.am_upsample
self.voc_upsample = self.tts_engine.voc_upsample
def depadding(self, data, chunk_num, chunk_id, block, pad, upsample):
"""
Streaming inference removes the result of pad inference
@ -189,12 +292,6 @@ class TTSServerExecutor(TTSExecutor):
Model inference and result stored in self.output.
"""
am_block = self.am_block
am_pad = self.am_pad
am_upsample = 1
voc_block = self.voc_block
voc_pad = self.voc_pad
voc_upsample = self.voc_upsample
# first_flag 用于标记首包
first_flag = 1
get_tone_ids = False
@ -203,7 +300,7 @@ class TTSServerExecutor(TTSExecutor):
# front
frontend_st = time.time()
if lang == 'zh':
input_ids = self.frontend.get_input_ids(
input_ids = self.executor.frontend.get_input_ids(
text,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids)
@ -211,7 +308,7 @@ class TTSServerExecutor(TTSExecutor):
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
elif lang == 'en':
input_ids = self.frontend.get_input_ids(
input_ids = self.executor.frontend.get_input_ids(
text, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
@ -226,7 +323,7 @@ class TTSServerExecutor(TTSExecutor):
# fastspeech2_csmsc
if am == "fastspeech2_csmsc_onnx":
# am
mel = self.am_sess.run(
mel = self.executor.am_sess.run(
output_names=None, input_feed={'text': part_phone_ids})
mel = mel[0]
if first_flag == 1:
@ -234,14 +331,16 @@ class TTSServerExecutor(TTSExecutor):
self.first_am_infer = first_am_et - frontend_et
# voc streaming
mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc")
mel_chunks = get_chunks(mel, self.voc_block, self.voc_pad,
"voc")
voc_chunk_num = len(mel_chunks)
voc_st = time.time()
for i, mel_chunk in enumerate(mel_chunks):
sub_wav = self.voc_sess.run(
sub_wav = self.executor.voc_sess.run(
output_names=None, input_feed={'logmel': mel_chunk})
sub_wav = self.depadding(sub_wav[0], voc_chunk_num, i,
voc_block, voc_pad, voc_upsample)
self.voc_block, self.voc_pad,
self.voc_upsample)
if first_flag == 1:
first_voc_et = time.time()
self.first_voc_infer = first_voc_et - first_am_et
@ -253,7 +352,7 @@ class TTSServerExecutor(TTSExecutor):
# fastspeech2_cnndecoder_csmsc
elif am == "fastspeech2_cnndecoder_csmsc_onnx":
# am
orig_hs = self.am_encoder_infer_sess.run(
orig_hs = self.executor.am_encoder_infer_sess.run(
None, input_feed={'text': part_phone_ids})
orig_hs = orig_hs[0]
@ -267,9 +366,9 @@ class TTSServerExecutor(TTSExecutor):
hss = get_chunks(orig_hs, self.am_block, self.am_pad, "am")
am_chunk_num = len(hss)
for i, hs in enumerate(hss):
am_decoder_output = self.am_decoder_sess.run(
am_decoder_output = self.executor.am_decoder_sess.run(
None, input_feed={'xs': hs})
am_postnet_output = self.am_postnet_sess.run(
am_postnet_output = self.executor.am_postnet_sess.run(
None,
input_feed={
'xs': np.transpose(am_decoder_output[0], (0, 2, 1))
@ -278,9 +377,11 @@ class TTSServerExecutor(TTSExecutor):
am_postnet_output[0], (0, 2, 1))
normalized_mel = am_output_data[0][0]
sub_mel = denorm(normalized_mel, self.am_mu, self.am_std)
sub_mel = self.depadding(sub_mel, am_chunk_num, i, am_block,
am_pad, am_upsample)
sub_mel = denorm(normalized_mel, self.executor.am_mu,
self.executor.am_std)
sub_mel = self.depadding(sub_mel, am_chunk_num, i,
self.am_block, self.am_pad,
self.am_upsample)
if i == 0:
mel_streaming = sub_mel
@ -297,11 +398,11 @@ class TTSServerExecutor(TTSExecutor):
self.first_am_infer = first_am_et - frontend_et
voc_chunk = mel_streaming[start:end, :]
sub_wav = self.voc_sess.run(
sub_wav = self.executor.voc_sess.run(
output_names=None, input_feed={'logmel': voc_chunk})
sub_wav = self.depadding(sub_wav[0], voc_chunk_num,
voc_chunk_id, voc_block,
voc_pad, voc_upsample)
sub_wav = self.depadding(
sub_wav[0], voc_chunk_num, voc_chunk_id,
self.voc_block, self.voc_pad, self.voc_upsample)
if first_flag == 1:
first_voc_et = time.time()
self.first_voc_infer = first_voc_et - first_am_et
@ -311,9 +412,11 @@ class TTSServerExecutor(TTSExecutor):
yield sub_wav
voc_chunk_id += 1
start = max(0, voc_chunk_id * voc_block - voc_pad)
end = min((voc_chunk_id + 1) * voc_block + voc_pad,
mel_len)
start = max(
0, voc_chunk_id * self.voc_block - self.voc_pad)
end = min(
(voc_chunk_id + 1) * self.voc_block + self.voc_pad,
mel_len)
else:
logger.error(
@ -322,111 +425,6 @@ class TTSServerExecutor(TTSExecutor):
self.final_response_time = time.time() - frontend_st
class TTSEngine(BaseEngine):
"""TTS server engine
Args:
metaclass: Defaults to Singleton.
"""
def __init__(self, name=None):
"""Initialize TTS server engine
"""
super().__init__()
def init(self, config: dict) -> bool:
self.config = config
assert (
self.config.am == "fastspeech2_csmsc_onnx" or
self.config.am == "fastspeech2_cnndecoder_csmsc_onnx"
) and (
self.config.voc == "hifigan_csmsc_onnx" or
self.config.voc == "mb_melgan_csmsc_onnx"
), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.'
assert (
self.config.voc_block > 0 and self.config.voc_pad > 0
), "Please set correct voc_block and voc_pad, they should be more than 0."
assert (
self.config.voc_sample_rate == self.config.am_sample_rate
), "The sample rate of AM and Vocoder model are different, please check model."
self.executor = TTSServerExecutor(
self.config.am_block, self.config.am_pad, self.config.voc_block,
self.config.voc_pad, self.config.voc_upsample)
try:
if self.config.am_sess_conf.device is not None:
self.device = self.config.am_sess_conf.device
elif self.config.voc_sess_conf.device is not None:
self.device = self.config.voc_sess_conf.device
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except BaseException as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
return False
try:
self.executor._init_from_path(
am=self.config.am,
am_ckpt=self.config.am_ckpt,
am_stat=self.config.am_stat,
phones_dict=self.config.phones_dict,
tones_dict=self.config.tones_dict,
speaker_dict=self.config.speaker_dict,
am_sample_rate=self.config.am_sample_rate,
am_sess_conf=self.config.am_sess_conf,
voc=self.config.voc,
voc_ckpt=self.config.voc_ckpt,
voc_sample_rate=self.config.voc_sample_rate,
voc_sess_conf=self.config.voc_sess_conf,
lang=self.config.lang)
except Exception as e:
logger.error("Failed to get model related files.")
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.config.voc_sess_conf.device))
return False
# warm up
try:
self.warm_up()
logger.info("Warm up successfully.")
except Exception as e:
logger.error("Failed to warm up on tts engine.")
return False
logger.info("Initialize TTS server engine successfully on device: %s." %
(self.config.voc_sess_conf.device))
return True
def warm_up(self):
"""warm up
"""
if self.config.lang == 'zh':
sentence = "您好,欢迎使用语音合成服务。"
if self.config.lang == 'en':
sentence = "Hello and welcome to the speech synthesis service."
logger.info("Start to warm up.")
for i in range(3):
for wav in self.executor.infer(
text=sentence,
lang=self.config.lang,
am=self.config.am,
spk_id=0, ):
logger.info(
f"The first response time of the {i} warm up: {self.executor.first_response_time} s"
)
break
def preprocess(self, text_bese64: str=None, text_bytes: bytes=None):
# Convert byte to text
if text_bese64:
@ -459,7 +457,7 @@ class TTSEngine(BaseEngine):
"""
wav_list = []
for wav in self.executor.infer(
for wav in self.infer(
text=sentence,
lang=self.config.lang,
am=self.config.am,
@ -477,11 +475,9 @@ class TTSEngine(BaseEngine):
duration = len(wav_all) / self.config.voc_sample_rate
logger.info(f"sentence: {sentence}")
logger.info(f"The durations of audio is: {duration} s")
logger.info(f"first response time: {self.first_response_time} s")
logger.info(f"final response time: {self.final_response_time} s")
logger.info(f"RTF: {self.final_response_time / duration}")
logger.info(
f"first response time: {self.executor.first_response_time} s")
logger.info(
f"final response time: {self.executor.final_response_time} s")
logger.info(f"RTF: {self.executor.final_response_time / duration}")
logger.info(
f"Other info: front time: {self.executor.frontend_time} s, first am infer time: {self.executor.first_am_infer} s, first voc infer time: {self.executor.first_voc_infer} s,"
f"Other info: front time: {self.frontend_time} s, first am infer time: {self.first_am_infer} s, first voc infer time: {self.first_voc_infer} s,"
)

@ -34,16 +34,12 @@ from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.modules.normalizer import ZScore
from paddlespeech.utils.dynamic_import import dynamic_import
__all__ = ['TTSEngine']
__all__ = ['TTSEngine', 'PaddleTTSConnectionHandler']
class TTSServerExecutor(TTSExecutor):
def __init__(self, am_block, am_pad, voc_block, voc_pad):
def __init__(self):
super().__init__()
self.am_block = am_block
self.am_pad = am_pad
self.voc_block = voc_block
self.voc_pad = voc_pad
self.pretrained_models = pretrained_models
def get_model_info(self,
@ -205,6 +201,106 @@ class TTSServerExecutor(TTSExecutor):
self.voc_inference.eval()
print("voc done!")
class TTSEngine(BaseEngine):
"""TTS server engine
Args:
metaclass: Defaults to Singleton.
"""
def __init__(self, name=None):
"""Initialize TTS server engine
"""
super().__init__()
def init(self, config: dict) -> bool:
self.executor = TTSServerExecutor()
self.config = config
self.lang = self.config.lang
self.engine_type = "online"
assert (
config.am == "fastspeech2_csmsc" or
config.am == "fastspeech2_cnndecoder_csmsc"
) and (
config.voc == "hifigan_csmsc" or config.voc == "mb_melgan_csmsc"
), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.'
assert (
config.voc_block > 0 and config.voc_pad > 0
), "Please set correct voc_block and voc_pad, they should be more than 0."
try:
if self.config.device is not None:
self.device = self.config.device
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except Exception as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
logger.error(e)
return False
try:
self.executor._init_from_path(
am=self.config.am,
am_config=self.config.am_config,
am_ckpt=self.config.am_ckpt,
am_stat=self.config.am_stat,
phones_dict=self.config.phones_dict,
tones_dict=self.config.tones_dict,
speaker_dict=self.config.speaker_dict,
voc=self.config.voc,
voc_config=self.config.voc_config,
voc_ckpt=self.config.voc_ckpt,
voc_stat=self.config.voc_stat,
lang=self.config.lang)
except Exception as e:
logger.error("Failed to get model related files.")
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
logger.error(e)
return False
self.am_block = self.config.am_block
self.am_pad = self.config.am_pad
self.voc_block = self.config.voc_block
self.voc_pad = self.config.voc_pad
self.am_upsample = 1
self.voc_upsample = self.executor.voc_config.n_shift
logger.info("Initialize TTS server engine successfully on device: %s." %
(self.device))
return True
class PaddleTTSConnectionHandler:
def __init__(self, tts_engine):
"""The PaddleSpeech TTS Server Connection Handler
This connection process every tts server request
Args:
tts_engine (TTSEngine): The TTS engine
"""
super().__init__()
logger.info(
"Create PaddleTTSConnectionHandler to process the tts request")
self.tts_engine = tts_engine
self.executor = self.tts_engine.executor
self.config = self.tts_engine.config
self.am_block = self.tts_engine.am_block
self.am_pad = self.tts_engine.am_pad
self.voc_block = self.tts_engine.voc_block
self.voc_pad = self.tts_engine.voc_pad
self.am_upsample = self.tts_engine.am_upsample
self.voc_upsample = self.tts_engine.voc_upsample
def depadding(self, data, chunk_num, chunk_id, block, pad, upsample):
"""
Streaming inference removes the result of pad inference
@ -233,12 +329,6 @@ class TTSServerExecutor(TTSExecutor):
Model inference and result stored in self.output.
"""
am_block = self.am_block
am_pad = self.am_pad
am_upsample = 1
voc_block = self.voc_block
voc_pad = self.voc_pad
voc_upsample = self.voc_config.n_shift
# first_flag 用于标记首包
first_flag = 1
@ -246,7 +336,7 @@ class TTSServerExecutor(TTSExecutor):
merge_sentences = False
frontend_st = time.time()
if lang == 'zh':
input_ids = self.frontend.get_input_ids(
input_ids = self.executor.frontend.get_input_ids(
text,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids)
@ -254,7 +344,7 @@ class TTSServerExecutor(TTSExecutor):
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
elif lang == 'en':
input_ids = self.frontend.get_input_ids(
input_ids = self.executor.frontend.get_input_ids(
text, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
@ -269,19 +359,21 @@ class TTSServerExecutor(TTSExecutor):
# fastspeech2_csmsc
if am == "fastspeech2_csmsc":
# am
mel = self.am_inference(part_phone_ids)
mel = self.executor.am_inference(part_phone_ids)
if first_flag == 1:
first_am_et = time.time()
self.first_am_infer = first_am_et - frontend_et
# voc streaming
mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc")
mel_chunks = get_chunks(mel, self.voc_block, self.voc_pad,
"voc")
voc_chunk_num = len(mel_chunks)
voc_st = time.time()
for i, mel_chunk in enumerate(mel_chunks):
sub_wav = self.voc_inference(mel_chunk)
sub_wav = self.executor.voc_inference(mel_chunk)
sub_wav = self.depadding(sub_wav, voc_chunk_num, i,
voc_block, voc_pad, voc_upsample)
self.voc_block, self.voc_pad,
self.voc_upsample)
if first_flag == 1:
first_voc_et = time.time()
self.first_voc_infer = first_voc_et - first_am_et
@ -293,7 +385,8 @@ class TTSServerExecutor(TTSExecutor):
# fastspeech2_cnndecoder_csmsc
elif am == "fastspeech2_cnndecoder_csmsc":
# am
orig_hs = self.am_inference.encoder_infer(part_phone_ids)
orig_hs = self.executor.am_inference.encoder_infer(
part_phone_ids)
# streaming voc chunk info
mel_len = orig_hs.shape[1]
@ -305,13 +398,15 @@ class TTSServerExecutor(TTSExecutor):
hss = get_chunks(orig_hs, self.am_block, self.am_pad, "am")
am_chunk_num = len(hss)
for i, hs in enumerate(hss):
before_outs = self.am_inference.decoder(hs)
after_outs = before_outs + self.am_inference.postnet(
before_outs = self.executor.am_inference.decoder(hs)
after_outs = before_outs + self.executor.am_inference.postnet(
before_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
normalized_mel = after_outs[0]
sub_mel = denorm(normalized_mel, self.am_mu, self.am_std)
sub_mel = self.depadding(sub_mel, am_chunk_num, i, am_block,
am_pad, am_upsample)
sub_mel = denorm(normalized_mel, self.executor.am_mu,
self.executor.am_std)
sub_mel = self.depadding(sub_mel, am_chunk_num, i,
self.am_block, self.am_pad,
self.am_upsample)
if i == 0:
mel_streaming = sub_mel
@ -328,11 +423,11 @@ class TTSServerExecutor(TTSExecutor):
self.first_am_infer = first_am_et - frontend_et
voc_chunk = mel_streaming[start:end, :]
voc_chunk = paddle.to_tensor(voc_chunk)
sub_wav = self.voc_inference(voc_chunk)
sub_wav = self.executor.voc_inference(voc_chunk)
sub_wav = self.depadding(sub_wav, voc_chunk_num,
voc_chunk_id, voc_block,
voc_pad, voc_upsample)
sub_wav = self.depadding(
sub_wav, voc_chunk_num, voc_chunk_id,
self.voc_block, self.voc_pad, self.voc_upsample)
if first_flag == 1:
first_voc_et = time.time()
self.first_voc_infer = first_voc_et - first_am_et
@ -342,9 +437,11 @@ class TTSServerExecutor(TTSExecutor):
yield sub_wav
voc_chunk_id += 1
start = max(0, voc_chunk_id * voc_block - voc_pad)
end = min((voc_chunk_id + 1) * voc_block + voc_pad,
mel_len)
start = max(
0, voc_chunk_id * self.voc_block - self.voc_pad)
end = min(
(voc_chunk_id + 1) * self.voc_block + self.voc_pad,
mel_len)
else:
logger.error(
@ -353,100 +450,6 @@ class TTSServerExecutor(TTSExecutor):
self.final_response_time = time.time() - frontend_st
class TTSEngine(BaseEngine):
"""TTS server engine
Args:
metaclass: Defaults to Singleton.
"""
def __init__(self, name=None):
"""Initialize TTS server engine
"""
super().__init__()
def init(self, config: dict) -> bool:
self.config = config
assert (
config.am == "fastspeech2_csmsc" or
config.am == "fastspeech2_cnndecoder_csmsc"
) and (
config.voc == "hifigan_csmsc" or config.voc == "mb_melgan_csmsc"
), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.'
assert (
config.voc_block > 0 and config.voc_pad > 0
), "Please set correct voc_block and voc_pad, they should be more than 0."
try:
if self.config.device is not None:
self.device = self.config.device
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except Exception as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
return False
self.executor = TTSServerExecutor(config.am_block, config.am_pad,
config.voc_block, config.voc_pad)
try:
self.executor._init_from_path(
am=self.config.am,
am_config=self.config.am_config,
am_ckpt=self.config.am_ckpt,
am_stat=self.config.am_stat,
phones_dict=self.config.phones_dict,
tones_dict=self.config.tones_dict,
speaker_dict=self.config.speaker_dict,
voc=self.config.voc,
voc_config=self.config.voc_config,
voc_ckpt=self.config.voc_ckpt,
voc_stat=self.config.voc_stat,
lang=self.config.lang)
except Exception as e:
logger.error("Failed to get model related files.")
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
return False
# warm up
try:
self.warm_up()
logger.info("Warm up successfully.")
except Exception as e:
logger.error("Failed to warm up on tts engine.")
return False
logger.info("Initialize TTS server engine successfully on device: %s." %
(self.device))
return True
def warm_up(self):
"""warm up
"""
if self.config.lang == 'zh':
sentence = "您好,欢迎使用语音合成服务。"
if self.config.lang == 'en':
sentence = "Hello and welcome to the speech synthesis service."
logger.info("Start to warm up.")
for i in range(3):
for wav in self.executor.infer(
text=sentence,
lang=self.config.lang,
am=self.config.am,
spk_id=0, ):
logger.info(
f"The first response time of the {i} warm up: {self.executor.first_response_time} s"
)
break
def preprocess(self, text_bese64: str=None, text_bytes: bytes=None):
# Convert byte to text
if text_bese64:
@ -480,7 +483,7 @@ class TTSEngine(BaseEngine):
wav_list = []
for wav in self.executor.infer(
for wav in self.infer(
text=sentence,
lang=self.config.lang,
am=self.config.am,
@ -496,13 +499,12 @@ class TTSEngine(BaseEngine):
wav_all = np.concatenate(wav_list, axis=0)
duration = len(wav_all) / self.executor.am_config.fs
logger.info(f"sentence: {sentence}")
logger.info(f"The durations of audio is: {duration} s")
logger.info(f"first response time: {self.first_response_time} s")
logger.info(f"final response time: {self.final_response_time} s")
logger.info(f"RTF: {self.final_response_time / duration}")
logger.info(
f"first response time: {self.executor.first_response_time} s")
logger.info(
f"final response time: {self.executor.final_response_time} s")
logger.info(f"RTF: {self.executor.final_response_time / duration}")
logger.info(
f"Other info: front time: {self.executor.frontend_time} s, first am infer time: {self.executor.first_am_infer} s, first voc infer time: {self.executor.first_voc_infer} s,"
)
f"Other info: front time: {self.frontend_time} s, first am infer time: {self.first_am_infer} s, first voc infer time: {self.first_voc_infer} s,"
)

@ -14,6 +14,7 @@
import base64
import io
import os
import sys
import time
from typing import Optional
@ -35,7 +36,7 @@ from paddlespeech.server.utils.paddle_predictor import run_model
from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend
__all__ = ['TTSEngine']
__all__ = ['TTSEngine', 'PaddleTTSConnectionHandler']
class TTSServerExecutor(TTSExecutor):
@ -245,7 +246,7 @@ class TTSServerExecutor(TTSExecutor):
else:
wav_all = paddle.concat([wav_all, wav])
self.voc_time += (time.time() - voc_st)
self._outputs['wav'] = wav_all
self._outputs["wav"] = wav_all
class TTSEngine(BaseEngine):
@ -263,6 +264,8 @@ class TTSEngine(BaseEngine):
def init(self, config: dict) -> bool:
self.executor = TTSServerExecutor()
self.config = config
self.lang = self.config.lang
self.engine_type = "inference"
try:
if self.config.am_predictor_conf.device is not None:
@ -272,58 +275,59 @@ class TTSEngine(BaseEngine):
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except BaseException as e:
except Exception as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
logger.error(e)
return False
self.executor._init_from_path(
am=self.config.am,
am_model=self.config.am_model,
am_params=self.config.am_params,
am_sample_rate=self.config.am_sample_rate,
phones_dict=self.config.phones_dict,
tones_dict=self.config.tones_dict,
speaker_dict=self.config.speaker_dict,
voc=self.config.voc,
voc_model=self.config.voc_model,
voc_params=self.config.voc_params,
voc_sample_rate=self.config.voc_sample_rate,
lang=self.config.lang,
am_predictor_conf=self.config.am_predictor_conf,
voc_predictor_conf=self.config.voc_predictor_conf, )
# warm up
try:
self.warm_up()
logger.info("Warm up successfully.")
self.executor._init_from_path(
am=self.config.am,
am_model=self.config.am_model,
am_params=self.config.am_params,
am_sample_rate=self.config.am_sample_rate,
phones_dict=self.config.phones_dict,
tones_dict=self.config.tones_dict,
speaker_dict=self.config.speaker_dict,
voc=self.config.voc,
voc_model=self.config.voc_model,
voc_params=self.config.voc_params,
voc_sample_rate=self.config.voc_sample_rate,
lang=self.config.lang,
am_predictor_conf=self.config.am_predictor_conf,
voc_predictor_conf=self.config.voc_predictor_conf, )
except Exception as e:
logger.error("Failed to warm up on tts engine.")
logger.error("Failed to get model related files.")
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
logger.error(e)
return False
logger.info("Initialize TTS server engine successfully.")
return True
def warm_up(self):
"""warm up
class PaddleTTSConnectionHandler(TTSServerExecutor):
def __init__(self, tts_engine):
"""The PaddleSpeech TTS Server Connection Handler
This connection process every tts server request
Args:
tts_engine (TTSEngine): The TTS engine
"""
if self.config.lang == 'zh':
sentence = "您好,欢迎使用语音合成服务。"
if self.config.lang == 'en':
sentence = "Hello and welcome to the speech synthesis service."
logger.info("Start to warm up.")
for i in range(3):
st = time.time()
self.executor.infer(
text=sentence,
lang=self.config.lang,
am=self.config.am,
spk_id=0, )
logger.info(
f"The response time of the {i} warm up: {time.time() - st} s")
super().__init__()
logger.info(
"Create PaddleTTSConnectionHandler to process the tts request")
self.tts_engine = tts_engine
self.executor = self.tts_engine.executor
self.config = self.tts_engine.config
self.frontend = self.executor.frontend
self.am_predictor = self.executor.am_predictor
self.voc_predictor = self.executor.voc_predictor
def postprocess(self,
wav,
@ -375,8 +379,11 @@ class TTSEngine(BaseEngine):
ErrorCode.SERVER_INTERNAL_ERR,
"Failed to transform speed. Can not install soxbindings on your system. \
You need to set speed value 1.0.")
except BaseException:
sys.exit(-1)
except Exception as e:
logger.error("Failed to transform speed.")
logger.error(e)
sys.exit(-1)
# wav to base64
buf = io.BytesIO()
@ -433,7 +440,7 @@ class TTSEngine(BaseEngine):
try:
infer_st = time.time()
self.executor.infer(
self.infer(
text=sentence, lang=lang, am=self.config.am, spk_id=spk_id)
infer_et = time.time()
infer_time = infer_et - infer_st
@ -441,13 +448,16 @@ class TTSEngine(BaseEngine):
except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts infer failed.")
except BaseException:
sys.exit(-1)
except Exception as e:
logger.error("tts infer failed.")
logger.error(e)
sys.exit(-1)
try:
postprocess_st = time.time()
target_sample_rate, wav_base64 = self.postprocess(
wav=self.executor._outputs['wav'].numpy(),
wav=self._outputs["wav"].numpy(),
original_fs=self.executor.am_sample_rate,
target_fs=sample_rate,
volume=volume,
@ -455,26 +465,28 @@ class TTSEngine(BaseEngine):
audio_path=save_path)
postprocess_et = time.time()
postprocess_time = postprocess_et - postprocess_st
duration = len(self.executor._outputs['wav']
.numpy()) / self.executor.am_sample_rate
duration = len(
self._outputs["wav"].numpy()) / self.executor.am_sample_rate
rtf = infer_time / duration
except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts postprocess failed.")
except BaseException:
sys.exit(-1)
except Exception as e:
logger.error("tts postprocess failed.")
logger.error(e)
sys.exit(-1)
logger.info("AM model: {}".format(self.config.am))
logger.info("Vocoder model: {}".format(self.config.voc))
logger.info("Language: {}".format(lang))
logger.info("tts engine type: paddle inference")
logger.info("tts engine type: python")
logger.info("audio duration: {}".format(duration))
logger.info(
"frontend inference time: {}".format(self.executor.frontend_time))
logger.info("AM inference time: {}".format(self.executor.am_time))
logger.info("Vocoder inference time: {}".format(self.executor.voc_time))
logger.info("frontend inference time: {}".format(self.frontend_time))
logger.info("AM inference time: {}".format(self.am_time))
logger.info("Vocoder inference time: {}".format(self.voc_time))
logger.info("total inference time: {}".format(infer_time))
logger.info(
"postprocess (change speed, volume, target sample rate) time: {}".
@ -482,5 +494,6 @@ class TTSEngine(BaseEngine):
logger.info("total generate audio time: {}".format(infer_time +
postprocess_time))
logger.info("RTF: {}".format(rtf))
logger.info("device: {}".format(self.tts_engine.device))
return lang, target_sample_rate, duration, wav_base64

@ -13,6 +13,7 @@
# limitations under the License.
import base64
import io
import sys
import time
import librosa
@ -28,7 +29,7 @@ from paddlespeech.server.utils.audio_process import change_speed
from paddlespeech.server.utils.errors import ErrorCode
from paddlespeech.server.utils.exception import ServerBaseException
__all__ = ['TTSEngine']
__all__ = ['TTSEngine', 'PaddleTTSConnectionHandler']
class TTSServerExecutor(TTSExecutor):
@ -52,6 +53,8 @@ class TTSEngine(BaseEngine):
def init(self, config: dict) -> bool:
self.executor = TTSServerExecutor()
self.config = config
self.lang = self.config.lang
self.engine_type = "python"
try:
if self.config.device is not None:
@ -59,12 +62,13 @@ class TTSEngine(BaseEngine):
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except BaseException as e:
except Exception as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
logger.error(e)
return False
try:
@ -81,41 +85,35 @@ class TTSEngine(BaseEngine):
voc_ckpt=self.config.voc_ckpt,
voc_stat=self.config.voc_stat,
lang=self.config.lang)
except BaseException:
except Exception as e:
logger.error("Failed to get model related files.")
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
return False
# warm up
try:
self.warm_up()
logger.info("Warm up successfully.")
except Exception as e:
logger.error("Failed to warm up on tts engine.")
logger.error(e)
return False
logger.info("Initialize TTS server engine successfully on device: %s." %
(self.device))
return True
def warm_up(self):
"""warm up
class PaddleTTSConnectionHandler(TTSServerExecutor):
def __init__(self, tts_engine):
"""The PaddleSpeech TTS Server Connection Handler
This connection process every tts server request
Args:
tts_engine (TTSEngine): The TTS engine
"""
if self.config.lang == 'zh':
sentence = "您好,欢迎使用语音合成服务。"
if self.config.lang == 'en':
sentence = "Hello and welcome to the speech synthesis service."
logger.info("Start to warm up.")
for i in range(3):
st = time.time()
self.executor.infer(
text=sentence,
lang=self.config.lang,
am=self.config.am,
spk_id=0, )
logger.info(
f"The response time of the {i} warm up: {time.time() - st} s")
super().__init__()
logger.info(
"Create PaddleTTSConnectionHandler to process the tts request")
self.tts_engine = tts_engine
self.executor = self.tts_engine.executor
self.config = self.tts_engine.config
self.frontend = self.executor.frontend
self.am_inference = self.executor.am_inference
self.voc_inference = self.executor.voc_inference
def postprocess(self,
wav,
@ -167,8 +165,11 @@ class TTSEngine(BaseEngine):
ErrorCode.SERVER_INTERNAL_ERR,
"Failed to transform speed. Can not install soxbindings on your system. \
You need to set speed value 1.0.")
except BaseException:
sys.exit(-1)
except Exception as e:
logger.error("Failed to transform speed.")
logger.error(e)
sys.exit(-1)
# wav to base64
buf = io.BytesIO()
@ -225,24 +226,27 @@ class TTSEngine(BaseEngine):
try:
infer_st = time.time()
self.executor.infer(
self.infer(
text=sentence, lang=lang, am=self.config.am, spk_id=spk_id)
infer_et = time.time()
infer_time = infer_et - infer_st
duration = len(self.executor._outputs['wav']
.numpy()) / self.executor.am_config.fs
duration = len(
self._outputs["wav"].numpy()) / self.executor.am_config.fs
rtf = infer_time / duration
except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts infer failed.")
except BaseException:
sys.exit(-1)
except Exception as e:
logger.error("tts infer failed.")
logger.error(e)
sys.exit(-1)
try:
postprocess_st = time.time()
target_sample_rate, wav_base64 = self.postprocess(
wav=self.executor._outputs['wav'].numpy(),
wav=self._outputs["wav"].numpy(),
original_fs=self.executor.am_config.fs,
target_fs=sample_rate,
volume=volume,
@ -254,8 +258,11 @@ class TTSEngine(BaseEngine):
except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts postprocess failed.")
except BaseException:
sys.exit(-1)
except Exception as e:
logger.error("tts postprocess failed.")
logger.error(e)
sys.exit(-1)
logger.info("AM model: {}".format(self.config.am))
logger.info("Vocoder model: {}".format(self.config.voc))
@ -263,10 +270,9 @@ class TTSEngine(BaseEngine):
logger.info("tts engine type: python")
logger.info("audio duration: {}".format(duration))
logger.info(
"frontend inference time: {}".format(self.executor.frontend_time))
logger.info("AM inference time: {}".format(self.executor.am_time))
logger.info("Vocoder inference time: {}".format(self.executor.voc_time))
logger.info("frontend inference time: {}".format(self.frontend_time))
logger.info("AM inference time: {}".format(self.am_time))
logger.info("Vocoder inference time: {}".format(self.voc_time))
logger.info("total inference time: {}".format(infer_time))
logger.info(
"postprocess (change speed, volume, target sample rate) time: {}".
@ -274,6 +280,6 @@ class TTSEngine(BaseEngine):
logger.info("total generate audio time: {}".format(infer_time +
postprocess_time))
logger.info("RTF: {}".format(rtf))
logger.info("device: {}".format(self.device))
logger.info("device: {}".format(self.tts_engine.device))
return lang, target_sample_rate, duration, wav_base64

@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import sys
import traceback
from typing import Union
from fastapi import APIRouter
from paddlespeech.cli.log import logger
from paddlespeech.server.engine.engine_pool import get_engine_pool
from paddlespeech.server.restful.request import ASRRequest
from paddlespeech.server.restful.response import ASRResponse
@ -68,8 +70,18 @@ def asr(request_body: ASRRequest):
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
asr_engine.run(audio_data)
asr_results = asr_engine.postprocess()
if asr_engine.engine_type == "python":
from paddlespeech.server.engine.asr.python.asr_engine import PaddleASRConnectionHandler
elif asr_engine.engine_type == "inference":
from paddlespeech.server.engine.asr.paddleinference.asr_engine import PaddleASRConnectionHandler
else:
logger.error("Offline asr engine only support python or inference.")
sys.exit(-1)
connection_handler = PaddleASRConnectionHandler(asr_engine)
connection_handler.run(audio_data)
asr_results = connection_handler.postprocess()
response = {
"success": True,

@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import sys
import traceback
from typing import Union
from fastapi import APIRouter
from paddlespeech.cli.log import logger
from paddlespeech.server.engine.engine_pool import get_engine_pool
from paddlespeech.server.restful.request import CLSRequest
from paddlespeech.server.restful.response import CLSResponse
@ -68,8 +70,18 @@ def cls(request_body: CLSRequest):
engine_pool = get_engine_pool()
cls_engine = engine_pool['cls']
cls_engine.run(audio_data)
cls_results = cls_engine.postprocess(request_body.topk)
if cls_engine.engine_type == "python":
from paddlespeech.server.engine.cls.python.cls_engine import PaddleCLSConnectionHandler
elif cls_engine.engine_type == "inference":
from paddlespeech.server.engine.cls.paddleinference.cls_engine import PaddleCLSConnectionHandler
else:
logger.error("Offline cls engine only support python or inference.")
sys.exit(-1)
connection_handler = PaddleCLSConnectionHandler(cls_engine)
connection_handler.run(audio_data)
cls_results = connection_handler.postprocess(request_body.topk)
response = {
"success": True,
@ -85,8 +97,11 @@ def cls(request_body: CLSRequest):
except ServerBaseException as e:
response = failed_response(e.error_code, e.msg)
except BaseException:
logger.error(e)
sys.exit(-1)
except Exception as e:
response = failed_response(ErrorCode.SERVER_UNKOWN_ERR)
logger.error(e)
traceback.print_exc()
return response

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import traceback
from typing import Union
@ -99,7 +100,16 @@ def tts(request_body: TTSRequest):
tts_engine = engine_pool['tts']
logger.info("Get tts engine successfully.")
lang, target_sample_rate, duration, wav_base64 = tts_engine.run(
if tts_engine.engine_type == "python":
from paddlespeech.server.engine.tts.python.tts_engine import PaddleTTSConnectionHandler
elif tts_engine.engine_type == "inference":
from paddlespeech.server.engine.tts.paddleinference.tts_engine import PaddleTTSConnectionHandler
else:
logger.error("Offline tts engine only support python or inference.")
sys.exit(-1)
connection_handler = PaddleTTSConnectionHandler(tts_engine)
lang, target_sample_rate, duration, wav_base64 = connection_handler.run(
text, spk_id, speed, volume, sample_rate, save_path)
response = {
@ -136,4 +146,14 @@ async def stream_tts(request_body: TTSRequest):
tts_engine = engine_pool['tts']
logger.info("Get tts engine successfully.")
return StreamingResponse(tts_engine.run(sentence=text))
if tts_engine.engine_type == "online":
from paddlespeech.server.engine.tts.online.python.tts_engine import PaddleTTSConnectionHandler
elif tts_engine.engine_type == "online-onnx":
from paddlespeech.server.engine.tts.online.onnx.tts_engine import PaddleTTSConnectionHandler
else:
logger.error("Online tts engine only support online or online-onnx.")
sys.exit(-1)
connection_handler = PaddleTTSConnectionHandler(tts_engine)
return StreamingResponse(connection_handler.run(sentence=text))

@ -40,6 +40,16 @@ async def websocket_endpoint(websocket: WebSocket):
engine_pool = get_engine_pool()
tts_engine = engine_pool['tts']
connection_handler = None
if tts_engine.engine_type == "online":
from paddlespeech.server.engine.tts.online.python.tts_engine import PaddleTTSConnectionHandler
elif tts_engine.engine_type == "online-onnx":
from paddlespeech.server.engine.tts.online.onnx.tts_engine import PaddleTTSConnectionHandler
else:
logger.error("Online tts engine only support online or online-onnx.")
sys.exit(-1)
try:
while True:
# careful here, changed the source code from starlette.websockets
@ -57,10 +67,13 @@ async def websocket_endpoint(websocket: WebSocket):
"signal": "server ready",
"session": session
}
connection_handler = PaddleTTSConnectionHandler(tts_engine)
await websocket.send_json(resp)
# end request
elif message['signal'] == 'end':
connection_handler = None
resp = {
"status": 0,
"signal": "connection will be closed",
@ -75,10 +88,11 @@ async def websocket_endpoint(websocket: WebSocket):
# speech synthesis request
elif 'text' in message:
text_bese64 = message["text"]
sentence = tts_engine.preprocess(text_bese64=text_bese64)
sentence = connection_handler.preprocess(
text_bese64=text_bese64)
# run
wav_generator = tts_engine.run(sentence)
wav_generator = connection_handler.run(sentence)
while True:
try:

Loading…
Cancel
Save