update tts engine, test=doc

pull/1955/head
lym0302 3 years ago
parent 7e987a6bcd
commit 14f9000f79

@ -0,0 +1,70 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import warnings
import uvicorn
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware
from paddlespeech.server.engine.engine_pool import init_engine_pool
from paddlespeech.server.restful.api import setup_router as setup_http_router
from paddlespeech.server.utils.config import get_config
from paddlespeech.server.ws.api import setup_router as setup_ws_router
warnings.filterwarnings("ignore")
import sys
app = FastAPI(
title="PaddleSpeech Serving API", description="Api", version="0.0.1")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"])
# change yaml file here
config_file = "./conf/application.yaml"
config = get_config(config_file)
# init engine
if not init_engine_pool(config):
print("Failed to init engine.")
sys.exit(-1)
# get api_router
api_list = list(engine.split("_")[0] for engine in config.engine_list)
if config.protocol == "websocket":
api_router = setup_ws_router(api_list)
elif config.protocol == "http":
api_router = setup_http_router(api_list)
else:
raise Exception("unsupported protocol")
sys.exit(-1)
# app needs to operate outside the main function
app.include_router(api_router)
if __name__ == "__main__":
parser = argparse.ArgumentParser(add_help=True)
parser.add_argument(
"--workers", type=int, help="workers of server", default=1)
args = parser.parse_args()
uvicorn.run(
"start_multi_progress_server:app",
host=config.host,
port=config.port,
debug=True,
workers=args.workers)

@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import time
from paddlespeech.cli.log import logger
from paddlespeech.server.engine.engine_factory import EngineFactory
# global value
@ -24,6 +28,50 @@ def get_engine_pool() -> dict:
return ENGINE_POOL
def warm_up(engine_and_type: str, engine, warm_up_time: int=3) -> bool:
if "tts" in engine_and_type:
if engine.lang == 'zh':
sentence = "您好,欢迎使用语音合成服务。"
elif engine.lang == 'en':
sentence = "Hello and welcome to the speech synthesis service."
else:
logger.error("tts engine only support lang: zh or en.")
sys.exit(-1)
if engine_and_type == "tts_python":
from paddlespeech.server.engine.tts.python.tts_engine import TTSHandler
elif engine_and_type == "tts_inference":
from paddlespeech.server.engine.tts.paddleinference.tts_engine import TTSHandler
elif engine_and_type == "tts_online":
pass
elif engine_and_type == "tts_online-onnx":
pass
else:
logger.error("Please check tte engine type.")
try:
logger.info("Start to warm up tts engine.")
for i in range(warm_up_time):
tts_handler = TTSHandler(engine)
st = time.time()
tts_handler.infer(
text=sentence,
lang=engine.config.lang,
am=engine.config.am,
spk_id=0, )
logger.info(
f"The response time of the {i} warm up: {time.time() - st} s"
)
except Exception as e:
logger.error("Failed to warm up on tts engine.")
logger.error(e)
return False
else:
pass
return True
def init_engine_pool(config) -> bool:
""" Init engine pool
"""
@ -38,4 +86,7 @@ def init_engine_pool(config) -> bool:
if not ENGINE_POOL[engine].init(config=config[engine_and_type]):
return False
if not warm_up(engine_and_type, ENGINE_POOL[engine]):
return False
return True

@ -85,3 +85,40 @@ pretrained_models = {
24000,
},
}
model_alias = {
# acoustic model
"speedyspeech":
"paddlespeech.t2s.models.speedyspeech:SpeedySpeech",
"speedyspeech_inference":
"paddlespeech.t2s.models.speedyspeech:SpeedySpeechInference",
"fastspeech2":
"paddlespeech.t2s.models.fastspeech2:FastSpeech2",
"fastspeech2_inference":
"paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference",
"tacotron2":
"paddlespeech.t2s.models.tacotron2:Tacotron2",
"tacotron2_inference":
"paddlespeech.t2s.models.tacotron2:Tacotron2Inference",
# voc
"pwgan":
"paddlespeech.t2s.models.parallel_wavegan:PWGGenerator",
"pwgan_inference":
"paddlespeech.t2s.models.parallel_wavegan:PWGInference",
"mb_melgan":
"paddlespeech.t2s.models.melgan:MelGANGenerator",
"mb_melgan_inference":
"paddlespeech.t2s.models.melgan:MelGANInference",
"style_melgan":
"paddlespeech.t2s.models.melgan:StyleMelGANGenerator",
"style_melgan_inference":
"paddlespeech.t2s.models.melgan:StyleMelGANInference",
"hifigan":
"paddlespeech.t2s.models.hifigan:HiFiGANGenerator",
"hifigan_inference":
"paddlespeech.t2s.models.hifigan:HiFiGANInference",
"wavernn":
"paddlespeech.t2s.models.wavernn:WaveRNN",
"wavernn_inference":
"paddlespeech.t2s.models.wavernn:WaveRNNInference",
}

