diff --git a/paddlespeech/server/conf/tts_online_application.yaml b/paddlespeech/server/conf/tts_online_application.yaml index a80b3ecec..10abf0d43 100644 --- a/paddlespeech/server/conf/tts_online_application.yaml +++ b/paddlespeech/server/conf/tts_online_application.yaml @@ -7,7 +7,7 @@ host: 127.0.0.1 port: 8092 # The task format in the engin_list is: _ -# task choices = ['asr_online', 'tts_online'] +# task choices = ['tts_online', 'tts_online-onnx'] # protocol = ['websocket', 'http'] (only one can be selected). protocol: 'http' engine_list: ['tts_online'] @@ -20,8 +20,8 @@ engine_list: ['tts_online'] ################################### TTS ######################################### ################### speech task: tts; engine_type: online ####################### tts_online: - # am (acoustic model) choices=['fastspeech2_csmsc'] - am: 'fastspeech2_csmsc' + # am (acoustic model) choices=['fastspeech2_csmsc', 'fastspeech2_cnndecoder_csmsc'] + am: 'fastspeech2_cnndecoder_csmsc' am_config: am_ckpt: am_stat: @@ -30,7 +30,7 @@ tts_online: speaker_dict: spk_id: 0 - # voc (vocoder) choices=['mb_melgan_csmsc'] + # voc (vocoder) choices=['mb_melgan_csmsc, hifigan_csmsc'] voc: 'mb_melgan_csmsc' voc_config: voc_ckpt: @@ -38,9 +38,51 @@ tts_online: # others lang: 'zh' - device: # set 'gpu:id' or 'cpu' + device: 'cpu' # set 'gpu:id' or 'cpu' am_block: 42 am_pad: 12 voc_block: 14 voc_pad: 14 + + +################################################################################# +# ENGINE CONFIG # +################################################################################# + +################################### TTS ######################################### +################### speech task: tts; engine_type: online-onnx ####################### +tts_online-onnx: + # am (acoustic model) choices=['fastspeech2_csmsc_onnx', 'fastspeech2_cnndecoder_csmsc_onnx'] + am: 'fastspeech2_cnndecoder_csmsc_onnx' + # am_ckpt is a list, if am is fastspeech2_cnndecoder_csmsc_onnx, am_ckpt = [encoder model, decoder model, postnet model]; + # if am is fastspeech2_csmsc_onnx, am_ckpt = [ckpt model]; + am_ckpt: # list + am_stat: + phones_dict: + tones_dict: + speaker_dict: + spk_id: 0 + am_sample_rate: 24000 + am_sess_conf: + device: "cpu" # set 'gpu:id' or 'cpu' + use_trt: False + cpu_threads: 1 + + # voc (vocoder) choices=['mb_melgan_csmsc_onnx, hifigan_csmsc_onnx'] + voc: 'mb_melgan_csmsc_onnx' + voc_ckpt: + voc_sample_rate: 24000 + voc_sess_conf: + device: "cpu" # set 'gpu:id' or 'cpu' + use_trt: False + cpu_threads: 1 + + # others + lang: 'zh' + am_block: 42 + am_pad: 12 + voc_block: 14 + voc_pad: 14 + voc_upsample: 300 + diff --git a/paddlespeech/server/engine/engine_factory.py b/paddlespeech/server/engine/engine_factory.py index e147a29a6..9ebf137df 100644 --- a/paddlespeech/server/engine/engine_factory.py +++ b/paddlespeech/server/engine/engine_factory.py @@ -35,7 +35,10 @@ class EngineFactory(object): from paddlespeech.server.engine.tts.python.tts_engine import TTSEngine return TTSEngine() elif engine_name == 'tts' and engine_type == 'online': - from paddlespeech.server.engine.tts.online.tts_engine import TTSEngine + from paddlespeech.server.engine.tts.online.python.tts_engine import TTSEngine + return TTSEngine() + elif engine_name == 'tts' and engine_type == 'online-onnx': + from paddlespeech.server.engine.tts.online.onnx.tts_engine import TTSEngine return TTSEngine() elif engine_name == 'cls' and engine_type == 'inference': from paddlespeech.server.engine.cls.paddleinference.cls_engine import CLSEngine diff --git a/paddlespeech/server/engine/tts/online/onnx/__init__.py b/paddlespeech/server/engine/tts/online/onnx/__init__.py new file mode 100644 index 000000000..97043fd7b --- /dev/null +++ b/paddlespeech/server/engine/tts/online/onnx/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/engine/tts/online/onnx/tts_engine.py b/paddlespeech/server/engine/tts/online/onnx/tts_engine.py new file mode 100644 index 000000000..b30c635ae --- /dev/null +++ b/paddlespeech/server/engine/tts/online/onnx/tts_engine.py @@ -0,0 +1,582 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import base64 +import math +import os +import time +from typing import Optional + +import numpy as np +import paddle + +from paddlespeech.cli.log import logger +from paddlespeech.cli.tts.infer import TTSExecutor +from paddlespeech.cli.utils import download_and_decompress +from paddlespeech.cli.utils import MODEL_HOME +from paddlespeech.server.engine.base_engine import BaseEngine +from paddlespeech.server.utils.audio_process import float2pcm +from paddlespeech.server.utils.onnx_infer import get_sess +from paddlespeech.server.utils.util import denorm +from paddlespeech.server.utils.util import get_chunks +from paddlespeech.t2s.frontend import English +from paddlespeech.t2s.frontend.zh_frontend import Frontend + +__all__ = ['TTSEngine'] + +# support online model +pretrained_models = { + # fastspeech2 + "fastspeech2_csmsc_onnx-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_onnx_0.2.0.zip', + 'md5': + 'fd3ad38d83273ad51f0ea4f4abf3ab4e', + 'ckpt': ['fastspeech2_csmsc.onnx'], + 'phones_dict': + 'phone_id_map.txt', + 'sample_rate': + 24000, + }, + "fastspeech2_cnndecoder_csmsc_onnx-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip', + 'md5': + '5f70e1a6bcd29d72d54e7931aa86f266', + 'ckpt': [ + 'fastspeech2_csmsc_am_encoder_infer.onnx', + 'fastspeech2_csmsc_am_decoder.onnx', + 'fastspeech2_csmsc_am_postnet.onnx', + ], + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + 'sample_rate': + 24000, + }, + + # mb_melgan + "mb_melgan_csmsc_onnx-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_onnx_0.2.0.zip', + 'md5': + '5b83ec746e8414bc29032d954ffd07ec', + 'ckpt': + 'mb_melgan_csmsc.onnx', + 'sample_rate': + 24000, + }, + + # hifigan + "hifigan_csmsc_onnx-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_onnx_0.2.0.zip', + 'md5': + '1a7dc0385875889e46952e50c0994a6b', + 'ckpt': + 'hifigan_csmsc.onnx', + 'sample_rate': + 24000, + }, +} + +model_alias = { + # acoustic model + "fastspeech2": + "paddlespeech.t2s.models.fastspeech2:FastSpeech2", + "fastspeech2_inference": + "paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference", + + # voc + "mb_melgan": + "paddlespeech.t2s.models.melgan:MelGANGenerator", + "mb_melgan_inference": + "paddlespeech.t2s.models.melgan:MelGANInference", + "hifigan": + "paddlespeech.t2s.models.hifigan:HiFiGANGenerator", + "hifigan_inference": + "paddlespeech.t2s.models.hifigan:HiFiGANInference", +} + +__all__ = ['TTSEngine'] + + +class TTSServerExecutor(TTSExecutor): + def __init__(self, am_block, am_pad, voc_block, voc_pad, voc_upsample): + super().__init__() + self.am_block = am_block + self.am_pad = am_pad + self.voc_block = voc_block + self.voc_pad = voc_pad + self.voc_upsample = voc_upsample + + self.pretrained_models = pretrained_models + self.model_alias = model_alias + + + def _get_pretrained_path(self, tag: str) -> os.PathLike: + """ + #Download and returns pretrained resources path of current task. + """ + support_models = list(pretrained_models.keys()) + assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format( + tag, '\n\t\t'.join(support_models)) + + res_path = os.path.join(MODEL_HOME, tag) + decompressed_path = download_and_decompress(pretrained_models[tag], + res_path) + decompressed_path = os.path.abspath(decompressed_path) + logger.info( + 'Use pretrained model stored in: {}'.format(decompressed_path)) + return decompressed_path + + + def _init_from_path( + self, + am: str='fastspeech2_csmsc_onnx', + am_ckpt: Optional[list]=None, + am_stat: Optional[os.PathLike]=None, + phones_dict: Optional[os.PathLike]=None, + tones_dict: Optional[os.PathLike]=None, + speaker_dict: Optional[os.PathLike]=None, + am_sample_rate: int=24000, + am_sess_conf: dict=None, + voc: str='mb_melgan_csmsc_onnx', + voc_ckpt: Optional[os.PathLike]=None, + voc_sample_rate: int=24000, + voc_sess_conf: dict=None, + lang: str='zh', ): + """ + Init model and other resources from a specific path. + """ + + if (hasattr(self, 'am_sess') or + (hasattr(self, 'am_encoder_infer_sess') and + hasattr(self, 'am_decoder_sess') and hasattr( + self, 'am_postnet_sess'))) and hasattr(self, 'voc_inference'): + logger.info('Models had been initialized.') + return + # am + am_tag = am + '-' + lang + if am == "fastspeech2_csmsc_onnx": + # get model info + if am_ckpt is None or phones_dict is None: + am_res_path = self._get_pretrained_path(am_tag) + self.am_res_path = am_res_path + self.am_ckpt = os.path.join( + am_res_path, pretrained_models[am_tag]['ckpt'][0]) + # must have phones_dict in acoustic + self.phones_dict = os.path.join( + am_res_path, pretrained_models[am_tag]['phones_dict']) + + else: + self.am_ckpt = os.path.abspath(am_ckpt[0]) + self.phones_dict = os.path.abspath(phones_dict) + self.am_res_path = os.path.dirname( + os.path.abspath(self.am_ckpt)) + + # create am sess + self.am_sess = get_sess(self.am_ckpt, am_sess_conf) + + elif am == "fastspeech2_cnndecoder_csmsc_onnx": + if am_ckpt is None or am_stat is None or phones_dict is None: + am_res_path = self._get_pretrained_path(am_tag) + self.am_res_path = am_res_path + self.am_encoder_infer = os.path.join( + am_res_path, pretrained_models[am_tag]['ckpt'][0]) + self.am_decoder = os.path.join( + am_res_path, pretrained_models[am_tag]['ckpt'][1]) + self.am_postnet = os.path.join( + am_res_path, pretrained_models[am_tag]['ckpt'][2]) + # must have phones_dict in acoustic + self.phones_dict = os.path.join( + am_res_path, pretrained_models[am_tag]['phones_dict']) + self.am_stat = os.path.join( + am_res_path, pretrained_models[am_tag]['speech_stats']) + + else: + self.am_encoder_infer = os.path.abspath(am_ckpt[0]) + self.am_decoder = os.path.abspath(am_ckpt[1]) + self.am_postnet = os.path.abspath(am_ckpt[2]) + self.phones_dict = os.path.abspath(phones_dict) + self.am_stat = os.path.abspath(am_stat) + self.am_res_path = os.path.dirname( + os.path.abspath(self.am_ckpt)) + + # create am sess + self.am_encoder_infer_sess = get_sess(self.am_encoder_infer, + am_sess_conf) + self.am_decoder_sess = get_sess(self.am_decoder, am_sess_conf) + self.am_postnet_sess = get_sess(self.am_postnet, am_sess_conf) + + self.am_mu, self.am_std = np.load(self.am_stat) + + logger.info(f"self.phones_dict: {self.phones_dict}") + logger.info(f"am model dir: {self.am_res_path}") + logger.info("Create am sess successfully.") + + # voc model info + voc_tag = voc + '-' + lang + if voc_ckpt is None: + voc_res_path = self._get_pretrained_path(voc_tag) + self.voc_res_path = voc_res_path + self.voc_ckpt = os.path.join(voc_res_path, + pretrained_models[voc_tag]['ckpt']) + else: + self.voc_ckpt = os.path.abspath(voc_ckpt) + self.voc_res_path = os.path.dirname(os.path.abspath(self.voc_ckpt)) + logger.info(self.voc_res_path) + + # create voc sess + self.voc_sess = get_sess(self.voc_ckpt, voc_sess_conf) + logger.info("Create voc sess successfully.") + + with open(self.phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + self.vocab_size = len(phn_id) + logger.info(f"vocab_size: {self.vocab_size}") + + # frontend + self.tones_dict = None + if lang == 'zh': + self.frontend = Frontend( + phone_vocab_path=self.phones_dict, + tone_vocab_path=self.tones_dict) + + elif lang == 'en': + self.frontend = English(phone_vocab_path=self.phones_dict) + logger.info("frontend done!") + + def depadding(self, data, chunk_num, chunk_id, block, pad, upsample): + """ + Streaming inference removes the result of pad inference + """ + front_pad = min(chunk_id * block, pad) + # first chunk + if chunk_id == 0: + data = data[:block * upsample] + # last chunk + elif chunk_id == chunk_num - 1: + data = data[front_pad * upsample:] + # middle chunk + else: + data = data[front_pad * upsample:(front_pad + block) * upsample] + + return data + + @paddle.no_grad() + def infer( + self, + text: str, + lang: str='zh', + am: str='fastspeech2_csmsc_onnx', + spk_id: int=0, ): + """ + Model inference and result stored in self.output. + """ + #import pdb;pdb.set_trace() + + am_block = self.am_block + am_pad = self.am_pad + am_upsample = 1 + voc_block = self.voc_block + voc_pad = self.voc_pad + voc_upsample = self.voc_upsample + # first_flag 用于标记首包 + first_flag = 1 + get_tone_ids = False + merge_sentences = False + + # front + frontend_st = time.time() + if lang == 'zh': + input_ids = self.frontend.get_input_ids( + text, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids) + phone_ids = input_ids["phone_ids"] + if get_tone_ids: + tone_ids = input_ids["tone_ids"] + elif lang == 'en': + input_ids = self.frontend.get_input_ids( + text, merge_sentences=merge_sentences) + phone_ids = input_ids["phone_ids"] + else: + logger.error("lang should in {'zh', 'en'}!") + frontend_et = time.time() + self.frontend_time = frontend_et - frontend_st + + for i in range(len(phone_ids)): + part_phone_ids = phone_ids[i].numpy() + voc_chunk_id = 0 + + # fastspeech2_csmsc + if am == "fastspeech2_csmsc_onnx": + # am + mel = self.am_sess.run( + output_names=None, input_feed={'text': part_phone_ids}) + mel = mel[0] + if first_flag == 1: + first_am_et = time.time() + self.first_am_infer = first_am_et - frontend_et + + # voc streaming + mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc") + voc_chunk_num = len(mel_chunks) + voc_st = time.time() + for i, mel_chunk in enumerate(mel_chunks): + sub_wav = self.voc_sess.run( + output_names=None, input_feed={'logmel': mel_chunk}) + sub_wav = self.depadding(sub_wav[0], voc_chunk_num, i, + voc_block, voc_pad, voc_upsample) + if first_flag == 1: + first_voc_et = time.time() + self.first_voc_infer = first_voc_et - first_am_et + self.first_response_time = first_voc_et - frontend_st + first_flag = 0 + + yield sub_wav + + # fastspeech2_cnndecoder_csmsc + elif am == "fastspeech2_cnndecoder_csmsc_onnx": + # am + orig_hs = self.am_encoder_infer_sess.run( + None, input_feed={'text': part_phone_ids}) + orig_hs = orig_hs[0] + + # streaming voc chunk info + mel_len = orig_hs.shape[1] + voc_chunk_num = math.ceil(mel_len / self.voc_block) + start = 0 + end = min(self.voc_block + self.voc_pad, mel_len) + + # streaming am + hss = get_chunks(orig_hs, self.am_block, self.am_pad, "am") + am_chunk_num = len(hss) + for i, hs in enumerate(hss): + am_decoder_output = self.am_decoder_sess.run( + None, input_feed={'xs': hs}) + am_postnet_output = self.am_postnet_sess.run( + None, + input_feed={ + 'xs': np.transpose(am_decoder_output[0], (0, 2, 1)) + }) + am_output_data = am_decoder_output + np.transpose( + am_postnet_output[0], (0, 2, 1)) + normalized_mel = am_output_data[0][0] + + sub_mel = denorm(normalized_mel, self.am_mu, self.am_std) + sub_mel = self.depadding(sub_mel, am_chunk_num, i, am_block, + am_pad, am_upsample) + + if i == 0: + mel_streaming = sub_mel + else: + mel_streaming = np.concatenate( + (mel_streaming, sub_mel), axis=0) + + # streaming voc + # 当流式AM推理的mel帧数大于流式voc推理的chunk size,开始进行流式voc 推理 + while (mel_streaming.shape[0] >= end and + voc_chunk_id < voc_chunk_num): + if first_flag == 1: + first_am_et = time.time() + self.first_am_infer = first_am_et - frontend_et + voc_chunk = mel_streaming[start:end, :] + + sub_wav = self.voc_sess.run( + output_names=None, input_feed={'logmel': voc_chunk}) + sub_wav = self.depadding(sub_wav[0], voc_chunk_num, + voc_chunk_id, voc_block, + voc_pad, voc_upsample) + if first_flag == 1: + first_voc_et = time.time() + self.first_voc_infer = first_voc_et - first_am_et + self.first_response_time = first_voc_et - frontend_st + first_flag = 0 + + yield sub_wav + + voc_chunk_id += 1 + start = max(0, voc_chunk_id * voc_block - voc_pad) + end = min((voc_chunk_id + 1) * voc_block + voc_pad, + mel_len) + + else: + logger.error( + "Only support fastspeech2_csmsc or fastspeech2_cnndecoder_csmsc on streaming tts." + ) + + self.final_response_time = time.time() - frontend_st + + +class TTSEngine(BaseEngine): + """TTS server engine + + Args: + metaclass: Defaults to Singleton. + """ + + def __init__(self, name=None): + """Initialize TTS server engine + """ + super().__init__() + + def init(self, config: dict) -> bool: + + self.config = config + print("aaaaaaaaaaaaaaaaaaaaaaaaa: ", self.config.am) + print("vvvvvvvvvvvvvvvvvvvvvvvvv: ", self.config.voc) + + assert ( + self.config.am == "fastspeech2_csmsc_onnx" or + self.config.am == "fastspeech2_cnndecoder_csmsc_onnx" + ) and ( + self.config.voc == "hifigan_csmsc_onnx" or + self.config.voc == "mb_melgan_csmsc_onnx" + ), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.' + + assert ( + self.config.voc_block > 0 and self.config.voc_pad > 0 + ), "Please set correct voc_block and voc_pad, they should be more than 0." + + assert ( + self.config.voc_sample_rate == self.config.am_sample_rate + ), "The sample rate of AM and Vocoder model are different, please check model." + + self.executor = TTSServerExecutor( + self.config.am_block, self.config.am_pad, self.config.voc_block, + self.config.voc_pad, self.config.voc_upsample) + + if "cpu" in self.config.am_sess_conf.device or "cpu" in self.config.voc_sess_conf.device: + paddle.set_device("cpu") + else: + paddle.set_device(self.config.am_sess_conf.device) + + try: + self.executor._init_from_path( + am=self.config.am, + am_ckpt=self.config.am_ckpt, + am_stat=self.config.am_stat, + phones_dict=self.config.phones_dict, + tones_dict=self.config.tones_dict, + speaker_dict=self.config.speaker_dict, + am_sample_rate=self.config.am_sample_rate, + am_sess_conf=self.config.am_sess_conf, + voc=self.config.voc, + voc_ckpt=self.config.voc_ckpt, + voc_sample_rate=self.config.voc_sample_rate, + voc_sess_conf=self.config.voc_sess_conf, + lang=self.config.lang) + + except Exception as e: + logger.error("Failed to get model related files.") + logger.error("Initialize TTS server engine Failed on device: %s." % + (self.config.voc_sess_conf.device)) + return False + + logger.info("Initialize TTS server engine successfully on device: %s." % + (self.config.voc_sess_conf.device)) + + # warm up + try: + self.warm_up() + except Exception as e: + logger.error("Failed to warm up on tts engine.") + return False + + return True + + def warm_up(self): + """warm up + """ + if self.config.lang == 'zh': + sentence = "您好,欢迎使用语音合成服务。" + if self.config.lang == 'en': + sentence = "Hello and welcome to the speech synthesis service." + logger.info( + "*******************************warm up ********************************" + ) + for i in range(3): + for wav in self.executor.infer( + text=sentence, + lang=self.config.lang, + am=self.config.am, + spk_id=0, ): + logger.info( + f"The first response time of the {i} warm up: {self.executor.first_response_time} s" + ) + break + logger.info( + "**********************************************************************" + ) + + def preprocess(self, text_bese64: str=None, text_bytes: bytes=None): + # Convert byte to text + if text_bese64: + text_bytes = base64.b64decode(text_bese64) # base64 to bytes + text = text_bytes.decode('utf-8') # bytes to text + + return text + + def run(self, + sentence: str, + spk_id: int=0, + speed: float=1.0, + volume: float=1.0, + sample_rate: int=0, + save_path: str=None): + """ run include inference and postprocess. + + Args: + sentence (str): text to be synthesized + spk_id (int, optional): speaker id for multi-speaker speech synthesis. Defaults to 0. + speed (float, optional): speed. Defaults to 1.0. + volume (float, optional): volume. Defaults to 1.0. + sample_rate (int, optional): target sample rate for synthesized audio, + 0 means the same as the model sampling rate. Defaults to 0. + save_path (str, optional): The save path of the synthesized audio. + None means do not save audio. Defaults to None. + + Returns: + wav_base64: The base64 format of the synthesized audio. + """ + wav_list = [] + + for wav in self.executor.infer( + text=sentence, + lang=self.config.lang, + am=self.config.am, + spk_id=spk_id, ): + + # wav type: float32, convert to pcm (base64) + wav = float2pcm(wav) # float32 to int16 + wav_bytes = wav.tobytes() # to bytes + wav_base64 = base64.b64encode(wav_bytes).decode('utf8') # to base64 + wav_list.append(wav) + + yield wav_base64 + + wav_all = np.concatenate(wav_list, axis=0) + duration = len(wav_all) / self.config.voc_sample_rate + logger.info(f"sentence: {sentence}") + logger.info(f"The durations of audio is: {duration} s") + logger.info( + f"first response time: {self.executor.first_response_time} s") + logger.info( + f"final response time: {self.executor.final_response_time} s") + logger.info(f"RTF: {self.executor.final_response_time / duration}") + logger.info( + f"Other info: front time: {self.executor.frontend_time} s, first am infer time: {self.executor.first_am_infer} s, first voc infer time: {self.executor.first_voc_infer} s," + ) diff --git a/paddlespeech/server/engine/tts/online/python/__init__.py b/paddlespeech/server/engine/tts/online/python/__init__.py new file mode 100644 index 000000000..97043fd7b --- /dev/null +++ b/paddlespeech/server/engine/tts/online/python/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/engine/tts/online/tts_engine.py b/paddlespeech/server/engine/tts/online/python/tts_engine.py similarity index 97% rename from paddlespeech/server/engine/tts/online/tts_engine.py rename to paddlespeech/server/engine/tts/online/python/tts_engine.py index c9135b884..6981a4132 100644 --- a/paddlespeech/server/engine/tts/online/tts_engine.py +++ b/paddlespeech/server/engine/tts/online/python/tts_engine.py @@ -202,6 +202,7 @@ class TTSServerExecutor(TTSExecutor): """ Init model and other resources from a specific path. """ + #import pdb;pdb.set_trace() if hasattr(self, 'am_inference') and hasattr(self, 'voc_inference'): logger.info('Models had been initialized.') return @@ -302,23 +303,6 @@ class TTSServerExecutor(TTSExecutor): self.voc_inference.eval() print("voc done!") - def get_phone(self, sentence, lang, merge_sentences, get_tone_ids): - tone_ids = None - if lang == 'zh': - input_ids = self.frontend.get_input_ids( - sentence, - merge_sentences=merge_sentences, - get_tone_ids=get_tone_ids) - phone_ids = input_ids["phone_ids"] - if get_tone_ids: - tone_ids = input_ids["tone_ids"] - elif lang == 'en': - input_ids = self.frontend.get_input_ids( - sentence, merge_sentences=merge_sentences) - phone_ids = input_ids["phone_ids"] - else: - print("lang should in {'zh', 'en'}!") - def depadding(self, data, chunk_num, chunk_id, block, pad, upsample): """ Streaming inference removes the result of pad inference @@ -479,6 +463,7 @@ class TTSEngine(BaseEngine): def __init__(self, name=None): """Initialize TTS server engine """ + #super(TTSEngine, self).__init__() super().__init__() def init(self, config: dict) -> bool: diff --git a/paddlespeech/server/utils/onnx_infer.py b/paddlespeech/server/utils/onnx_infer.py new file mode 100644 index 000000000..ac11c534b --- /dev/null +++ b/paddlespeech/server/utils/onnx_infer.py @@ -0,0 +1,36 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Optional + +import onnxruntime as ort + + +def get_sess(model_path: Optional[os.PathLike]=None, sess_conf: dict=None): + sess_options = ort.SessionOptions() + sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL + + if "gpu" in sess_conf["device"]: + # fastspeech2/mb_melgan can't use trt now! + if sess_conf["use_trt"]: + providers = ['TensorrtExecutionProvider'] + else: + providers = ['CUDAExecutionProvider'] + elif sess_conf["device"] == "cpu": + providers = ['CPUExecutionProvider'] + sess_options.intra_op_num_threads = sess_conf["cpu_threads"] + sess = ort.InferenceSession( + model_path, providers=providers, sess_options=sess_options) + return sess diff --git a/paddlespeech/server/ws/tts_socket.py b/paddlespeech/server/ws/tts_socket.py index 11458b3cf..699ee412b 100644 --- a/paddlespeech/server/ws/tts_socket.py +++ b/paddlespeech/server/ws/tts_socket.py @@ -51,7 +51,6 @@ async def websocket_endpoint(websocket: WebSocket): tts_results = next(wav_generator) resp = {"status": 1, "audio": tts_results} await websocket.send_json(resp) - logger.info("streaming audio...") except StopIteration as e: resp = {"status": 2, "audio": ''} await websocket.send_json(resp) diff --git a/setup.py b/setup.py index 82ff63412..8053c9b2f 100644 --- a/setup.py +++ b/setup.py @@ -42,6 +42,7 @@ base = [ "loguru", "matplotlib", "nara_wpe", + "onnxruntime", "pandas", "paddleaudio", "paddlenlp", @@ -64,12 +65,16 @@ base = [ "webrtcvad", "yacs~=0.1.8", "prettytable", + "zhon", ] server = [ "fastapi", "uvicorn", "pattern_singleton", + "websockets", + "websocket", + "websocket-client", ] requirements = { @@ -90,7 +95,6 @@ requirements = { "unidecode", "yq", "pre-commit", - "zhon", ] } diff --git a/tests/unit/server/change_yaml.py b/tests/unit/server/offline/change_yaml.py similarity index 100% rename from tests/unit/server/change_yaml.py rename to tests/unit/server/offline/change_yaml.py diff --git a/tests/unit/server/conf/application.yaml b/tests/unit/server/offline/conf/application.yaml similarity index 100% rename from tests/unit/server/conf/application.yaml rename to tests/unit/server/offline/conf/application.yaml diff --git a/tests/unit/server/test_server_client.sh b/tests/unit/server/offline/test_server_client.sh similarity index 100% rename from tests/unit/server/test_server_client.sh rename to tests/unit/server/offline/test_server_client.sh diff --git a/tests/unit/server/online/tts/change_yaml.py b/tests/unit/server/online/tts/change_yaml.py new file mode 100644 index 000000000..98fd13722 --- /dev/null +++ b/tests/unit/server/online/tts/change_yaml.py @@ -0,0 +1,140 @@ +#!/usr/bin/python +import argparse +import os + +import yaml +""" +def change_value1(yamlfile: str, target_key: str, target_value: str, engine: str="tts_online"): + tmp_yamlfile = yamlfile.split(".yaml")[0] + "_tmp.yaml" + os.system("cp %s %s" % (yamlfile, tmp_yamlfile)) + + with open(tmp_yamlfile) as f, open(yamlfile, "w+", encoding="utf-8") as fw: + y = yaml.safe_load(f) + y[engine][target_key] = target_value + + print(yaml.dump(y, default_flow_style=False, sort_keys=False)) + yaml.dump(y, fw, allow_unicode=True) + os.system("rm %s" % (tmp_yamlfile)) + print(f"Change key: {target_key} to value: {target_value} successfully.") + +def change_protocol(yamlfile: str, target_key: str, target_value: str): + tmp_yamlfile = yamlfile.split(".yaml")[0] + "_tmp.yaml" + os.system("cp %s %s" % (yamlfile, tmp_yamlfile)) + + with open(tmp_yamlfile) as f, open(yamlfile, "w+", encoding="utf-8") as fw: + y = yaml.safe_load(f) + y[target_key] = target_value + + print(yaml.dump(y, default_flow_style=False, sort_keys=False)) + yaml.dump(y, fw, allow_unicode=True) + os.system("rm %s" % (tmp_yamlfile)) + print(f"Change key: {target_key} to value: {target_value} successfully.") + +def change_engine_type(yamlfile: str, target_key: str, target_value: str): + tmp_yamlfile = yamlfile.split(".yaml")[0] + "_tmp.yaml" + os.system("cp %s %s" % (yamlfile, tmp_yamlfile)) + + with open(tmp_yamlfile) as f, open(yamlfile, "w+", encoding="utf-8") as fw: + y = yaml.safe_load(f) + y[target_key] = [target_value] + + print(yaml.dump(y, default_flow_style=False, sort_keys=False)) + yaml.dump(y, fw, allow_unicode=True) + os.system("rm %s" % (tmp_yamlfile)) + print(f"Change key: {target_key} to value: {target_value} successfully.") +""" + + +def change_value(args): + yamlfile = args.config_file + change_type = args.change_type + engine_type = args.engine_type + target_key = args.target_key + target_value = args.target_value + + tmp_yamlfile = yamlfile.split(".yaml")[0] + "_tmp.yaml" + os.system("cp %s %s" % (yamlfile, tmp_yamlfile)) + + with open(tmp_yamlfile) as f, open(yamlfile, "w+", encoding="utf-8") as fw: + y = yaml.safe_load(f) + + if change_type == "model": + if engine_type == "tts_online-onnx": + target_value = target_value + "_onnx" + y[engine_type][target_key] = target_value + elif change_type == "protocol": + assert (target_key == "protocol" and ( + target_value == "http" or target_value == "websocket" + )), "if change_type is protocol, target_key must be set protocol." + y[target_key] = target_value + elif change_type == "engine_type": + assert ( + target_key == "engine_list" + ), "if change_type is engine_type, target_key must be set engine_list." + y[target_key] = [target_value] + elif change_type == "device": + assert ( + target_key == "device" + ), "if change_type is device, target_key must be set device." + if y["engine_list"][0] == "tts_online": + y["tts_online"]["device"] = target_value + elif y["engine_list"][0] == "tts_online-onnx": + y["tts_online-onnx"]["am_sess_conf"]["device"] = target_value + y["tts_online-onnx"]["voc_sess_conf"]["device"] = target_value + else: + print( + "Error engine_list, please set tts_online or tts_online-onnx" + ) + + else: + print("Error change_type, please set correct change_type.") + + print(yaml.dump(y, default_flow_style=False, sort_keys=False)) + yaml.dump(y, fw, allow_unicode=True) + os.system("rm %s" % (tmp_yamlfile)) + print(f"Change key: {target_key} to value: {target_value} successfully.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + '--config_file', + type=str, + default='./conf/application.yaml', + help='server yaml file.') + parser.add_argument( + '--change_type', + type=str, + default="model", + choices=["model", "protocol", "engine_type", "device"], + help='change protocol', ) + parser.add_argument( + '--engine_type', + type=str, + default="tts_online", + help='engine type', + choices=["tts_online", "tts_online-onnx"]) + parser.add_argument( + '--target_key', + type=str, + default=None, + help='Change key', + required=True) + parser.add_argument( + '--target_value', + type=str, + default=None, + help='target value', + required=True) + + args = parser.parse_args() + + change_value(args) + """ + if args.change_type == "model": + change_value(args.config_file, args.target_key, args.target_value, args.engine) + elif args.change_type == "protocol": + change_protocol(args.config_file, args.target_key, args.target_value) + else: + print("Please set correct change type, model or protocol") + """ diff --git a/tests/unit/server/online/tts/conf/application.yaml b/tests/unit/server/online/tts/conf/application.yaml new file mode 100644 index 000000000..347411b66 --- /dev/null +++ b/tests/unit/server/online/tts/conf/application.yaml @@ -0,0 +1,88 @@ +# This is the parameter configuration file for PaddleSpeech Serving. + +################################################################################# +# SERVER SETTING # +################################################################################# +host: 127.0.0.1 +port: 8092 + +# The task format in the engin_list is: _ +# task choices = ['tts_online', 'tts_online-onnx'] +# protocol = ['websocket', 'http'] (only one can be selected). +protocol: 'http' +engine_list: ['tts_online'] + + +################################################################################# +# ENGINE CONFIG # +################################################################################# + +################################### TTS ######################################### +################### speech task: tts; engine_type: online ####################### +tts_online: + # am (acoustic model) choices=['fastspeech2_csmsc', 'fastspeech2_cnndecoder_csmsc'] + am: 'fastspeech2_cnndecoder_csmsc' + am_config: + am_ckpt: + am_stat: + phones_dict: + tones_dict: + speaker_dict: + spk_id: 0 + + # voc (vocoder) choices=['mb_melgan_csmsc', 'hifigan_csmsc'] + voc: 'mb_melgan_csmsc' + voc_config: + voc_ckpt: + voc_stat: + + # others + lang: 'zh' + device: 'cpu' # set 'gpu:id' or 'cpu' + am_block: 42 + am_pad: 12 + voc_block: 14 + voc_pad: 14 + + + +################################################################################# +# ENGINE CONFIG # +################################################################################# + +################################### TTS ######################################### +################### speech task: tts; engine_type: online-onnx ####################### +tts_online-onnx: + # am (acoustic model) choices=['fastspeech2_csmsc_onnx', 'fastspeech2_cnndecoder_csmsc_onnx'] + am: 'fastspeech2_cnndecoder_csmsc_onnx' + # am_ckpt is a list, if am is fastspeech2_cnndecoder_csmsc_onnx, am_ckpt = [encoder model, decoder model, postnet model]; + # if am is fastspeech2_csmsc_onnx, am_ckpt = [ckpt model]; + am_ckpt: # list + am_stat: + phones_dict: + tones_dict: + speaker_dict: + spk_id: 0 + am_sample_rate: 24000 + am_sess_conf: + device: "cpu" # set 'gpu:id' or 'cpu' + use_trt: False + cpu_threads: 1 + + # voc (vocoder) choices=['mb_melgan_csmsc_onnx', 'hifigan_csmsc_onnx'] + voc: 'mb_melgan_csmsc_onnx' + voc_ckpt: + voc_sample_rate: 24000 + voc_sess_conf: + device: "cpu" # set 'gpu:id' or 'cpu' + use_trt: False + cpu_threads: 1 + + # others + lang: 'zh' + am_block: 42 + am_pad: 12 + voc_block: 14 + voc_pad: 14 + voc_upsample: 300 + diff --git a/tests/unit/server/online/tts/http_client.py b/tests/unit/server/online/tts/http_client.py new file mode 100644 index 000000000..cbc1f5c02 --- /dev/null +++ b/tests/unit/server/online/tts/http_client.py @@ -0,0 +1,100 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import base64 +import json +import os +import time + +import requests + +from paddlespeech.server.utils.audio_process import pcm2wav + + +def save_audio(buffer, audio_path) -> bool: + if args.save_path.endswith("pcm"): + with open(args.save_path, "wb") as f: + f.write(buffer) + elif args.save_path.endswith("wav"): + with open("./tmp.pcm", "wb") as f: + f.write(buffer) + pcm2wav("./tmp.pcm", audio_path, channels=1, bits=16, sample_rate=24000) + os.system("rm ./tmp.pcm") + else: + print("Only supports saved audio format is pcm or wav") + return False + + return True + + +def test(args): + params = { + "text": args.text, + "spk_id": args.spk_id, + "speed": args.speed, + "volume": args.volume, + "sample_rate": args.sample_rate, + "save_path": '' + } + + buffer = b'' + flag = 1 + url = "http://" + str(args.server) + ":" + str( + args.port) + "/paddlespeech/streaming/tts" + st = time.time() + html = requests.post(url, json.dumps(params), stream=True) + for chunk in html.iter_content(chunk_size=1024): + chunk = base64.b64decode(chunk) # bytes + if flag: + first_response = time.time() - st + print(f"首包响应:{first_response} s") + flag = 0 + buffer += chunk + + final_response = time.time() - st + duration = len(buffer) / 2.0 / 24000 + + print(f"尾包响应:{final_response} s") + print(f"音频时长:{duration} s") + print(f"RTF: {final_response / duration}") + + if args.save_path is not None: + if save_audio(buffer, args.save_path): + print("音频保存至:", args.save_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + '--text', + type=str, + default="您好,欢迎使用语音合成服务。", + help='A sentence to be synthesized') + parser.add_argument('--spk_id', type=int, default=0, help='Speaker id') + parser.add_argument('--speed', type=float, default=1.0, help='Audio speed') + parser.add_argument( + '--volume', type=float, default=1.0, help='Audio volume') + parser.add_argument( + '--sample_rate', + type=int, + default=0, + help='Sampling rate, the default is the same as the model') + parser.add_argument( + "--server", type=str, help="server ip", default="127.0.0.1") + parser.add_argument("--port", type=int, help="server port", default=8092) + parser.add_argument( + "--save_path", type=str, help="save audio path", default=None) + + args = parser.parse_args() + test(args) diff --git a/tests/unit/server/online/tts/test.sh b/tests/unit/server/online/tts/test.sh new file mode 100644 index 000000000..54e274f1f --- /dev/null +++ b/tests/unit/server/online/tts/test.sh @@ -0,0 +1,315 @@ +#!/bin/bash +# bash test.sh + +StartService(){ + # Start service + paddlespeech_server start --config_file $config_file 1>>$log/server.log 2>>$log/server.log.wf & + echo $! > pid + + start_num=$(cat $log/server.log.wf | grep "INFO: Uvicorn running on http://" -c) + flag="normal" + while [[ $start_num -lt $target_start_num && $flag == "normal" ]] + do + start_num=$(cat $log/server.log.wf | grep "INFO: Uvicorn running on http://" -c) + # start service failed + if [ $(cat $log/server.log.wf | grep -i "Failed to warm up on tts engine." -c) -gt $error_time ];then + echo "Service started failed." | tee -a $log/test_result.log + error_time=$(cat $log/server.log.wf | grep -i "Failed to warm up on tts engine." -c) + flag="unnormal" + + elif [ $(cat $log/server.log.wf | grep -i "AssertionError" -c) -gt $error_time ];then + echo "Service started failed." | tee -a $log/test_result.log + error_time+=$(cat $log/server.log.wf | grep -i "AssertionError" -c) + flag="unnormal" + fi + done +} + +ClientTest_http(){ + for ((i=1; i<=3;i++)) + do + python http_client.py --save_path ./out_http.wav + ((http_test_times+=1)) + done +} + +ClientTest_ws(){ + for ((i=1; i<=3;i++)) + do + python ws_client.py + ((ws_test_times+=1)) + done +} + +GetTestResult_http() { + # Determine if the test was successful + http_response_success_time=$(cat $log/server.log | grep "200 OK" -c) + if (( $http_response_success_time == $http_test_times )) ; then + echo "Testing successfully. $info" | tee -a $log/test_result.log + else + echo "Testing failed. $info" | tee -a $log/test_result.log + fi + http_test_times=$http_response_success_time +} + +GetTestResult_ws() { + # Determine if the test was successful + ws_response_success_time=$(cat $log/server.log.wf | grep "Complete the transmission of audio streams" -c) + if (( $ws_response_success_time == $ws_test_times )) ; then + echo "Testing successfully. $info" | tee -a $log/test_result.log + else + echo "Testing failed. $info" | tee -a $log/test_result.log + fi + ws_test_times=$ws_response_success_time +} + + +engine_type=$1 +log=$2 +mkdir -p $log +rm -rf $log/server.log.wf +rm -rf $log/server.log +rm -rf $log/test_result.log + +config_file=./conf/application.yaml +server_ip=$(cat $config_file | grep "host" | awk -F " " '{print $2}') +port=$(cat $config_file | grep "port" | awk '/port:/ {print $2}') + +echo "Sevice ip: $server_ip" | tee $log/test_result.log +echo "Sevice port: $port" | tee -a $log/test_result.log + +# whether a process is listening on $port +pid=`lsof -i :"$port"|grep -v "PID" | awk '{print $2}'` +if [ "$pid" != "" ]; then + echo "The port: $port is occupied, please change another port" + exit +fi + + + +target_start_num=0 # the number of start service +test_times=0 # The number of client test +error_time=0 # The number of error occurrences in the startup failure server.log.wf file + +# start server: engine: tts_online, protocol: http, am: fastspeech2_cnndecoder_csmsc, voc: mb_melgan_csmsc +info="start server: engine: $engine_type, protocol: http, am: fastspeech2_cnndecoder_csmsc, voc: mb_melgan_csmsc." +echo "$info" | tee -a $log/test_result.log +((target_start_num+=1)) +StartService + +if [[ $start_num -eq $target_start_num && $flag == "normal" ]]; then + echo "Service started successfully." | tee -a $log/test_result.log + ClientTest_http + echo "This round of testing is over." | tee -a $log/test_result.log + + GetTestResult_http +else + echo "Service failed to start, no client test." + target_start_num=$start_num + +fi + +kill -9 `cat pid` +rm -rf pid +sleep 2s +echo "**************************************************************************************" | tee -a $log/test_result.log + + + + +python change_yaml.py --engine_type $engine_type --target_key voc --target_value hifigan_csmsc # change voc: mb_melgan_csmsc -> hifigan_csmsc +# start server: engine: tts_online, protocol: http, am: fastspeech2_cnndecoder_csmsc, voc: hifigan_csmsc +info="start server: engine: $engine_type, protocol: http, am: fastspeech2_cnndecoder_csmsc, voc: hifigan_csmsc." + +echo "$info" | tee -a $log/test_result.log +((target_start_num+=1)) +StartService + +if [[ $start_num -eq $target_start_num && $flag == "normal" ]]; then + echo "Service started successfully." | tee -a $log/test_result.log + ClientTest_http + echo "This round of testing is over." | tee -a $log/test_result.log + + GetTestResult_http +else + echo "Service failed to start, no client test." + target_start_num=$start_num + +fi + +kill -9 `cat pid` +rm -rf pid +sleep 2s +echo "**************************************************************************************" | tee -a $log/test_result.log + + + +python change_yaml.py --engine_type $engine_type --target_key am --target_value fastspeech2_csmsc # change am: fastspeech2_cnndecoder_csmsc -> fastspeech2_csmsc +# start server: engine: tts_online, protocol: http, am: fastspeech2_csmsc, voc: hifigan_csmsc +info="start server: engine: $engine_type, protocol: http, am: fastspeech2_csmsc, voc: hifigan_csmsc." + +echo "$info" | tee -a $log/test_result.log +((target_start_num+=1)) +StartService + +if [[ $start_num -eq $target_start_num && $flag == "normal" ]]; then + echo "Service started successfully." | tee -a $log/test_result.log + ClientTest_http + echo "This round of testing is over." | tee -a $log/test_result.log + + GetTestResult_http +else + echo "Service failed to start, no client test." + target_start_num=$start_num + +fi + +kill -9 `cat pid` +rm -rf pid +sleep 2s +echo "**************************************************************************************" | tee -a $log/test_result.log + + +python change_yaml.py --engine_type $engine_type --target_key voc --target_value mb_melgan_csmsc # change voc: hifigan_csmsc -> mb_melgan_csmsc +# start server: engine: tts_online, protocol: http, am: fastspeech2_csmsc, voc: mb_melgan_csmsc +info="start server: engine: $engine_type, protocol: http, am: fastspeech2_csmsc, voc: mb_melgan_csmsc." + +echo "$info" | tee -a $log/test_result.log +((target_start_num+=1)) +StartService + +if [[ $start_num -eq $target_start_num && $flag == "normal" ]]; then + echo "Service started successfully." | tee -a $log/test_result.log + ClientTest_http + echo "This round of testing is over." | tee -a $log/test_result.log + + GetTestResult_http +else + echo "Service failed to start, no client test." + target_start_num=$start_num + +fi + +kill -9 `cat pid` +rm -rf pid +sleep 2s +echo "**************************************************************************************" | tee -a $log/test_result.log + + +echo "********************************************* websocket **********************************************************" + +python change_yaml.py --engine_type $engine_type --change_type protocol --target_key protocol --target_value websocket +# start server: engine: tts_online, protocol: websocket, am: fastspeech2_csmsc, voc: mb_melgan_csmsc +info="start server: engine: $engine_type, protocol: websocket, am: fastspeech2_csmsc, voc: mb_melgan_csmsc." + +echo "$info" | tee -a $log/test_result.log +((target_start_num+=1)) +StartService + +if [[ $start_num -eq $target_start_num && $flag == "normal" ]]; then + echo "Service started successfully." | tee -a $log/test_result.log + ClientTest_ws + echo "This round of testing is over." | tee -a $log/test_result.log + + GetTestResult_ws +else + echo "Service failed to start, no client test." + target_start_num=$start_num + +fi + +kill -9 `cat pid` +rm -rf pid +sleep 2s +echo "**************************************************************************************" | tee -a $log/test_result.log + +python change_yaml.py --engine_type $engine_type --target_key voc --target_value hifigan_csmsc # change voc: mb_melgan_csmsc -> hifigan_csmsc +# start server: engine: tts_online, protocol: websocket, am: fastspeech2_csmsc, voc: hifigan_csmsc +info="start server: engine: $engine_type, protocol: websocket, am: fastspeech2_csmsc, voc: hifigan_csmsc." + +echo "$info" | tee -a $log/test_result.log +((target_start_num+=1)) +StartService + +if [[ $start_num -eq $target_start_num && $flag == "normal" ]]; then + echo "Service started successfully." | tee -a $log/test_result.log + ClientTest_ws + echo "This round of testing is over." | tee -a $log/test_result.log + + GetTestResult_ws +else + echo "Service failed to start, no client test." + target_start_num=$start_num + +fi + +kill -9 `cat pid` +rm -rf pid +sleep 2s +echo "**************************************************************************************" | tee -a $log/test_result.log + + +python change_yaml.py --engine_type $engine_type --target_key am --target_value fastspeech2_cnndecoder_csmsc # change am: fastspeech2_csmsc -> fastspeech2_cnndecoder_csmsc +# start server: engine: tts_online, protocol: websocket, am: fastspeech2_cnndecoder_csmsc, voc: hifigan_csmsc +info="start server: engine: $engine_type, protocol: websocket, am: fastspeech2_cnndecoder_csmsc, voc: hifigan_csmsc." + +echo "$info" | tee -a $log/test_result.log +((target_start_num+=1)) +StartService + +if [[ $start_num -eq $target_start_num && $flag == "normal" ]]; then + echo "Service started successfully." | tee -a $log/test_result.log + ClientTest_ws + echo "This round of testing is over." | tee -a $log/test_result.log + + GetTestResult_ws +else + echo "Service failed to start, no client test." + target_start_num=$start_num + +fi + +kill -9 `cat pid` +rm -rf pid +sleep 2s +echo "**************************************************************************************" | tee -a $log/test_result.log + + + +python change_yaml.py --engine_type $engine_type --target_key voc --target_value mb_melgan_csmsc # change am: hifigan_csmsc -> mb_melgan_csmsc +# start server: engine: tts_online, protocol: websocket, am: fastspeech2_cnndecoder_csmsc, voc: mb_melgan_csmsc +info="start server: engine: $engine_type, protocol: websocket, am: fastspeech2_cnndecoder_csmsc, voc: mb_melgan_csmsc." + +echo "$info" | tee -a $log/test_result.log +((target_start_num+=1)) +StartService + +if [[ $start_num -eq $target_start_num && $flag == "normal" ]]; then + echo "Service started successfully." | tee -a $log/test_result.log + ClientTest_ws + echo "This round of testing is over." | tee -a $log/test_result.log + + GetTestResult_ws +else + echo "Service failed to start, no client test." + target_start_num=$start_num + +fi + +kill -9 `cat pid` +rm -rf pid +sleep 2s +echo "**************************************************************************************" | tee -a $log/test_result.log + + + +echo "All tests completed." | tee -a $log/test_result.log + + +# sohw all the test results +echo "***************** Here are all the test results ********************" +cat $log/test_result.log + +# Restoring conf is the same as demos/speech_server +cp ./tts_online_application.yaml ./conf/application.yaml -rf +sleep 2s \ No newline at end of file diff --git a/tests/unit/server/online/tts/test_all.sh b/tests/unit/server/online/tts/test_all.sh new file mode 100644 index 000000000..8e490255d --- /dev/null +++ b/tests/unit/server/online/tts/test_all.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# bash test_all.sh + +log_all_dir=./log + +bash test.sh tts_online $log_all_dir/log_tts_online_cpu + +python change_yaml.py --change_type engine_type --target_key engine_list --target_value tts_online-onnx +bash test.sh tts_online-onnx $log_all_dir/log_tts_online-onnx_cpu + + +python change_yaml.py --change_type device --target_key device --target_value gpu:3 +bash test.sh tts_online $log_all_dir/log_tts_online_gpu + +python change_yaml.py --change_type engine_type --target_key engine_list --target_value tts_online-onnx +python change_yaml.py --change_type device --target_key device --target_value gpu:3 +bash test.sh tts_online-onnx $log_all_dir/log_tts_online-onnx_gpu + +echo "************************************** show all test results ****************************************" +cat $log_all_dir/log_tts_online_cpu/test_result.log +cat $log_all_dir/log_tts_online-onnx_cpu/test_result.log +cat $log_all_dir/log_tts_online_gpu/test_result.log +cat $log_all_dir/log_tts_online-onnx_gpu/test_result.log diff --git a/tests/unit/server/online/tts/tts_online_application.yaml b/tests/unit/server/online/tts/tts_online_application.yaml new file mode 100644 index 000000000..347411b66 --- /dev/null +++ b/tests/unit/server/online/tts/tts_online_application.yaml @@ -0,0 +1,88 @@ +# This is the parameter configuration file for PaddleSpeech Serving. + +################################################################################# +# SERVER SETTING # +################################################################################# +host: 127.0.0.1 +port: 8092 + +# The task format in the engin_list is: _ +# task choices = ['tts_online', 'tts_online-onnx'] +# protocol = ['websocket', 'http'] (only one can be selected). +protocol: 'http' +engine_list: ['tts_online'] + + +################################################################################# +# ENGINE CONFIG # +################################################################################# + +################################### TTS ######################################### +################### speech task: tts; engine_type: online ####################### +tts_online: + # am (acoustic model) choices=['fastspeech2_csmsc', 'fastspeech2_cnndecoder_csmsc'] + am: 'fastspeech2_cnndecoder_csmsc' + am_config: + am_ckpt: + am_stat: + phones_dict: + tones_dict: + speaker_dict: + spk_id: 0 + + # voc (vocoder) choices=['mb_melgan_csmsc', 'hifigan_csmsc'] + voc: 'mb_melgan_csmsc' + voc_config: + voc_ckpt: + voc_stat: + + # others + lang: 'zh' + device: 'cpu' # set 'gpu:id' or 'cpu' + am_block: 42 + am_pad: 12 + voc_block: 14 + voc_pad: 14 + + + +################################################################################# +# ENGINE CONFIG # +################################################################################# + +################################### TTS ######################################### +################### speech task: tts; engine_type: online-onnx ####################### +tts_online-onnx: + # am (acoustic model) choices=['fastspeech2_csmsc_onnx', 'fastspeech2_cnndecoder_csmsc_onnx'] + am: 'fastspeech2_cnndecoder_csmsc_onnx' + # am_ckpt is a list, if am is fastspeech2_cnndecoder_csmsc_onnx, am_ckpt = [encoder model, decoder model, postnet model]; + # if am is fastspeech2_csmsc_onnx, am_ckpt = [ckpt model]; + am_ckpt: # list + am_stat: + phones_dict: + tones_dict: + speaker_dict: + spk_id: 0 + am_sample_rate: 24000 + am_sess_conf: + device: "cpu" # set 'gpu:id' or 'cpu' + use_trt: False + cpu_threads: 1 + + # voc (vocoder) choices=['mb_melgan_csmsc_onnx', 'hifigan_csmsc_onnx'] + voc: 'mb_melgan_csmsc_onnx' + voc_ckpt: + voc_sample_rate: 24000 + voc_sess_conf: + device: "cpu" # set 'gpu:id' or 'cpu' + use_trt: False + cpu_threads: 1 + + # others + lang: 'zh' + am_block: 42 + am_pad: 12 + voc_block: 14 + voc_pad: 14 + voc_upsample: 300 + diff --git a/tests/unit/server/online/tts/ws_client.py b/tests/unit/server/online/tts/ws_client.py new file mode 100644 index 000000000..eef010cf2 --- /dev/null +++ b/tests/unit/server/online/tts/ws_client.py @@ -0,0 +1,126 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import _thread as thread +import argparse +import base64 +import json +import ssl +import time + +import websocket + +flag = 1 +st = 0.0 +all_bytes = b'' + + +class WsParam(object): + # 初始化 + def __init__(self, text, server="127.0.0.1", port=8090): + self.server = server + self.port = port + self.url = "ws://" + self.server + ":" + str(self.port) + "/ws/tts" + self.text = text + + # 生成url + def create_url(self): + return self.url + + +def on_message(ws, message): + global flag + global st + global all_bytes + + try: + message = json.loads(message) + audio = message["audio"] + audio = base64.b64decode(audio) # bytes + status = message["status"] + all_bytes += audio + + if status == 0: + print("create successfully.") + elif status == 1: + if flag: + print(f"首包响应:{time.time() - st} s") + flag = 0 + elif status == 2: + final_response = time.time() - st + duration = len(all_bytes) / 2.0 / 24000 + print(f"尾包响应:{final_response} s") + print(f"音频时长:{duration} s") + print(f"RTF: {final_response / duration}") + with open("./out.pcm", "wb") as f: + f.write(all_bytes) + print("ws is closed") + ws.close() + else: + print("infer error") + + except Exception as e: + print("receive msg,but parse exception:", e) + + +# 收到websocket错误的处理 +def on_error(ws, error): + print("### error:", error) + + +# 收到websocket关闭的处理 +def on_close(ws): + print("### closed ###") + + +# 收到websocket连接建立的处理 +def on_open(ws): + def run(*args): + global st + text_base64 = str( + base64.b64encode((wsParam.text).encode('utf-8')), "UTF8") + d = {"text": text_base64} + d = json.dumps(d) + print("Start sending text data") + st = time.time() + ws.send(d) + + thread.start_new_thread(run, ()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--text", + type=str, + help="A sentence to be synthesized", + default="您好,欢迎使用语音合成服务。") + parser.add_argument( + "--server", type=str, help="server ip", default="127.0.0.1") + parser.add_argument("--port", type=int, help="server port", default=8092) + args = parser.parse_args() + + print("***************************************") + print("Server ip: ", args.server) + print("Server port: ", args.port) + print("Sentence to be synthesized: ", args.text) + print("***************************************") + + wsParam = WsParam(text=args.text, server=args.server, port=args.port) + + websocket.enableTrace(False) + wsUrl = wsParam.create_url() + ws = websocket.WebSocketApp( + wsUrl, on_message=on_message, on_error=on_error, on_close=on_close) + ws.on_open = on_open + ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})