diff --git a/demos/speech_server/start_multi_progress_server.py b/demos/speech_server/start_multi_progress_server.py new file mode 100644 index 000000000..5e86befb7 --- /dev/null +++ b/demos/speech_server/start_multi_progress_server.py @@ -0,0 +1,70 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import warnings + +import uvicorn +from fastapi import FastAPI +from starlette.middleware.cors import CORSMiddleware + +from paddlespeech.server.engine.engine_pool import init_engine_pool +from paddlespeech.server.restful.api import setup_router as setup_http_router +from paddlespeech.server.utils.config import get_config +from paddlespeech.server.ws.api import setup_router as setup_ws_router +warnings.filterwarnings("ignore") +import sys + +app = FastAPI( + title="PaddleSpeech Serving API", description="Api", version="0.0.1") +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"]) + +# change yaml file here +config_file = "./conf/application.yaml" +config = get_config(config_file) + +# init engine +if not init_engine_pool(config): + print("Failed to init engine.") + sys.exit(-1) + +# get api_router +api_list = list(engine.split("_")[0] for engine in config.engine_list) +if config.protocol == "websocket": + api_router = setup_ws_router(api_list) +elif config.protocol == "http": + api_router = setup_http_router(api_list) +else: + raise Exception("unsupported protocol") + sys.exit(-1) + +# app needs to operate outside the main function +app.include_router(api_router) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(add_help=True) + parser.add_argument( + "--workers", type=int, help="workers of server", default=1) + args = parser.parse_args() + + uvicorn.run( + "start_multi_progress_server:app", + host=config.host, + port=config.port, + debug=True, + workers=args.workers) diff --git a/paddlespeech/server/engine/engine_pool.py b/paddlespeech/server/engine/engine_pool.py index 5300303f6..298cf0bf8 100644 --- a/paddlespeech/server/engine/engine_pool.py +++ b/paddlespeech/server/engine/engine_pool.py @@ -11,6 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys +import time + +from paddlespeech.cli.log import logger from paddlespeech.server.engine.engine_factory import EngineFactory # global value @@ -24,6 +28,50 @@ def get_engine_pool() -> dict: return ENGINE_POOL +def warm_up(engine_and_type: str, engine, warm_up_time: int=3) -> bool: + if "tts" in engine_and_type: + if engine.lang == 'zh': + sentence = "您好,欢迎使用语音合成服务。" + elif engine.lang == 'en': + sentence = "Hello and welcome to the speech synthesis service." + else: + logger.error("tts engine only support lang: zh or en.") + sys.exit(-1) + + if engine_and_type == "tts_python": + from paddlespeech.server.engine.tts.python.tts_engine import TTSHandler + elif engine_and_type == "tts_inference": + from paddlespeech.server.engine.tts.paddleinference.tts_engine import TTSHandler + elif engine_and_type == "tts_online": + pass + elif engine_and_type == "tts_online-onnx": + pass + else: + logger.error("Please check tte engine type.") + + try: + logger.info("Start to warm up tts engine.") + for i in range(warm_up_time): + tts_handler = TTSHandler(engine) + st = time.time() + tts_handler.infer( + text=sentence, + lang=engine.config.lang, + am=engine.config.am, + spk_id=0, ) + logger.info( + f"The response time of the {i} warm up: {time.time() - st} s" + ) + except Exception as e: + logger.error("Failed to warm up on tts engine.") + logger.error(e) + return False + + else: + pass + return True + + def init_engine_pool(config) -> bool: """ Init engine pool """ @@ -38,4 +86,7 @@ def init_engine_pool(config) -> bool: if not ENGINE_POOL[engine].init(config=config[engine_and_type]): return False + if not warm_up(engine_and_type, ENGINE_POOL[engine]): + return False + return True diff --git a/paddlespeech/server/engine/tts/paddleinference/pretrained_models.py b/paddlespeech/server/engine/tts/paddleinference/pretrained_models.py index 9618a7a69..a3dc6fc2b 100644 --- a/paddlespeech/server/engine/tts/paddleinference/pretrained_models.py +++ b/paddlespeech/server/engine/tts/paddleinference/pretrained_models.py @@ -85,3 +85,40 @@ pretrained_models = { 24000, }, } + +model_alias = { + # acoustic model + "speedyspeech": + "paddlespeech.t2s.models.speedyspeech:SpeedySpeech", + "speedyspeech_inference": + "paddlespeech.t2s.models.speedyspeech:SpeedySpeechInference", + "fastspeech2": + "paddlespeech.t2s.models.fastspeech2:FastSpeech2", + "fastspeech2_inference": + "paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference", + "tacotron2": + "paddlespeech.t2s.models.tacotron2:Tacotron2", + "tacotron2_inference": + "paddlespeech.t2s.models.tacotron2:Tacotron2Inference", + # voc + "pwgan": + "paddlespeech.t2s.models.parallel_wavegan:PWGGenerator", + "pwgan_inference": + "paddlespeech.t2s.models.parallel_wavegan:PWGInference", + "mb_melgan": + "paddlespeech.t2s.models.melgan:MelGANGenerator", + "mb_melgan_inference": + "paddlespeech.t2s.models.melgan:MelGANInference", + "style_melgan": + "paddlespeech.t2s.models.melgan:StyleMelGANGenerator", + "style_melgan_inference": + "paddlespeech.t2s.models.melgan:StyleMelGANInference", + "hifigan": + "paddlespeech.t2s.models.hifigan:HiFiGANGenerator", + "hifigan_inference": + "paddlespeech.t2s.models.hifigan:HiFiGANInference", + "wavernn": + "paddlespeech.t2s.models.wavernn:WaveRNN", + "wavernn_inference": + "paddlespeech.t2s.models.wavernn:WaveRNNInference", +} diff --git a/paddlespeech/server/engine/tts/paddleinference/tts_engine.py b/paddlespeech/server/engine/tts/paddleinference/tts_engine.py index f1ce8b76e..7bce47e1a 100644 --- a/paddlespeech/server/engine/tts/paddleinference/tts_engine.py +++ b/paddlespeech/server/engine/tts/paddleinference/tts_engine.py @@ -23,9 +23,11 @@ import paddle import soundfile as sf from scipy.io import wavfile +from .pretrained_models import model_alias from .pretrained_models import pretrained_models from paddlespeech.cli.log import logger -from paddlespeech.cli.tts.infer import TTSExecutor +from paddlespeech.cli.utils import download_and_decompress +from paddlespeech.cli.utils import MODEL_HOME from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.audio_process import change_speed from paddlespeech.server.utils.errors import ErrorCode @@ -35,13 +37,72 @@ from paddlespeech.server.utils.paddle_predictor import run_model from paddlespeech.t2s.frontend import English from paddlespeech.t2s.frontend.zh_frontend import Frontend -__all__ = ['TTSEngine'] +__all__ = ['TTSEngine', 'TTSHandler'] -class TTSServerExecutor(TTSExecutor): +class TTSEngine(BaseEngine): + """TTS server engine + + Args: + metaclass: Defaults to Singleton. + """ + def __init__(self): - super().__init__() + """Initialize TTS server engine + """ + super(TTSEngine, self).__init__() + self.model_alias = model_alias self.pretrained_models = pretrained_models + self.engine_type = "inference" + + def init(self, config: dict) -> bool: + self.config = config + self.lang = self.config.lang + + try: + if self.config.am_predictor_conf.device is not None: + self.device = self.config.am_predictor_conf.device + elif self.config.voc_predictor_conf.device is not None: + self.device = self.config.voc_predictor_conf.device + else: + self.device = paddle.get_device() + paddle.set_device(self.device) + except Exception as e: + logger.error( + "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" + ) + logger.error("Initialize TTS server engine Failed on device: %s." % + (self.device)) + logger.error(e) + return False + + try: + self._init_from_path( + am=self.config.am, + am_model=self.config.am_model, + am_params=self.config.am_params, + am_sample_rate=self.config.am_sample_rate, + phones_dict=self.config.phones_dict, + tones_dict=self.config.tones_dict, + speaker_dict=self.config.speaker_dict, + voc=self.config.voc, + voc_model=self.config.voc_model, + voc_params=self.config.voc_params, + voc_sample_rate=self.config.voc_sample_rate, + lang=self.config.lang, + am_predictor_conf=self.config.am_predictor_conf, + voc_predictor_conf=self.config.voc_predictor_conf, ) + logger.info( + "Initialize TTS server engine successfully on device: %s." % + (self.device)) + except Exception as e: + logger.error("Failed to get model related files.") + logger.error("Initialize TTS server engine Failed on device: %s." % + (self.device)) + logger.error(e) + return False + + return True def _init_from_path( self, @@ -176,6 +237,31 @@ class TTSServerExecutor(TTSExecutor): predictor_conf=self.voc_predictor_conf) logger.info("Create Vocoder predictor successfully.") + def _get_pretrained_path(self, tag: str) -> os.PathLike: + """ + Download and returns pretrained resources path of current task. + """ + support_models = list(self.pretrained_models.keys()) + assert tag in self.pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format( + tag, '\n\t\t'.join(support_models)) + + res_path = os.path.join(MODEL_HOME, tag) + decompressed_path = download_and_decompress(self.pretrained_models[tag], + res_path) + decompressed_path = os.path.abspath(decompressed_path) + logger.info( + 'Use pretrained model stored in: {}'.format(decompressed_path)) + + return decompressed_path + + +class TTSHandler(): + def __init__(self, tts_engine): + """Initialize TTS server engine + """ + super(TTSHandler, self).__init__() + self.tts_engine = tts_engine + @paddle.no_grad() def infer(self, text: str, @@ -189,11 +275,12 @@ class TTSServerExecutor(TTSExecutor): am_dataset = am[am.rindex('_') + 1:] get_tone_ids = False merge_sentences = False + frontend_st = time.time() if am_name == 'speedyspeech': get_tone_ids = True if lang == 'zh': - input_ids = self.frontend.get_input_ids( + input_ids = self.tts_engine.frontend.get_input_ids( text, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) @@ -201,7 +288,7 @@ class TTSServerExecutor(TTSExecutor): if get_tone_ids: tone_ids = input_ids["tone_ids"] elif lang == 'en': - input_ids = self.frontend.get_input_ids( + input_ids = self.tts_engine.frontend.get_input_ids( text, merge_sentences=merge_sentences) phone_ids = input_ids["phone_ids"] else: @@ -218,7 +305,7 @@ class TTSServerExecutor(TTSExecutor): if am_name == 'speedyspeech': part_tone_ids = tone_ids[i] am_result = run_model( - self.am_predictor, + self.tts_engine.am_predictor, [part_phone_ids.numpy(), part_tone_ids.numpy()]) mel = am_result[0] @@ -228,14 +315,14 @@ class TTSServerExecutor(TTSExecutor): if am_dataset in {"aishell3", "vctk"}: pass else: - am_result = run_model(self.am_predictor, + am_result = run_model(self.tts_engine.am_predictor, [part_phone_ids.numpy()]) mel = am_result[0] self.am_time += (time.time() - am_st) # voc voc_st = time.time() - voc_result = run_model(self.voc_predictor, [mel]) + voc_result = run_model(self.tts_engine.voc_predictor, [mel]) wav = voc_result[0] wav = paddle.to_tensor(wav) @@ -245,85 +332,7 @@ class TTSServerExecutor(TTSExecutor): else: wav_all = paddle.concat([wav_all, wav]) self.voc_time += (time.time() - voc_st) - self._outputs['wav'] = wav_all - - -class TTSEngine(BaseEngine): - """TTS server engine - - Args: - metaclass: Defaults to Singleton. - """ - - def __init__(self): - """Initialize TTS server engine - """ - super(TTSEngine, self).__init__() - - def init(self, config: dict) -> bool: - self.executor = TTSServerExecutor() - self.config = config - - try: - if self.config.am_predictor_conf.device is not None: - self.device = self.config.am_predictor_conf.device - elif self.config.voc_predictor_conf.device is not None: - self.device = self.config.voc_predictor_conf.device - else: - self.device = paddle.get_device() - paddle.set_device(self.device) - except BaseException as e: - logger.error( - "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" - ) - logger.error("Initialize TTS server engine Failed on device: %s." % - (self.device)) - return False - - self.executor._init_from_path( - am=self.config.am, - am_model=self.config.am_model, - am_params=self.config.am_params, - am_sample_rate=self.config.am_sample_rate, - phones_dict=self.config.phones_dict, - tones_dict=self.config.tones_dict, - speaker_dict=self.config.speaker_dict, - voc=self.config.voc, - voc_model=self.config.voc_model, - voc_params=self.config.voc_params, - voc_sample_rate=self.config.voc_sample_rate, - lang=self.config.lang, - am_predictor_conf=self.config.am_predictor_conf, - voc_predictor_conf=self.config.voc_predictor_conf, ) - - # warm up - try: - self.warm_up() - logger.info("Warm up successfully.") - except Exception as e: - logger.error("Failed to warm up on tts engine.") - return False - - logger.info("Initialize TTS server engine successfully.") - return True - - def warm_up(self): - """warm up - """ - if self.config.lang == 'zh': - sentence = "您好,欢迎使用语音合成服务。" - if self.config.lang == 'en': - sentence = "Hello and welcome to the speech synthesis service." - logger.info("Start to warm up.") - for i in range(3): - st = time.time() - self.executor.infer( - text=sentence, - lang=self.config.lang, - am=self.config.am, - spk_id=0, ) - logger.info( - f"The response time of the {i} warm up: {time.time() - st} s") + self.output = wav_all def postprocess(self, wav, @@ -375,8 +384,9 @@ class TTSEngine(BaseEngine): ErrorCode.SERVER_INTERNAL_ERR, "Failed to transform speed. Can not install soxbindings on your system. \ You need to set speed value 1.0.") - except BaseException: + except Exception as e: logger.error("Failed to transform speed.") + logger.error(e) # wav to base64 buf = io.BytesIO() @@ -429,52 +439,59 @@ class TTSEngine(BaseEngine): wav_base64: The base64 format of the synthesized audio. """ - lang = self.config.lang + lang = self.tts_engine.lang try: infer_st = time.time() - self.executor.infer( - text=sentence, lang=lang, am=self.config.am, spk_id=spk_id) + self.infer( + text=sentence, + lang=lang, + am=self.tts_engine.config.am, + spk_id=spk_id) infer_et = time.time() infer_time = infer_et - infer_st except ServerBaseException: raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, "tts infer failed.") - except BaseException: + sys.exit(-1) + except Exception as e: logger.error("tts infer failed.") + logger.error(e) + sys.exit(-1) try: postprocess_st = time.time() target_sample_rate, wav_base64 = self.postprocess( - wav=self.executor._outputs['wav'].numpy(), - original_fs=self.executor.am_sample_rate, + wav=self.output.numpy(), + original_fs=self.tts_engine.am_sample_rate, target_fs=sample_rate, volume=volume, speed=speed, audio_path=save_path) postprocess_et = time.time() postprocess_time = postprocess_et - postprocess_st - duration = len(self.executor._outputs['wav'] - .numpy()) / self.executor.am_sample_rate + duration = len(self.output.numpy()) / self.tts_engine.am_sample_rate rtf = infer_time / duration except ServerBaseException: raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, "tts postprocess failed.") - except BaseException: + sys.exit(-1) + except Exception as e: logger.error("tts postprocess failed.") + logger.error(e) + sys.exit(-1) - logger.info("AM model: {}".format(self.config.am)) - logger.info("Vocoder model: {}".format(self.config.voc)) + logger.info("AM model: {}".format(self.tts_engine.config.am)) + logger.info("Vocoder model: {}".format(self.tts_engine.config.voc)) logger.info("Language: {}".format(lang)) logger.info("tts engine type: paddle inference") logger.info("audio duration: {}".format(duration)) - logger.info( - "frontend inference time: {}".format(self.executor.frontend_time)) - logger.info("AM inference time: {}".format(self.executor.am_time)) - logger.info("Vocoder inference time: {}".format(self.executor.voc_time)) + logger.info("frontend inference time: {}".format(self.frontend_time)) + logger.info("AM inference time: {}".format(self.am_time)) + logger.info("Vocoder inference time: {}".format(self.voc_time)) logger.info("total inference time: {}".format(infer_time)) logger.info( "postprocess (change speed, volume, target sample rate) time: {}". diff --git a/paddlespeech/server/engine/tts/python/tts_engine.py b/paddlespeech/server/engine/tts/python/tts_engine.py index d0002baa4..3d67f6483 100644 --- a/paddlespeech/server/engine/tts/python/tts_engine.py +++ b/paddlespeech/server/engine/tts/python/tts_engine.py @@ -13,32 +13,37 @@ # limitations under the License. import base64 import io +import os import time +from typing import Optional import librosa import numpy as np import paddle import soundfile as sf +import yaml from scipy.io import wavfile +from yacs.config import CfgNode +from .pretrained_models import model_alias +from .pretrained_models import pretrained_models from paddlespeech.cli.log import logger -from paddlespeech.cli.tts.infer import TTSExecutor +from paddlespeech.cli.utils import download_and_decompress +from paddlespeech.cli.utils import MODEL_HOME +from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.audio_process import change_speed from paddlespeech.server.utils.errors import ErrorCode from paddlespeech.server.utils.exception import ServerBaseException +from paddlespeech.t2s.frontend import English +from paddlespeech.t2s.frontend.zh_frontend import Frontend +from paddlespeech.t2s.modules.normalizer import ZScore -__all__ = ['TTSEngine'] - - -class TTSServerExecutor(TTSExecutor): - def __init__(self): - super().__init__() - pass +__all__ = ['TTSEngine', 'TTSHandler'] class TTSEngine(BaseEngine): - """TTS server engine + """TTS server engine for model setting Args: metaclass: Defaults to Singleton. @@ -48,10 +53,13 @@ class TTSEngine(BaseEngine): """Initialize TTS server engine """ super(TTSEngine, self).__init__() + self.model_alias = model_alias + self.pretrained_models = pretrained_models + self.engine_type = "python" def init(self, config: dict) -> bool: - self.executor = TTSServerExecutor() self.config = config + self.lang = self.config.lang try: if self.config.device is not None: @@ -59,16 +67,17 @@ class TTSEngine(BaseEngine): else: self.device = paddle.get_device() paddle.set_device(self.device) - except BaseException as e: + except Exception as e: logger.error( "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" ) logger.error("Initialize TTS server engine Failed on device: %s." % (self.device)) + logger.error(e) return False try: - self.executor._init_from_path( + self._init_from_path( am=self.config.am, am_config=self.config.am_config, am_ckpt=self.config.am_ckpt, @@ -81,41 +90,278 @@ class TTSEngine(BaseEngine): voc_ckpt=self.config.voc_ckpt, voc_stat=self.config.voc_stat, lang=self.config.lang) - except BaseException: + logger.info( + "Initialize TTS server engine successfully on device: %s." % + (self.device)) + except Exception as e: logger.error("Failed to get model related files.") logger.error("Initialize TTS server engine Failed on device: %s." % (self.device)) + logger.error(e) return False - # warm up - try: - self.warm_up() - logger.info("Warm up successfully.") - except Exception as e: - logger.error("Failed to warm up on tts engine.") - return False - - logger.info("Initialize TTS server engine successfully on device: %s." % - (self.device)) return True - def warm_up(self): - """warm up + def _init_from_path( + self, + am: str='fastspeech2_csmsc', + am_config: Optional[os.PathLike]=None, + am_ckpt: Optional[os.PathLike]=None, + am_stat: Optional[os.PathLike]=None, + phones_dict: Optional[os.PathLike]=None, + tones_dict: Optional[os.PathLike]=None, + speaker_dict: Optional[os.PathLike]=None, + voc: str='pwgan_csmsc', + voc_config: Optional[os.PathLike]=None, + voc_ckpt: Optional[os.PathLike]=None, + voc_stat: Optional[os.PathLike]=None, + lang: str='zh', ): + """ + Init model and other resources from a specific path. """ - if self.config.lang == 'zh': - sentence = "您好,欢迎使用语音合成服务。" - if self.config.lang == 'en': - sentence = "Hello and welcome to the speech synthesis service." - logger.info("Start to warm up.") - for i in range(3): - st = time.time() - self.executor.infer( - text=sentence, - lang=self.config.lang, - am=self.config.am, - spk_id=0, ) - logger.info( - f"The response time of the {i} warm up: {time.time() - st} s") + + if hasattr(self, 'am_inference') and hasattr(self, 'voc_inference'): + logger.info('Models had been initialized.') + return + # am + am_tag = am + '-' + lang + if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None: + am_res_path = self._get_pretrained_path(am_tag) + self.am_res_path = am_res_path + self.am_config = os.path.join( + am_res_path, self.pretrained_models[am_tag]['config']) + self.am_ckpt = os.path.join(am_res_path, + self.pretrained_models[am_tag]['ckpt']) + self.am_stat = os.path.join( + am_res_path, self.pretrained_models[am_tag]['speech_stats']) + # must have phones_dict in acoustic + self.phones_dict = os.path.join( + am_res_path, self.pretrained_models[am_tag]['phones_dict']) + logger.info(am_res_path) + logger.info(self.am_config) + logger.info(self.am_ckpt) + else: + self.am_config = os.path.abspath(am_config) + self.am_ckpt = os.path.abspath(am_ckpt) + self.am_stat = os.path.abspath(am_stat) + self.phones_dict = os.path.abspath(phones_dict) + self.am_res_path = os.path.dirname(os.path.abspath(self.am_config)) + + # for speedyspeech + self.tones_dict = None + if 'tones_dict' in self.pretrained_models[am_tag]: + self.tones_dict = os.path.join( + am_res_path, self.pretrained_models[am_tag]['tones_dict']) + if tones_dict: + self.tones_dict = tones_dict + + # for multi speaker fastspeech2 + self.speaker_dict = None + if 'speaker_dict' in self.pretrained_models[am_tag]: + self.speaker_dict = os.path.join( + am_res_path, self.pretrained_models[am_tag]['speaker_dict']) + if speaker_dict: + self.speaker_dict = speaker_dict + + # voc + voc_tag = voc + '-' + lang + if voc_ckpt is None or voc_config is None or voc_stat is None: + voc_res_path = self._get_pretrained_path(voc_tag) + self.voc_res_path = voc_res_path + self.voc_config = os.path.join( + voc_res_path, self.pretrained_models[voc_tag]['config']) + self.voc_ckpt = os.path.join( + voc_res_path, self.pretrained_models[voc_tag]['ckpt']) + self.voc_stat = os.path.join( + voc_res_path, self.pretrained_models[voc_tag]['speech_stats']) + logger.info(voc_res_path) + logger.info(self.voc_config) + logger.info(self.voc_ckpt) + else: + self.voc_config = os.path.abspath(voc_config) + self.voc_ckpt = os.path.abspath(voc_ckpt) + self.voc_stat = os.path.abspath(voc_stat) + self.voc_res_path = os.path.dirname( + os.path.abspath(self.voc_config)) + + # Init body. + with open(self.am_config) as f: + self.am_config = CfgNode(yaml.safe_load(f)) + with open(self.voc_config) as f: + self.voc_config = CfgNode(yaml.safe_load(f)) + + with open(self.phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + vocab_size = len(phn_id) + logger.info(f"vocab_size: {vocab_size}") + + tone_size = None + if self.tones_dict: + with open(self.tones_dict, "r") as f: + tone_id = [line.strip().split() for line in f.readlines()] + tone_size = len(tone_id) + logger.info(f"tone_size: {tone_size}") + + spk_num = None + if self.speaker_dict: + with open(self.speaker_dict, 'rt') as f: + spk_id = [line.strip().split() for line in f.readlines()] + spk_num = len(spk_id) + logger.info(f"spk_num: {spk_num}") + + # frontend + if lang == 'zh': + self.frontend = Frontend( + phone_vocab_path=self.phones_dict, + tone_vocab_path=self.tones_dict) + + elif lang == 'en': + self.frontend = English(phone_vocab_path=self.phones_dict) + logger.info("frontend done!") + + # acoustic model + odim = self.am_config.n_mels + # model: {model_name}_{dataset} + am_name = am[:am.rindex('_')] + + am_class = dynamic_import(am_name, self.model_alias) + am_inference_class = dynamic_import(am_name + '_inference', + self.model_alias) + + if am_name == 'fastspeech2': + am = am_class( + idim=vocab_size, + odim=odim, + spk_num=spk_num, + **self.am_config["model"]) + elif am_name == 'speedyspeech': + am = am_class( + vocab_size=vocab_size, + tone_size=tone_size, + **self.am_config["model"]) + elif am_name == 'tacotron2': + am = am_class(idim=vocab_size, odim=odim, **self.am_config["model"]) + + am.set_state_dict(paddle.load(self.am_ckpt)["main_params"]) + am.eval() + am_mu, am_std = np.load(self.am_stat) + am_mu = paddle.to_tensor(am_mu) + am_std = paddle.to_tensor(am_std) + am_normalizer = ZScore(am_mu, am_std) + self.am_inference = am_inference_class(am_normalizer, am) + self.am_inference.eval() + logger.info("acoustic model done!") + + # vocoder + # model: {model_name}_{dataset} + voc_name = voc[:voc.rindex('_')] + voc_class = dynamic_import(voc_name, self.model_alias) + voc_inference_class = dynamic_import(voc_name + '_inference', + self.model_alias) + if voc_name != 'wavernn': + voc = voc_class(**self.voc_config["generator_params"]) + voc.set_state_dict(paddle.load(self.voc_ckpt)["generator_params"]) + voc.remove_weight_norm() + voc.eval() + else: + voc = voc_class(**self.voc_config["model"]) + voc.set_state_dict(paddle.load(self.voc_ckpt)["main_params"]) + voc.eval() + voc_mu, voc_std = np.load(self.voc_stat) + voc_mu = paddle.to_tensor(voc_mu) + voc_std = paddle.to_tensor(voc_std) + voc_normalizer = ZScore(voc_mu, voc_std) + self.voc_inference = voc_inference_class(voc_normalizer, voc) + self.voc_inference.eval() + logger.info("voc done!") + + def _get_pretrained_path(self, tag: str) -> os.PathLike: + """ + Download and returns pretrained resources path of current task. + """ + support_models = list(self.pretrained_models.keys()) + assert tag in self.pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format( + tag, '\n\t\t'.join(support_models)) + + res_path = os.path.join(MODEL_HOME, tag) + decompressed_path = download_and_decompress(self.pretrained_models[tag], + res_path) + decompressed_path = os.path.abspath(decompressed_path) + logger.info( + 'Use pretrained model stored in: {}'.format(decompressed_path)) + + return decompressed_path + + +class TTSHandler(): + def __init__(self, tts_engine): + """Initialize TTS server engine + """ + super(TTSHandler, self).__init__() + self.tts_engine = tts_engine + + @paddle.no_grad() + def infer(self, + text: str, + lang: str='zh', + am: str='fastspeech2_csmsc', + spk_id: int=0): + """ + Model inference and result stored in self.output. + """ + am_name = am[:am.rindex('_')] + am_dataset = am[am.rindex('_') + 1:] + get_tone_ids = False + merge_sentences = False + frontend_st = time.time() + if am_name == 'speedyspeech': + get_tone_ids = True + if lang == 'zh': + input_ids = self.tts_engine.frontend.get_input_ids( + text, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids) + phone_ids = input_ids["phone_ids"] + if get_tone_ids: + tone_ids = input_ids["tone_ids"] + elif lang == 'en': + input_ids = self.tts_engine.frontend.get_input_ids( + text, merge_sentences=merge_sentences) + phone_ids = input_ids["phone_ids"] + else: + logger.info("lang should in {'zh', 'en'}!") + self.frontend_time = time.time() - frontend_st + + self.am_time = 0 + self.voc_time = 0 + flags = 0 + for i in range(len(phone_ids)): + am_st = time.time() + part_phone_ids = phone_ids[i] + # am + if am_name == 'speedyspeech': + part_tone_ids = tone_ids[i] + mel = self.tts_engine.am_inference(part_phone_ids, + part_tone_ids) + # fastspeech2 + else: + # multi speaker + if am_dataset in {"aishell3", "vctk"}: + mel = self.tts_engine.am_inference( + part_phone_ids, spk_id=paddle.to_tensor(spk_id)) + else: + mel = self.tts_engine.am_inference(part_phone_ids) + self.am_time += (time.time() - am_st) + # voc + voc_st = time.time() + wav = self.tts_engine.voc_inference(mel) + if flags == 0: + wav_all = wav + flags = 1 + else: + wav_all = paddle.concat([wav_all, wav]) + self.voc_time += (time.time() - voc_st) + self.output = wav_all def postprocess(self, wav, @@ -167,8 +413,9 @@ class TTSEngine(BaseEngine): ErrorCode.SERVER_INTERNAL_ERR, "Failed to transform speed. Can not install soxbindings on your system. \ You need to set speed value 1.0.") - except BaseException: + except Exception as e: logger.error("Failed to transform speed.") + logger.error(e) # wav to base64 buf = io.BytesIO() @@ -221,29 +468,34 @@ class TTSEngine(BaseEngine): wav_base64: The base64 format of the synthesized audio. """ - lang = self.config.lang + lang = self.tts_engine.config.lang try: infer_st = time.time() - self.executor.infer( - text=sentence, lang=lang, am=self.config.am, spk_id=spk_id) + self.infer( + text=sentence, + lang=self.tts_engine.config.lang, + am=self.tts_engine.config.am, + spk_id=spk_id) infer_et = time.time() infer_time = infer_et - infer_st - duration = len(self.executor._outputs['wav'] - .numpy()) / self.executor.am_config.fs + duration = len(self.output.numpy()) / self.tts_engine.am_config.fs rtf = infer_time / duration except ServerBaseException: raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, "tts infer failed.") - except BaseException: + sys.exit(-1) + except Exception as e: logger.error("tts infer failed.") + logger.error(e) + sys.exit(-1) try: postprocess_st = time.time() target_sample_rate, wav_base64 = self.postprocess( - wav=self.executor._outputs['wav'].numpy(), - original_fs=self.executor.am_config.fs, + wav=self.output.numpy(), + original_fs=self.tts_engine.am_config.fs, target_fs=sample_rate, volume=volume, speed=speed, @@ -254,19 +506,21 @@ class TTSEngine(BaseEngine): except ServerBaseException: raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, "tts postprocess failed.") - except BaseException: + sys.exit(-1) + except Exception as e: logger.error("tts postprocess failed.") + logger.error(e) + sys.exit(-1) - logger.info("AM model: {}".format(self.config.am)) - logger.info("Vocoder model: {}".format(self.config.voc)) + logger.info("AM model: {}".format(self.tts_engine.config.am)) + logger.info("Vocoder model: {}".format(self.tts_engine.config.voc)) logger.info("Language: {}".format(lang)) logger.info("tts engine type: python") logger.info("audio duration: {}".format(duration)) - logger.info( - "frontend inference time: {}".format(self.executor.frontend_time)) - logger.info("AM inference time: {}".format(self.executor.am_time)) - logger.info("Vocoder inference time: {}".format(self.executor.voc_time)) + logger.info("frontend inference time: {}".format(self.frontend_time)) + logger.info("AM inference time: {}".format(self.am_time)) + logger.info("Vocoder inference time: {}".format(self.voc_time)) logger.info("total inference time: {}".format(infer_time)) logger.info( "postprocess (change speed, volume, target sample rate) time: {}". @@ -274,6 +528,6 @@ class TTSEngine(BaseEngine): logger.info("total generate audio time: {}".format(infer_time + postprocess_time)) logger.info("RTF: {}".format(rtf)) - logger.info("device: {}".format(self.device)) + logger.info("device: {}".format(self.tts_engine.device)) return lang, target_sample_rate, duration, wav_base64 diff --git a/paddlespeech/server/restful/tts_api.py b/paddlespeech/server/restful/tts_api.py index 15d618d93..5d5863ade 100644 --- a/paddlespeech/server/restful/tts_api.py +++ b/paddlespeech/server/restful/tts_api.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys import traceback from typing import Union @@ -99,7 +100,16 @@ def tts(request_body: TTSRequest): tts_engine = engine_pool['tts'] logger.info("Get tts engine successfully.") - lang, target_sample_rate, duration, wav_base64 = tts_engine.run( + if tts_engine.engine_type == "python": + from paddlespeech.server.engine.tts.python.tts_engine import TTSHandler + elif tts_engine.engine_type == "inference": + from paddlespeech.server.engine.tts.paddleinference.tts_engine import TTSHandler + else: + logger.error("Offline tts engine only support python or inference.") + sys.exit(-1) + tts_handler = TTSHandler(tts_engine) + + lang, target_sample_rate, duration, wav_base64 = tts_handler.run( text, spk_id, speed, volume, sample_rate, save_path) response = {