@ -23,9 +23,11 @@ import paddle
import soundfile as sf
from scipy.io import wavfile
from .pretrained_models import model_alias
from .pretrained_models import pretrained_models
from paddlespeech.cli.log import logger
from paddlespeech.cli.tts.infer import TTSExecutor
from paddlespeech.cli.utils import download_and_decompress
from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.audio_process import change_speed
from paddlespeech.server.utils.errors import ErrorCode
@ -35,13 +37,72 @@ from paddlespeech.server.utils.paddle_predictor import run_model
from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend
__all__ = ['TTSEngine']
__all__ = ['TTSEngine', 'TTSHandler']
class TTSServerExecutor(TTSExecutor):
class TTSEngine(BaseEngine):
"""TTS server engine
Args:
metaclass: Defaults to Singleton.
"""
def __init__(self):
super().__init__()
"""Initialize TTS server engine
"""
super(TTSEngine, self).__init__()
self.model_alias = model_alias
self.pretrained_models = pretrained_models
self.engine_type = "inference"
def init(self, config: dict) -> bool:
self.config = config
self.lang = self.config.lang
try:
if self.config.am_predictor_conf.device is not None:
self.device = self.config.am_predictor_conf.device
elif self.config.voc_predictor_conf.device is not None:
self.device = self.config.voc_predictor_conf.device
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except Exception as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
logger.error(e)
return False
try:
self._init_from_path(
am=self.config.am,
am_model=self.config.am_model,
am_params=self.config.am_params,
am_sample_rate=self.config.am_sample_rate,
phones_dict=self.config.phones_dict,
tones_dict=self.config.tones_dict,
speaker_dict=self.config.speaker_dict,
voc=self.config.voc,
voc_model=self.config.voc_model,
voc_params=self.config.voc_params,
voc_sample_rate=self.config.voc_sample_rate,
lang=self.config.lang,
am_predictor_conf=self.config.am_predictor_conf,
voc_predictor_conf=self.config.voc_predictor_conf, )
logger.info(
"Initialize TTS server engine successfully on device: %s." %
(self.device))
except Exception as e:
logger.error("Failed to get model related files.")
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
logger.error(e)
return False
return True
def _init_from_path(
self,
@ -176,6 +237,31 @@ class TTSServerExecutor(TTSExecutor):
predictor_conf=self.voc_predictor_conf)
logger.info("Create Vocoder predictor successfully.")
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
Download and returns pretrained resources path of current task.
"""
support_models = list(self.pretrained_models.keys())
assert tag in self.pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format(
tag, '\n\t\t'.join(support_models))
res_path = os.path.join(MODEL_HOME, tag)
decompressed_path = download_and_decompress(self.pretrained_models[tag],
res_path)
decompressed_path = os.path.abspath(decompressed_path)
logger.info(
'Use pretrained model stored in: {}'.format(decompressed_path))
return decompressed_path
class TTSHandler():
def __init__(self, tts_engine):
"""Initialize TTS server engine
"""
super(TTSHandler, self).__init__()
self.tts_engine = tts_engine
@paddle.no_grad()
def infer(self,
text: str,
@ -189,11 +275,12 @@ class TTSServerExecutor(TTSExecutor):
am_dataset = am[am.rindex('_') + 1:]
get_tone_ids = False
merge_sentences = False
frontend_st = time.time()
if am_name == 'speedyspeech':
get_tone_ids = True
if lang == 'zh':
input_ids = self.frontend.get_input_ids(
input_ids = self.tts_engine.frontend.get_input_ids(
text,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids)
@ -201,7 +288,7 @@ class TTSServerExecutor(TTSExecutor):
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
elif lang == 'en':
input_ids = self.frontend.get_input_ids(
input_ids = self.tts_engine.frontend.get_input_ids(
text, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
@ -218,7 +305,7 @@ class TTSServerExecutor(TTSExecutor):
if am_name == 'speedyspeech':
part_tone_ids = tone_ids[i]
am_result = run_model(
self.am_predictor,
self.tts_engine.am_predictor,
[part_phone_ids.numpy(), part_tone_ids.numpy()])
mel = am_result[0]
@ -228,14 +315,14 @@ class TTSServerExecutor(TTSExecutor):
if am_dataset in {"aishell3", "vctk"}:
pass
else:
am_result = run_model(self.am_predictor,
am_result = run_model(self.tts_engine.am_predictor,
[part_phone_ids.numpy()])
mel = am_result[0]
self.am_time += (time.time() - am_st)
# voc
voc_st = time.time()
voc_result = run_model(self.voc_predictor, [mel])
voc_result = run_model(self.tts_engine.voc_predictor, [mel])
wav = voc_result[0]
wav = paddle.to_tensor(wav)
@ -245,85 +332,7 @@ class TTSServerExecutor(TTSExecutor):
else:
wav_all = paddle.concat([wav_all, wav])
self.voc_time += (time.time() - voc_st)
self._outputs['wav'] = wav_all
class TTSEngine(BaseEngine):
"""TTS server engine
Args:
metaclass: Defaults to Singleton.
"""
def __init__(self):
"""Initialize TTS server engine
"""
super(TTSEngine, self).__init__()
def init(self, config: dict) -> bool:
self.executor = TTSServerExecutor()
self.config = config
try:
if self.config.am_predictor_conf.device is not None:
self.device = self.config.am_predictor_conf.device
elif self.config.voc_predictor_conf.device is not None:
self.device = self.config.voc_predictor_conf.device
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except BaseException as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
return False
self.executor._init_from_path(
am=self.config.am,
am_model=self.config.am_model,
am_params=self.config.am_params,
am_sample_rate=self.config.am_sample_rate,
phones_dict=self.config.phones_dict,
tones_dict=self.config.tones_dict,
speaker_dict=self.config.speaker_dict,
voc=self.config.voc,
voc_model=self.config.voc_model,
voc_params=self.config.voc_params,
voc_sample_rate=self.config.voc_sample_rate,
lang=self.config.lang,
am_predictor_conf=self.config.am_predictor_conf,
voc_predictor_conf=self.config.voc_predictor_conf, )
# warm up
try:
self.warm_up()
logger.info("Warm up successfully.")
except Exception as e:
logger.error("Failed to warm up on tts engine.")
return False
logger.info("Initialize TTS server engine successfully.")
return True
def warm_up(self):
"""warm up
"""
if self.config.lang == 'zh':
sentence = "您好,欢迎使用语音合成服务。"
if self.config.lang == 'en':
sentence = "Hello and welcome to the speech synthesis service."
logger.info("Start to warm up.")
for i in range(3):
st = time.time()
self.executor.infer(
text=sentence,
lang=self.config.lang,
am=self.config.am,
spk_id=0, )
logger.info(
f"The response time of the {i} warm up: {time.time() - st} s")
self.output = wav_all
def postprocess(self,
wav,
@ -375,8 +384,9 @@ class TTSEngine(BaseEngine):
ErrorCode.SERVER_INTERNAL_ERR,
"Failed to transform speed. Can not install soxbindings on your system. \
You need to set speed value 1.0.")
except BaseException:
except Exception as e:
logger.error("Failed to transform speed.")
logger.error(e)
# wav to base64
buf = io.BytesIO()
@ -429,52 +439,59 @@ class TTSEngine(BaseEngine):
wav_base64: The base64 format of the synthesized audio.
"""
lang = self.config.lang
lang = self.tts_engine.lang
try:
infer_st = time.time()
self.executor.infer(
text=sentence, lang=lang, am=self.config.am, spk_id=spk_id)
self.infer(
text=sentence,
lang=lang,
am=self.tts_engine.config.am,
spk_id=spk_id)
infer_et = time.time()
infer_time = infer_et - infer_st
except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts infer failed.")
except BaseException:
sys.exit(-1)
except Exception as e:
logger.error("tts infer failed.")
logger.error(e)
sys.exit(-1)
try:
postprocess_st = time.time()
target_sample_rate, wav_base64 = self.postprocess(
wav=self.executor._outputs['wav'].numpy(),
original_fs=self.executor.am_sample_rate,
wav=self.output.numpy(),
original_fs=self.tts_engine.am_sample_rate,
target_fs=sample_rate,
volume=volume,
speed=speed,
audio_path=save_path)
postprocess_et = time.time()
postprocess_time = postprocess_et - postprocess_st
duration = len(self.executor._outputs['wav']
.numpy()) / self.executor.am_sample_rate
duration = len(self.output.numpy()) / self.tts_engine.am_sample_rate
rtf = infer_time / duration
except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts postprocess failed.")
except BaseException:
sys.exit(-1)
except Exception as e:
logger.error("tts postprocess failed.")
logger.error(e)
sys.exit(-1)
logger.info("AM model: {}".format(self.config.am))
logger.info("Vocoder model: {}".format(self.config.voc))
logger.info("AM model: {}".format(self.tts_engine.config.am))
logger.info("Vocoder model: {}".format(self.tts_engine.config.voc))
logger.info("Language: {}".format(lang))
logger.info("tts engine type: paddle inference")
logger.info("audio duration: {}".format(duration))
logger.info(
"frontend inference time: {}".format(self.executor.frontend_time))
logger.info("AM inference time: {}".format(self.executor.am_time))
logger.info("Vocoder inference time: {}".format(self.executor.voc_time))
logger.info("frontend inference time: {}".format(self.frontend_time))
logger.info("AM inference time: {}".format(self.am_time))
logger.info("Vocoder inference time: {}".format(self.voc_time))
logger.info("total inference time: {}".format(infer_time))
logger.info(
"postprocess (change speed, volume, target sample rate) time: {}".

@ -13,32 +13,37 @@
# limitations under the License.
import base64
import io
import os
import time
from typing import Optional
import librosa
import numpy as np
import paddle
import soundfile as sf
import yaml
from scipy.io import wavfile
from yacs.config import CfgNode
from .pretrained_models import model_alias
from .pretrained_models import pretrained_models
from paddlespeech.cli.log import logger
from paddlespeech.cli.tts.infer import TTSExecutor
from paddlespeech.cli.utils import download_and_decompress
from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.audio_process import change_speed
from paddlespeech.server.utils.errors import ErrorCode
from paddlespeech.server.utils.exception import ServerBaseException
from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.modules.normalizer import ZScore
__all__ = ['TTSEngine']
class TTSServerExecutor(TTSExecutor):
def __init__(self):
super().__init__()
pass
__all__ = ['TTSEngine', 'TTSHandler']
class TTSEngine(BaseEngine):
"""TTS server engine
"""TTS server engine for model setting
Args:
metaclass: Defaults to Singleton.
@ -48,10 +53,13 @@ class TTSEngine(BaseEngine):
"""Initialize TTS server engine
"""
super(TTSEngine, self).__init__()
self.model_alias = model_alias
self.pretrained_models = pretrained_models
self.engine_type = "python"
def init(self, config: dict) -> bool:
self.executor = TTSServerExecutor()
self.config = config
self.lang = self.config.lang
try:
if self.config.device is not None:
@ -59,16 +67,17 @@ class TTSEngine(BaseEngine):
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except BaseException as e:
except Exception as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
logger.error(e)
return False
try:
self.executor._init_from_path(
self._init_from_path(
am=self.config.am,
am_config=self.config.am_config,
am_ckpt=self.config.am_ckpt,
@ -81,41 +90,278 @@ class TTSEngine(BaseEngine):
voc_ckpt=self.config.voc_ckpt,
voc_stat=self.config.voc_stat,
lang=self.config.lang)
except BaseException:
logger.info(
"Initialize TTS server engine successfully on device: %s." %
(self.device))
except Exception as e:
logger.error("Failed to get model related files.")
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
logger.error(e)
return False
# warm up
try:
self.warm_up()
logger.info("Warm up successfully.")
except Exception as e:
logger.error("Failed to warm up on tts engine.")
return False
logger.info("Initialize TTS server engine successfully on device: %s." %
(self.device))
return True
def warm_up(self):
"""warm up
def _init_from_path(
self,
am: str='fastspeech2_csmsc',
am_config: Optional[os.PathLike]=None,
am_ckpt: Optional[os.PathLike]=None,
am_stat: Optional[os.PathLike]=None,
phones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None,
speaker_dict: Optional[os.PathLike]=None,
voc: str='pwgan_csmsc',
voc_config: Optional[os.PathLike]=None,
voc_ckpt: Optional[os.PathLike]=None,
voc_stat: Optional[os.PathLike]=None,
lang: str='zh', ):
"""
Init model and other resources from a specific path.
"""
if self.config.lang == 'zh':
sentence = "您好,欢迎使用语音合成服务。"
if self.config.lang == 'en':
sentence = "Hello and welcome to the speech synthesis service."
logger.info("Start to warm up.")
for i in range(3):
st = time.time()
self.executor.infer(
text=sentence,
lang=self.config.lang,
am=self.config.am,
spk_id=0, )
logger.info(
f"The response time of the {i} warm up: {time.time() - st} s")
if hasattr(self, 'am_inference') and hasattr(self, 'voc_inference'):
logger.info('Models had been initialized.')
return
# am
am_tag = am + '-' + lang
if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None:
am_res_path = self._get_pretrained_path(am_tag)
self.am_res_path = am_res_path
self.am_config = os.path.join(
am_res_path, self.pretrained_models[am_tag]['config'])
self.am_ckpt = os.path.join(am_res_path,
self.pretrained_models[am_tag]['ckpt'])
self.am_stat = os.path.join(
am_res_path, self.pretrained_models[am_tag]['speech_stats'])
# must have phones_dict in acoustic
self.phones_dict = os.path.join(
am_res_path, self.pretrained_models[am_tag]['phones_dict'])
logger.info(am_res_path)
logger.info(self.am_config)
logger.info(self.am_ckpt)
else:
self.am_config = os.path.abspath(am_config)
self.am_ckpt = os.path.abspath(am_ckpt)
self.am_stat = os.path.abspath(am_stat)
self.phones_dict = os.path.abspath(phones_dict)
self.am_res_path = os.path.dirname(os.path.abspath(self.am_config))
# for speedyspeech
self.tones_dict = None
if 'tones_dict' in self.pretrained_models[am_tag]:
self.tones_dict = os.path.join(
am_res_path, self.pretrained_models[am_tag]['tones_dict'])
if tones_dict:
self.tones_dict = tones_dict
# for multi speaker fastspeech2
self.speaker_dict = None
if 'speaker_dict' in self.pretrained_models[am_tag]:
self.speaker_dict = os.path.join(
am_res_path, self.pretrained_models[am_tag]['speaker_dict'])
if speaker_dict:
self.speaker_dict = speaker_dict
# voc
voc_tag = voc + '-' + lang
if voc_ckpt is None or voc_config is None or voc_stat is None:
voc_res_path = self._get_pretrained_path(voc_tag)
self.voc_res_path = voc_res_path
self.voc_config = os.path.join(
voc_res_path, self.pretrained_models[voc_tag]['config'])
self.voc_ckpt = os.path.join(
voc_res_path, self.pretrained_models[voc_tag]['ckpt'])
self.voc_stat = os.path.join(
voc_res_path, self.pretrained_models[voc_tag]['speech_stats'])
logger.info(voc_res_path)
logger.info(self.voc_config)
logger.info(self.voc_ckpt)
else:
self.voc_config = os.path.abspath(voc_config)
self.voc_ckpt = os.path.abspath(voc_ckpt)
self.voc_stat = os.path.abspath(voc_stat)
self.voc_res_path = os.path.dirname(
os.path.abspath(self.voc_config))
# Init body.
with open(self.am_config) as f:
self.am_config = CfgNode(yaml.safe_load(f))
with open(self.voc_config) as f:
self.voc_config = CfgNode(yaml.safe_load(f))
with open(self.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id)
logger.info(f"vocab_size: {vocab_size}")
tone_size = None
if self.tones_dict:
with open(self.tones_dict, "r") as f:
tone_id = [line.strip().split() for line in f.readlines()]
tone_size = len(tone_id)
logger.info(f"tone_size: {tone_size}")
spk_num = None
if self.speaker_dict:
with open(self.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id)
logger.info(f"spk_num: {spk_num}")
# frontend
if lang == 'zh':
self.frontend = Frontend(
phone_vocab_path=self.phones_dict,
tone_vocab_path=self.tones_dict)
elif lang == 'en':
self.frontend = English(phone_vocab_path=self.phones_dict)
logger.info("frontend done!")
# acoustic model
odim = self.am_config.n_mels
# model: {model_name}_{dataset}
am_name = am[:am.rindex('_')]
am_class = dynamic_import(am_name, self.model_alias)
am_inference_class = dynamic_import(am_name + '_inference',
self.model_alias)
if am_name == 'fastspeech2':
am = am_class(
idim=vocab_size,
odim=odim,
spk_num=spk_num,
**self.am_config["model"])
elif am_name == 'speedyspeech':
am = am_class(
vocab_size=vocab_size,
tone_size=tone_size,
**self.am_config["model"])
elif am_name == 'tacotron2':
am = am_class(idim=vocab_size, odim=odim, **self.am_config["model"])
am.set_state_dict(paddle.load(self.am_ckpt)["main_params"])
am.eval()
am_mu, am_std = np.load(self.am_stat)
am_mu = paddle.to_tensor(am_mu)
am_std = paddle.to_tensor(am_std)
am_normalizer = ZScore(am_mu, am_std)
self.am_inference = am_inference_class(am_normalizer, am)
self.am_inference.eval()
logger.info("acoustic model done!")
# vocoder
# model: {model_name}_{dataset}
voc_name = voc[:voc.rindex('_')]
voc_class = dynamic_import(voc_name, self.model_alias)
voc_inference_class = dynamic_import(voc_name + '_inference',
self.model_alias)
if voc_name != 'wavernn':
voc = voc_class(**self.voc_config["generator_params"])
voc.set_state_dict(paddle.load(self.voc_ckpt)["generator_params"])
voc.remove_weight_norm()
voc.eval()
else:
voc = voc_class(**self.voc_config["model"])
voc.set_state_dict(paddle.load(self.voc_ckpt)["main_params"])
voc.eval()
voc_mu, voc_std = np.load(self.voc_stat)
voc_mu = paddle.to_tensor(voc_mu)
voc_std = paddle.to_tensor(voc_std)
voc_normalizer = ZScore(voc_mu, voc_std)
self.voc_inference = voc_inference_class(voc_normalizer, voc)
self.voc_inference.eval()
logger.info("voc done!")
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
Download and returns pretrained resources path of current task.
"""
support_models = list(self.pretrained_models.keys())
assert tag in self.pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format(
tag, '\n\t\t'.join(support_models))
res_path = os.path.join(MODEL_HOME, tag)
decompressed_path = download_and_decompress(self.pretrained_models[tag],
res_path)
decompressed_path = os.path.abspath(decompressed_path)
logger.info(
'Use pretrained model stored in: {}'.format(decompressed_path))
return decompressed_path
class TTSHandler():
def __init__(self, tts_engine):
"""Initialize TTS server engine
"""
super(TTSHandler, self).__init__()
self.tts_engine = tts_engine
@paddle.no_grad()
def infer(self,
text: str,
lang: str='zh',
am: str='fastspeech2_csmsc',
spk_id: int=0):
"""
Model inference and result stored in self.output.
"""
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
get_tone_ids = False
merge_sentences = False
frontend_st = time.time()
if am_name == 'speedyspeech':
get_tone_ids = True
if lang == 'zh':
input_ids = self.tts_engine.frontend.get_input_ids(
text,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
elif lang == 'en':
input_ids = self.tts_engine.frontend.get_input_ids(
text, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
logger.info("lang should in {'zh', 'en'}!")
self.frontend_time = time.time() - frontend_st
self.am_time = 0
self.voc_time = 0
flags = 0
for i in range(len(phone_ids)):
am_st = time.time()
part_phone_ids = phone_ids[i]
# am
if am_name == 'speedyspeech':
part_tone_ids = tone_ids[i]
mel = self.tts_engine.am_inference(part_phone_ids,
part_tone_ids)
# fastspeech2
else:
# multi speaker
if am_dataset in {"aishell3", "vctk"}:
mel = self.tts_engine.am_inference(
part_phone_ids, spk_id=paddle.to_tensor(spk_id))
else:
mel = self.tts_engine.am_inference(part_phone_ids)
self.am_time += (time.time() - am_st)
# voc
voc_st = time.time()
wav = self.tts_engine.voc_inference(mel)
if flags == 0:
wav_all = wav
flags = 1
else:
wav_all = paddle.concat([wav_all, wav])
self.voc_time += (time.time() - voc_st)
self.output = wav_all
def postprocess(self,
wav,
@ -167,8 +413,9 @@ class TTSEngine(BaseEngine):
ErrorCode.SERVER_INTERNAL_ERR,
"Failed to transform speed. Can not install soxbindings on your system. \
You need to set speed value 1.0.")
except BaseException:
except Exception as e:
logger.error("Failed to transform speed.")
logger.error(e)
# wav to base64
buf = io.BytesIO()
@ -221,29 +468,34 @@ class TTSEngine(BaseEngine):
wav_base64: The base64 format of the synthesized audio.
"""
lang = self.config.lang
lang = self.tts_engine.config.lang
try:
infer_st = time.time()
self.executor.infer(
text=sentence, lang=lang, am=self.config.am, spk_id=spk_id)
self.infer(
text=sentence,
lang=self.tts_engine.config.lang,
am=self.tts_engine.config.am,
spk_id=spk_id)
infer_et = time.time()
infer_time = infer_et - infer_st
duration = len(self.executor._outputs['wav']
.numpy()) / self.executor.am_config.fs
duration = len(self.output.numpy()) / self.tts_engine.am_config.fs
rtf = infer_time / duration
except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts infer failed.")
except BaseException:
sys.exit(-1)
except Exception as e:
logger.error("tts infer failed.")
logger.error(e)
sys.exit(-1)
try:
postprocess_st = time.time()
target_sample_rate, wav_base64 = self.postprocess(
wav=self.executor._outputs['wav'].numpy(),
original_fs=self.executor.am_config.fs,
wav=self.output.numpy(),
original_fs=self.tts_engine.am_config.fs,
target_fs=sample_rate,
volume=volume,
speed=speed,
@ -254,19 +506,21 @@ class TTSEngine(BaseEngine):
except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts postprocess failed.")
except BaseException:
sys.exit(-1)
except Exception as e:
logger.error("tts postprocess failed.")
logger.error(e)
sys.exit(-1)
logger.info("AM model: {}".format(self.config.am))
logger.info("Vocoder model: {}".format(self.config.voc))
logger.info("AM model: {}".format(self.tts_engine.config.am))
logger.info("Vocoder model: {}".format(self.tts_engine.config.voc))
logger.info("Language: {}".format(lang))
logger.info("tts engine type: python")
logger.info("audio duration: {}".format(duration))
logger.info(
"frontend inference time: {}".format(self.executor.frontend_time))
logger.info("AM inference time: {}".format(self.executor.am_time))
logger.info("Vocoder inference time: {}".format(self.executor.voc_time))
logger.info("frontend inference time: {}".format(self.frontend_time))
logger.info("AM inference time: {}".format(self.am_time))
logger.info("Vocoder inference time: {}".format(self.voc_time))
logger.info("total inference time: {}".format(infer_time))
logger.info(
"postprocess (change speed, volume, target sample rate) time: {}".
@ -274,6 +528,6 @@ class TTSEngine(BaseEngine):
logger.info("total generate audio time: {}".format(infer_time +
postprocess_time))
logger.info("RTF: {}".format(rtf))
logger.info("device: {}".format(self.device))
logger.info("device: {}".format(self.tts_engine.device))
return lang, target_sample_rate, duration, wav_base64

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import traceback
from typing import Union
@ -99,7 +100,16 @@ def tts(request_body: TTSRequest):
tts_engine = engine_pool['tts']
logger.info("Get tts engine successfully.")
lang, target_sample_rate, duration, wav_base64 = tts_engine.run(
if tts_engine.engine_type == "python":
from paddlespeech.server.engine.tts.python.tts_engine import TTSHandler
elif tts_engine.engine_type == "inference":
from paddlespeech.server.engine.tts.paddleinference.tts_engine import TTSHandler
else:
logger.error("Offline tts engine only support python or inference.")
sys.exit(-1)
tts_handler = TTSHandler(tts_engine)
lang, target_sample_rate, duration, wav_base64 = tts_handler.run(
text, spk_id, speed, volume, sample_rate, save_path)
response = {

Loading…
Cancel
Save