# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import base64 import math import os import time from typing import Optional import numpy as np import paddle from paddlespeech.cli.log import logger from paddlespeech.cli.tts.infer import TTSExecutor from paddlespeech.cli.utils import download_and_decompress from paddlespeech.cli.utils import MODEL_HOME from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.audio_process import float2pcm from paddlespeech.server.utils.onnx_infer import get_sess from paddlespeech.server.utils.util import denorm from paddlespeech.server.utils.util import get_chunks from paddlespeech.t2s.frontend import English from paddlespeech.t2s.frontend.zh_frontend import Frontend __all__ = ['TTSEngine'] # support online model pretrained_models = { # fastspeech2 "fastspeech2_csmsc_onnx-zh": { 'url': 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_onnx_0.2.0.zip', 'md5': 'fd3ad38d83273ad51f0ea4f4abf3ab4e', 'ckpt': ['fastspeech2_csmsc.onnx'], 'phones_dict': 'phone_id_map.txt', 'sample_rate': 24000, }, "fastspeech2_cnndecoder_csmsc_onnx-zh": { 'url': 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip', 'md5': '5f70e1a6bcd29d72d54e7931aa86f266', 'ckpt': [ 'fastspeech2_csmsc_am_encoder_infer.onnx', 'fastspeech2_csmsc_am_decoder.onnx', 'fastspeech2_csmsc_am_postnet.onnx', ], 'speech_stats': 'speech_stats.npy', 'phones_dict': 'phone_id_map.txt', 'sample_rate': 24000, }, # mb_melgan "mb_melgan_csmsc_onnx-zh": { 'url': 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_onnx_0.2.0.zip', 'md5': '5b83ec746e8414bc29032d954ffd07ec', 'ckpt': 'mb_melgan_csmsc.onnx', 'sample_rate': 24000, }, # hifigan "hifigan_csmsc_onnx-zh": { 'url': 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_onnx_0.2.0.zip', 'md5': '1a7dc0385875889e46952e50c0994a6b', 'ckpt': 'hifigan_csmsc.onnx', 'sample_rate': 24000, }, } model_alias = { # acoustic model "fastspeech2": "paddlespeech.t2s.models.fastspeech2:FastSpeech2", "fastspeech2_inference": "paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference", # voc "mb_melgan": "paddlespeech.t2s.models.melgan:MelGANGenerator", "mb_melgan_inference": "paddlespeech.t2s.models.melgan:MelGANInference", "hifigan": "paddlespeech.t2s.models.hifigan:HiFiGANGenerator", "hifigan_inference": "paddlespeech.t2s.models.hifigan:HiFiGANInference", } __all__ = ['TTSEngine'] class TTSServerExecutor(TTSExecutor): def __init__(self, am_block, am_pad, voc_block, voc_pad, voc_upsample): super().__init__() self.am_block = am_block self.am_pad = am_pad self.voc_block = voc_block self.voc_pad = voc_pad self.voc_upsample = voc_upsample self.pretrained_models = pretrained_models self.model_alias = model_alias def _get_pretrained_path(self, tag: str) -> os.PathLike: """ #Download and returns pretrained resources path of current task. """ support_models = list(pretrained_models.keys()) assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format( tag, '\n\t\t'.join(support_models)) res_path = os.path.join(MODEL_HOME, tag) decompressed_path = download_and_decompress(pretrained_models[tag], res_path) decompressed_path = os.path.abspath(decompressed_path) logger.info( 'Use pretrained model stored in: {}'.format(decompressed_path)) return decompressed_path def _init_from_path( self, am: str='fastspeech2_csmsc_onnx', am_ckpt: Optional[list]=None, am_stat: Optional[os.PathLike]=None, phones_dict: Optional[os.PathLike]=None, tones_dict: Optional[os.PathLike]=None, speaker_dict: Optional[os.PathLike]=None, am_sample_rate: int=24000, am_sess_conf: dict=None, voc: str='mb_melgan_csmsc_onnx', voc_ckpt: Optional[os.PathLike]=None, voc_sample_rate: int=24000, voc_sess_conf: dict=None, lang: str='zh', ): """ Init model and other resources from a specific path. """ if (hasattr(self, 'am_sess') or (hasattr(self, 'am_encoder_infer_sess') and hasattr(self, 'am_decoder_sess') and hasattr( self, 'am_postnet_sess'))) and hasattr(self, 'voc_inference'): logger.info('Models had been initialized.') return # am am_tag = am + '-' + lang if am == "fastspeech2_csmsc_onnx": # get model info if am_ckpt is None or phones_dict is None: am_res_path = self._get_pretrained_path(am_tag) self.am_res_path = am_res_path self.am_ckpt = os.path.join( am_res_path, pretrained_models[am_tag]['ckpt'][0]) # must have phones_dict in acoustic self.phones_dict = os.path.join( am_res_path, pretrained_models[am_tag]['phones_dict']) else: self.am_ckpt = os.path.abspath(am_ckpt[0]) self.phones_dict = os.path.abspath(phones_dict) self.am_res_path = os.path.dirname( os.path.abspath(self.am_ckpt)) # create am sess self.am_sess = get_sess(self.am_ckpt, am_sess_conf) elif am == "fastspeech2_cnndecoder_csmsc_onnx": if am_ckpt is None or am_stat is None or phones_dict is None: am_res_path = self._get_pretrained_path(am_tag) self.am_res_path = am_res_path self.am_encoder_infer = os.path.join( am_res_path, pretrained_models[am_tag]['ckpt'][0]) self.am_decoder = os.path.join( am_res_path, pretrained_models[am_tag]['ckpt'][1]) self.am_postnet = os.path.join( am_res_path, pretrained_models[am_tag]['ckpt'][2]) # must have phones_dict in acoustic self.phones_dict = os.path.join( am_res_path, pretrained_models[am_tag]['phones_dict']) self.am_stat = os.path.join( am_res_path, pretrained_models[am_tag]['speech_stats']) else: self.am_encoder_infer = os.path.abspath(am_ckpt[0]) self.am_decoder = os.path.abspath(am_ckpt[1]) self.am_postnet = os.path.abspath(am_ckpt[2]) self.phones_dict = os.path.abspath(phones_dict) self.am_stat = os.path.abspath(am_stat) self.am_res_path = os.path.dirname( os.path.abspath(self.am_ckpt)) # create am sess self.am_encoder_infer_sess = get_sess(self.am_encoder_infer, am_sess_conf) self.am_decoder_sess = get_sess(self.am_decoder, am_sess_conf) self.am_postnet_sess = get_sess(self.am_postnet, am_sess_conf) self.am_mu, self.am_std = np.load(self.am_stat) logger.info(f"self.phones_dict: {self.phones_dict}") logger.info(f"am model dir: {self.am_res_path}") logger.info("Create am sess successfully.") # voc model info voc_tag = voc + '-' + lang if voc_ckpt is None: voc_res_path = self._get_pretrained_path(voc_tag) self.voc_res_path = voc_res_path self.voc_ckpt = os.path.join(voc_res_path, pretrained_models[voc_tag]['ckpt']) else: self.voc_ckpt = os.path.abspath(voc_ckpt) self.voc_res_path = os.path.dirname(os.path.abspath(self.voc_ckpt)) logger.info(self.voc_res_path) # create voc sess self.voc_sess = get_sess(self.voc_ckpt, voc_sess_conf) logger.info("Create voc sess successfully.") with open(self.phones_dict, "r") as f: phn_id = [line.strip().split() for line in f.readlines()] self.vocab_size = len(phn_id) logger.info(f"vocab_size: {self.vocab_size}") # frontend self.tones_dict = None if lang == 'zh': self.frontend = Frontend( phone_vocab_path=self.phones_dict, tone_vocab_path=self.tones_dict) elif lang == 'en': self.frontend = English(phone_vocab_path=self.phones_dict) logger.info("frontend done!") def depadding(self, data, chunk_num, chunk_id, block, pad, upsample): """ Streaming inference removes the result of pad inference """ front_pad = min(chunk_id * block, pad) # first chunk if chunk_id == 0: data = data[:block * upsample] # last chunk elif chunk_id == chunk_num - 1: data = data[front_pad * upsample:] # middle chunk else: data = data[front_pad * upsample:(front_pad + block) * upsample] return data @paddle.no_grad() def infer( self, text: str, lang: str='zh', am: str='fastspeech2_csmsc_onnx', spk_id: int=0, ): """ Model inference and result stored in self.output. """ #import pdb;pdb.set_trace() am_block = self.am_block am_pad = self.am_pad am_upsample = 1 voc_block = self.voc_block voc_pad = self.voc_pad voc_upsample = self.voc_upsample # first_flag 用于标记首包 first_flag = 1 get_tone_ids = False merge_sentences = False # front frontend_st = time.time() if lang == 'zh': input_ids = self.frontend.get_input_ids( text, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) phone_ids = input_ids["phone_ids"] if get_tone_ids: tone_ids = input_ids["tone_ids"] elif lang == 'en': input_ids = self.frontend.get_input_ids( text, merge_sentences=merge_sentences) phone_ids = input_ids["phone_ids"] else: logger.error("lang should in {'zh', 'en'}!") frontend_et = time.time() self.frontend_time = frontend_et - frontend_st for i in range(len(phone_ids)): part_phone_ids = phone_ids[i].numpy() voc_chunk_id = 0 # fastspeech2_csmsc if am == "fastspeech2_csmsc_onnx": # am mel = self.am_sess.run( output_names=None, input_feed={'text': part_phone_ids}) mel = mel[0] if first_flag == 1: first_am_et = time.time() self.first_am_infer = first_am_et - frontend_et # voc streaming mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc") voc_chunk_num = len(mel_chunks) voc_st = time.time() for i, mel_chunk in enumerate(mel_chunks): sub_wav = self.voc_sess.run( output_names=None, input_feed={'logmel': mel_chunk}) sub_wav = self.depadding(sub_wav[0], voc_chunk_num, i, voc_block, voc_pad, voc_upsample) if first_flag == 1: first_voc_et = time.time() self.first_voc_infer = first_voc_et - first_am_et self.first_response_time = first_voc_et - frontend_st first_flag = 0 yield sub_wav # fastspeech2_cnndecoder_csmsc elif am == "fastspeech2_cnndecoder_csmsc_onnx": # am orig_hs = self.am_encoder_infer_sess.run( None, input_feed={'text': part_phone_ids}) orig_hs = orig_hs[0] # streaming voc chunk info mel_len = orig_hs.shape[1] voc_chunk_num = math.ceil(mel_len / self.voc_block) start = 0 end = min(self.voc_block + self.voc_pad, mel_len) # streaming am hss = get_chunks(orig_hs, self.am_block, self.am_pad, "am") am_chunk_num = len(hss) for i, hs in enumerate(hss): am_decoder_output = self.am_decoder_sess.run( None, input_feed={'xs': hs}) am_postnet_output = self.am_postnet_sess.run( None, input_feed={ 'xs': np.transpose(am_decoder_output[0], (0, 2, 1)) }) am_output_data = am_decoder_output + np.transpose( am_postnet_output[0], (0, 2, 1)) normalized_mel = am_output_data[0][0] sub_mel = denorm(normalized_mel, self.am_mu, self.am_std) sub_mel = self.depadding(sub_mel, am_chunk_num, i, am_block, am_pad, am_upsample) if i == 0: mel_streaming = sub_mel else: mel_streaming = np.concatenate( (mel_streaming, sub_mel), axis=0) # streaming voc # 当流式AM推理的mel帧数大于流式voc推理的chunk size,开始进行流式voc 推理 while (mel_streaming.shape[0] >= end and voc_chunk_id < voc_chunk_num): if first_flag == 1: first_am_et = time.time() self.first_am_infer = first_am_et - frontend_et voc_chunk = mel_streaming[start:end, :] sub_wav = self.voc_sess.run( output_names=None, input_feed={'logmel': voc_chunk}) sub_wav = self.depadding(sub_wav[0], voc_chunk_num, voc_chunk_id, voc_block, voc_pad, voc_upsample) if first_flag == 1: first_voc_et = time.time() self.first_voc_infer = first_voc_et - first_am_et self.first_response_time = first_voc_et - frontend_st first_flag = 0 yield sub_wav voc_chunk_id += 1 start = max(0, voc_chunk_id * voc_block - voc_pad) end = min((voc_chunk_id + 1) * voc_block + voc_pad, mel_len) else: logger.error( "Only support fastspeech2_csmsc or fastspeech2_cnndecoder_csmsc on streaming tts." ) self.final_response_time = time.time() - frontend_st class TTSEngine(BaseEngine): """TTS server engine Args: metaclass: Defaults to Singleton. """ def __init__(self, name=None): """Initialize TTS server engine """ super().__init__() def init(self, config: dict) -> bool: self.config = config assert ( self.config.am == "fastspeech2_csmsc_onnx" or self.config.am == "fastspeech2_cnndecoder_csmsc_onnx" ) and ( self.config.voc == "hifigan_csmsc_onnx" or self.config.voc == "mb_melgan_csmsc_onnx" ), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.' assert ( self.config.voc_block > 0 and self.config.voc_pad > 0 ), "Please set correct voc_block and voc_pad, they should be more than 0." assert ( self.config.voc_sample_rate == self.config.am_sample_rate ), "The sample rate of AM and Vocoder model are different, please check model." self.executor = TTSServerExecutor( self.config.am_block, self.config.am_pad, self.config.voc_block, self.config.voc_pad, self.config.voc_upsample) if "cpu" in self.config.am_sess_conf.device or "cpu" in self.config.voc_sess_conf.device: paddle.set_device("cpu") else: paddle.set_device(self.config.am_sess_conf.device) try: self.executor._init_from_path( am=self.config.am, am_ckpt=self.config.am_ckpt, am_stat=self.config.am_stat, phones_dict=self.config.phones_dict, tones_dict=self.config.tones_dict, speaker_dict=self.config.speaker_dict, am_sample_rate=self.config.am_sample_rate, am_sess_conf=self.config.am_sess_conf, voc=self.config.voc, voc_ckpt=self.config.voc_ckpt, voc_sample_rate=self.config.voc_sample_rate, voc_sess_conf=self.config.voc_sess_conf, lang=self.config.lang) except Exception as e: logger.error("Failed to get model related files.") logger.error("Initialize TTS server engine Failed on device: %s." % (self.config.voc_sess_conf.device)) return False logger.info("Initialize TTS server engine successfully on device: %s." % (self.config.voc_sess_conf.device)) # warm up try: self.warm_up() except Exception as e: logger.error("Failed to warm up on tts engine.") return False return True def warm_up(self): """warm up """ if self.config.lang == 'zh': sentence = "您好,欢迎使用语音合成服务。" if self.config.lang == 'en': sentence = "Hello and welcome to the speech synthesis service." logger.info( "*******************************warm up ********************************" ) for i in range(3): for wav in self.executor.infer( text=sentence, lang=self.config.lang, am=self.config.am, spk_id=0, ): logger.info( f"The first response time of the {i} warm up: {self.executor.first_response_time} s" ) break logger.info( "**********************************************************************" ) def preprocess(self, text_bese64: str=None, text_bytes: bytes=None): # Convert byte to text if text_bese64: text_bytes = base64.b64decode(text_bese64) # base64 to bytes text = text_bytes.decode('utf-8') # bytes to text return text def run(self, sentence: str, spk_id: int=0, speed: float=1.0, volume: float=1.0, sample_rate: int=0, save_path: str=None): """ run include inference and postprocess. Args: sentence (str): text to be synthesized spk_id (int, optional): speaker id for multi-speaker speech synthesis. Defaults to 0. speed (float, optional): speed. Defaults to 1.0. volume (float, optional): volume. Defaults to 1.0. sample_rate (int, optional): target sample rate for synthesized audio, 0 means the same as the model sampling rate. Defaults to 0. save_path (str, optional): The save path of the synthesized audio. None means do not save audio. Defaults to None. Returns: wav_base64: The base64 format of the synthesized audio. """ wav_list = [] for wav in self.executor.infer( text=sentence, lang=self.config.lang, am=self.config.am, spk_id=spk_id, ): # wav type: float32, convert to pcm (base64) wav = float2pcm(wav) # float32 to int16 wav_bytes = wav.tobytes() # to bytes wav_base64 = base64.b64encode(wav_bytes).decode('utf8') # to base64 wav_list.append(wav) yield wav_base64 wav_all = np.concatenate(wav_list, axis=0) duration = len(wav_all) / self.config.voc_sample_rate logger.info(f"sentence: {sentence}") logger.info(f"The durations of audio is: {duration} s") logger.info( f"first response time: {self.executor.first_response_time} s") logger.info( f"final response time: {self.executor.final_response_time} s") logger.info(f"RTF: {self.executor.final_response_time / duration}") logger.info( f"Other info: front time: {self.executor.frontend_time} s, first am infer time: {self.executor.first_am_infer} s, first voc infer time: {self.executor.first_voc_infer} s," )