diff --git a/paddlespeech/server/conf/tts_online_application.yaml b/paddlespeech/server/conf/tts_online_application.yaml new file mode 100644 index 000000000..a80b3ecec --- /dev/null +++ b/paddlespeech/server/conf/tts_online_application.yaml @@ -0,0 +1,46 @@ +# This is the parameter configuration file for PaddleSpeech Serving. + +################################################################################# +# SERVER SETTING # +################################################################################# +host: 127.0.0.1 +port: 8092 + +# The task format in the engin_list is: _ +# task choices = ['asr_online', 'tts_online'] +# protocol = ['websocket', 'http'] (only one can be selected). +protocol: 'http' +engine_list: ['tts_online'] + + +################################################################################# +# ENGINE CONFIG # +################################################################################# + +################################### TTS ######################################### +################### speech task: tts; engine_type: online ####################### +tts_online: + # am (acoustic model) choices=['fastspeech2_csmsc'] + am: 'fastspeech2_csmsc' + am_config: + am_ckpt: + am_stat: + phones_dict: + tones_dict: + speaker_dict: + spk_id: 0 + + # voc (vocoder) choices=['mb_melgan_csmsc'] + voc: 'mb_melgan_csmsc' + voc_config: + voc_ckpt: + voc_stat: + + # others + lang: 'zh' + device: # set 'gpu:id' or 'cpu' + am_block: 42 + am_pad: 12 + voc_block: 14 + voc_pad: 14 + diff --git a/paddlespeech/server/engine/engine_factory.py b/paddlespeech/server/engine/engine_factory.py index 2a39fb79b..e147a29a6 100644 --- a/paddlespeech/server/engine/engine_factory.py +++ b/paddlespeech/server/engine/engine_factory.py @@ -34,6 +34,9 @@ class EngineFactory(object): elif engine_name == 'tts' and engine_type == 'python': from paddlespeech.server.engine.tts.python.tts_engine import TTSEngine return TTSEngine() + elif engine_name == 'tts' and engine_type == 'online': + from paddlespeech.server.engine.tts.online.tts_engine import TTSEngine + return TTSEngine() elif engine_name == 'cls' and engine_type == 'inference': from paddlespeech.server.engine.cls.paddleinference.cls_engine import CLSEngine return CLSEngine() diff --git a/paddlespeech/server/engine/tts/online/__init__.py b/paddlespeech/server/engine/tts/online/__init__.py new file mode 100644 index 000000000..97043fd7b --- /dev/null +++ b/paddlespeech/server/engine/tts/online/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/engine/tts/online/tts_engine.py b/paddlespeech/server/engine/tts/online/tts_engine.py new file mode 100644 index 000000000..2f068b3b9 --- /dev/null +++ b/paddlespeech/server/engine/tts/online/tts_engine.py @@ -0,0 +1,305 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import base64 +import io +import time + +import librosa +import numpy as np +import paddle +import soundfile as sf +from scipy.io import wavfile + +from paddlespeech.cli.log import logger +from paddlespeech.cli.tts.infer import TTSExecutor +from paddlespeech.server.engine.base_engine import BaseEngine +from paddlespeech.server.utils.audio_process import change_speed +from paddlespeech.server.utils.errors import ErrorCode +from paddlespeech.server.utils.exception import ServerBaseException +from paddlespeech.server.utils.audio_process import float2pcm +from paddlespeech.server.utils.config import get_config +from paddlespeech.server.utils.util import denorm +from paddlespeech.server.utils.util import get_chunks + + +import math + +__all__ = ['TTSEngine'] + + +class TTSServerExecutor(TTSExecutor): + def __init__(self): + super().__init__() + pass + + @paddle.no_grad() + def infer(self, + text: str, + lang: str='zh', + am: str='fastspeech2_csmsc', + spk_id: int=0, + am_block: int=42, + am_pad: int=12, + voc_block: int=14, + voc_pad: int=14,): + """ + Model inference and result stored in self.output. + """ + am_name = am[:am.rindex('_')] + am_dataset = am[am.rindex('_') + 1:] + get_tone_ids = False + merge_sentences = False + frontend_st = time.time() + if am_name == 'speedyspeech': + get_tone_ids = True + if lang == 'zh': + input_ids = self.frontend.get_input_ids( + text, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids) + phone_ids = input_ids["phone_ids"] + if get_tone_ids: + tone_ids = input_ids["tone_ids"] + elif lang == 'en': + input_ids = self.frontend.get_input_ids( + text, merge_sentences=merge_sentences) + phone_ids = input_ids["phone_ids"] + else: + print("lang should in {'zh', 'en'}!") + self.frontend_time = time.time() - frontend_st + + for i in range(len(phone_ids)): + am_st = time.time() + part_phone_ids = phone_ids[i] + # am + if am_name == 'speedyspeech': + part_tone_ids = tone_ids[i] + mel = self.am_inference(part_phone_ids, part_tone_ids) + # fastspeech2 + else: + # multi speaker + if am_dataset in {"aishell3", "vctk"}: + mel = self.am_inference( + part_phone_ids, spk_id=paddle.to_tensor(spk_id)) + else: + mel = self.am_inference(part_phone_ids) + am_et = time.time() + + # voc streaming + voc_upsample = self.voc_config.n_shift + mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc") + chunk_num = len(mel_chunks) + voc_st = time.time() + for i, mel_chunk in enumerate(mel_chunks): + sub_wav = self.voc_inference(mel_chunk) + front_pad = min(i*voc_block, voc_pad) + + if i == 0: + sub_wav = sub_wav[: voc_block * voc_upsample] + elif i == chunk_num - 1: + sub_wav = sub_wav[front_pad * voc_upsample : ] + else: + sub_wav = sub_wav[front_pad * voc_upsample: (front_pad + voc_block) * voc_upsample] + + yield sub_wav + +class TTSEngine(BaseEngine): + """TTS server engine + + Args: + metaclass: Defaults to Singleton. + """ + + def __init__(self, name=None): + """Initialize TTS server engine + """ + super(TTSEngine, self).__init__() + + def init(self, config: dict) -> bool: + self.executor = TTSServerExecutor() + + try: + self.config = config + if self.config.device: + self.device = self.config.device + else: + self.device = paddle.get_device() + paddle.set_device(self.device) + except Exception as e: + logger.error( + "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" + ) + logger.error("Initialize TTS server engine Failed on device: %s." % + (self.device)) + return False + + try: + self.executor._init_from_path( + am=self.config.am, + am_config=self.config.am_config, + am_ckpt=self.config.am_ckpt, + am_stat=self.config.am_stat, + phones_dict=self.config.phones_dict, + tones_dict=self.config.tones_dict, + speaker_dict=self.config.speaker_dict, + voc=self.config.voc, + voc_config=self.config.voc_config, + voc_ckpt=self.config.voc_ckpt, + voc_stat=self.config.voc_stat, + lang=self.config.lang) + except Exception as e: + logger.error("Failed to get model related files.") + logger.error("Initialize TTS server engine Failed on device: %s." % + (self.device)) + return False + + self.am_block = self.config.am_block + self.am_pad = self.config.am_pad + self.voc_block = self.config.voc_block + self.voc_pad = self.config.voc_pad + + logger.info("Initialize TTS server engine successfully on device: %s." % + (self.device)) + return True + + def preprocess(self, text_bese64: str=None, text_bytes: bytes=None): + # Convert byte to text + if text_bese64: + text_bytes = base64.b64decode(text_bese64) # base64 to bytes + text = text_bytes.decode('utf-8') # bytes to text + + return text + + def postprocess(self, + wav, + original_fs: int, + target_fs: int=0, + volume: float=1.0, + speed: float=1.0, + audio_path: str=None): + """Post-processing operations, including speech, volume, sample rate, save audio file + + Args: + wav (numpy(float)): Synthesized audio sample points + original_fs (int): original audio sample rate + target_fs (int): target audio sample rate + volume (float): target volume + speed (float): target speed + + Raises: + ServerBaseException: Throws an exception if the change speed unsuccessfully. + + Returns: + target_fs: target sample rate for synthesized audio. + wav_base64: The base64 format of the synthesized audio. + """ + + # transform sample_rate + if target_fs == 0 or target_fs > original_fs: + target_fs = original_fs + wav_tar_fs = wav + logger.info( + "The sample rate of synthesized audio is the same as model, which is {}Hz". + format(original_fs)) + else: + wav_tar_fs = librosa.resample( + np.squeeze(wav), original_fs, target_fs) + logger.info( + "The sample rate of model is {}Hz and the target sample rate is {}Hz. Converting the sample rate of the synthesized audio successfully.". + format(original_fs, target_fs)) + # transform volume + wav_vol = wav_tar_fs * volume + logger.info("Transform the volume of the audio successfully.") + + # transform speed + try: # windows not support soxbindings + wav_speed = change_speed(wav_vol, speed, target_fs) + logger.info("Transform the speed of the audio successfully.") + except ServerBaseException: + raise ServerBaseException( + ErrorCode.SERVER_INTERNAL_ERR, + "Failed to transform speed. Can not install soxbindings on your system. \ + You need to set speed value 1.0.") + except BaseException: + logger.error("Failed to transform speed.") + + # wav to base64 + buf = io.BytesIO() + wavfile.write(buf, target_fs, wav_speed) + base64_bytes = base64.b64encode(buf.read()) + wav_base64 = base64_bytes.decode('utf-8') + logger.info("Audio to string successfully.") + + # save audio + if audio_path is not None: + if audio_path.endswith(".wav"): + sf.write(audio_path, wav_speed, target_fs) + elif audio_path.endswith(".pcm"): + wav_norm = wav_speed * (32767 / max(0.001, + np.max(np.abs(wav_speed)))) + with open(audio_path, "wb") as f: + f.write(wav_norm.astype(np.int16)) + logger.info("Save audio to {} successfully.".format(audio_path)) + else: + logger.info("There is no need to save audio.") + + return target_fs, wav_base64 + + def run(self, + sentence: str, + spk_id: int=0, + speed: float=1.0, + volume: float=1.0, + sample_rate: int=0, + save_path: str=None): + """ run include inference and postprocess. + + Args: + sentence (str): text to be synthesized + spk_id (int, optional): speaker id for multi-speaker speech synthesis. Defaults to 0. + speed (float, optional): speed. Defaults to 1.0. + volume (float, optional): volume. Defaults to 1.0. + sample_rate (int, optional): target sample rate for synthesized audio, + 0 means the same as the model sampling rate. Defaults to 0. + save_path (str, optional): The save path of the synthesized audio. + None means do not save audio. Defaults to None. + + Raises: + ServerBaseException: Throws an exception if tts inference unsuccessfully. + ServerBaseException: Throws an exception if postprocess unsuccessfully. + + Returns: + lang: model language + target_sample_rate: target sample rate for synthesized audio. + wav_base64: The base64 format of the synthesized audio. + """ + + lang = self.config.lang + wav_list = [] + + for wav in self.executor.infer(text=sentence, lang=lang, am=self.config.am, spk_id=spk_id, am_block=self.am_block, am_pad=self.am_pad, voc_block=self.voc_block, voc_pad=self.voc_pad): + # wav type: float32, convert to pcm (base64) + wav = float2pcm(wav) # float32 to int16 + wav_bytes = wav.tobytes() # to bytes + wav_base64 = base64.b64encode(wav_bytes).decode('utf8') # to base64 + wav_list.append(wav) + + yield wav_base64 + + wav_all = np.concatenate(wav_list, axis=0) + logger.info("The durations of audio is: {} s".format(len(wav_all)/self.executor.am_config.fs)) + + + + diff --git a/paddlespeech/server/restful/tts_api.py b/paddlespeech/server/restful/tts_api.py index 4e9bbe23e..d1268428a 100644 --- a/paddlespeech/server/restful/tts_api.py +++ b/paddlespeech/server/restful/tts_api.py @@ -15,6 +15,7 @@ import traceback from typing import Union from fastapi import APIRouter +from fastapi.responses import StreamingResponse from paddlespeech.cli.log import logger from paddlespeech.server.engine.engine_pool import get_engine_pool @@ -125,3 +126,14 @@ def tts(request_body: TTSRequest): traceback.print_exc() return response + + +@router.post("/paddlespeech/streaming/tts") +async def stream_tts(request_body: TTSRequest): + text = request_body.text + + engine_pool = get_engine_pool() + tts_engine = engine_pool['tts'] + logger.info("Get tts engine successfully.") + + return StreamingResponse(tts_engine.run(sentence=text)) diff --git a/paddlespeech/server/tests/tts/test_client.py b/paddlespeech/server/tests/tts/offline/http_client.py similarity index 90% rename from paddlespeech/server/tests/tts/test_client.py rename to paddlespeech/server/tests/tts/offline/http_client.py index e42c9bcfa..1bdee4c18 100644 --- a/paddlespeech/server/tests/tts/test_client.py +++ b/paddlespeech/server/tests/tts/offline/http_client.py @@ -33,7 +33,8 @@ def tts_client(args): text: A sentence to be synthesized outfile: Synthetic audio file """ - url = 'http://127.0.0.1:8090/paddlespeech/tts' + url = "http://" + str(args.server) + ":" + str( + args.port) + "/paddlespeech/tts" request = { "text": args.text, "spk_id": args.spk_id, @@ -72,7 +73,7 @@ if __name__ == "__main__": parser.add_argument( '--text', type=str, - default="你好,欢迎使用语音合成服务", + default="您好,欢迎使用语音合成服务。", help='A sentence to be synthesized') parser.add_argument('--spk_id', type=int, default=0, help='Speaker id') parser.add_argument('--speed', type=float, default=1.0, help='Audio speed') @@ -88,6 +89,9 @@ if __name__ == "__main__": type=str, default="./out.wav", help='Synthesized audio file') + parser.add_argument( + "--server", type=str, help="server ip", default="127.0.0.1") + parser.add_argument("--port", type=int, help="server port", default=8090) args = parser.parse_args() st = time.time() diff --git a/paddlespeech/server/tests/tts/online/http_client.py b/paddlespeech/server/tests/tts/online/http_client.py new file mode 100644 index 000000000..cbc1f5c02 --- /dev/null +++ b/paddlespeech/server/tests/tts/online/http_client.py @@ -0,0 +1,100 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import base64 +import json +import os +import time + +import requests + +from paddlespeech.server.utils.audio_process import pcm2wav + + +def save_audio(buffer, audio_path) -> bool: + if args.save_path.endswith("pcm"): + with open(args.save_path, "wb") as f: + f.write(buffer) + elif args.save_path.endswith("wav"): + with open("./tmp.pcm", "wb") as f: + f.write(buffer) + pcm2wav("./tmp.pcm", audio_path, channels=1, bits=16, sample_rate=24000) + os.system("rm ./tmp.pcm") + else: + print("Only supports saved audio format is pcm or wav") + return False + + return True + + +def test(args): + params = { + "text": args.text, + "spk_id": args.spk_id, + "speed": args.speed, + "volume": args.volume, + "sample_rate": args.sample_rate, + "save_path": '' + } + + buffer = b'' + flag = 1 + url = "http://" + str(args.server) + ":" + str( + args.port) + "/paddlespeech/streaming/tts" + st = time.time() + html = requests.post(url, json.dumps(params), stream=True) + for chunk in html.iter_content(chunk_size=1024): + chunk = base64.b64decode(chunk) # bytes + if flag: + first_response = time.time() - st + print(f"首包响应:{first_response} s") + flag = 0 + buffer += chunk + + final_response = time.time() - st + duration = len(buffer) / 2.0 / 24000 + + print(f"尾包响应:{final_response} s") + print(f"音频时长:{duration} s") + print(f"RTF: {final_response / duration}") + + if args.save_path is not None: + if save_audio(buffer, args.save_path): + print("音频保存至:", args.save_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + '--text', + type=str, + default="您好,欢迎使用语音合成服务。", + help='A sentence to be synthesized') + parser.add_argument('--spk_id', type=int, default=0, help='Speaker id') + parser.add_argument('--speed', type=float, default=1.0, help='Audio speed') + parser.add_argument( + '--volume', type=float, default=1.0, help='Audio volume') + parser.add_argument( + '--sample_rate', + type=int, + default=0, + help='Sampling rate, the default is the same as the model') + parser.add_argument( + "--server", type=str, help="server ip", default="127.0.0.1") + parser.add_argument("--port", type=int, help="server port", default=8092) + parser.add_argument( + "--save_path", type=str, help="save audio path", default=None) + + args = parser.parse_args() + test(args) diff --git a/paddlespeech/server/tests/tts/online/http_client_playaudio.py b/paddlespeech/server/tests/tts/online/http_client_playaudio.py new file mode 100644 index 000000000..1e7e8064e --- /dev/null +++ b/paddlespeech/server/tests/tts/online/http_client_playaudio.py @@ -0,0 +1,112 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import base64 +import json +import threading +import time + +import pyaudio +import requests + +mutex = threading.Lock() +buffer = b'' +p = pyaudio.PyAudio() +stream = p.open( + format=p.get_format_from_width(2), channels=1, rate=24000, output=True) +max_fail = 50 + + +def play_audio(): + global stream + global buffer + global max_fail + while True: + if not buffer: + max_fail -= 1 + time.sleep(0.05) + if max_fail < 0: + break + mutex.acquire() + stream.write(buffer) + buffer = b'' + mutex.release() + + +def test(args): + global mutex + global buffer + params = { + "text": args.text, + "spk_id": args.spk_id, + "speed": args.speed, + "volume": args.volume, + "sample_rate": args.sample_rate, + "save_path": '' + } + + all_bytes = 0.0 + t = threading.Thread(target=play_audio) + flag = 1 + url = "http://" + str(args.server) + ":" + str( + args.port) + "/paddlespeech/streaming/tts" + st = time.time() + html = requests.post(url, json.dumps(params), stream=True) + for chunk in html.iter_content(chunk_size=1024): + mutex.acquire() + chunk = base64.b64decode(chunk) # bytes + buffer += chunk + mutex.release() + if flag: + first_response = time.time() - st + print(f"首包响应:{first_response} s") + flag = 0 + t.start() + all_bytes += len(chunk) + + final_response = time.time() - st + duration = all_bytes / 2 / 24000 + + print(f"尾包响应:{final_response} s") + print(f"音频时长:{duration} s") + print(f"RTF: {final_response / duration}") + + t.join() + stream.stop_stream() + stream.close() + p.terminate() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + '--text', + type=str, + default="您好,欢迎使用语音合成服务。", + help='A sentence to be synthesized') + parser.add_argument('--spk_id', type=int, default=0, help='Speaker id') + parser.add_argument('--speed', type=float, default=1.0, help='Audio speed') + parser.add_argument( + '--volume', type=float, default=1.0, help='Audio volume') + parser.add_argument( + '--sample_rate', + type=int, + default=0, + help='Sampling rate, the default is the same as the model') + parser.add_argument( + "--server", type=str, help="server ip", default="127.0.0.1") + parser.add_argument("--port", type=int, help="server port", default=8092) + + args = parser.parse_args() + test(args) diff --git a/paddlespeech/server/tests/tts/online/out.pcm b/paddlespeech/server/tests/tts/online/out.pcm new file mode 100644 index 000000000..a52377f82 Binary files /dev/null and b/paddlespeech/server/tests/tts/online/out.pcm differ diff --git a/paddlespeech/server/tests/tts/online/ws_client.py b/paddlespeech/server/tests/tts/online/ws_client.py new file mode 100644 index 000000000..e0f47b551 --- /dev/null +++ b/paddlespeech/server/tests/tts/online/ws_client.py @@ -0,0 +1,126 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import _thread as thread +import argparse +import base64 +import json +import ssl +import time + +import websocket + +flag = 1 +st = 0.0 +all_bytes = b'' + + +class Ws_Param(object): + # 初始化 + def __init__(self, text, server="127.0.0.1", port=8090): + self.server = server + self.port = port + self.url = "ws://" + self.server + ":" + str(self.port) + "/ws/tts" + self.text = text + + # 生成url + def create_url(self): + return self.url + + +def on_message(ws, message): + global flag + global st + global all_bytes + + try: + message = json.loads(message) + audio = message["audio"] + audio = base64.b64decode(audio) # bytes + status = message["status"] + all_bytes += audio + + if status == 0: + print("create successfully.") + elif status == 1: + if flag: + print(f"首包响应:{time.time() - st} s") + flag = 0 + elif status == 2: + final_response = time.time() - st + duration = len(all_bytes) / 2.0 / 24000 + print(f"尾包响应:{final_response} s") + print(f"音频时长:{duration} s") + print(f"RTF: {final_response / duration}") + with open("./out.pcm", "wb") as f: + f.write(all_bytes) + print("ws is closed") + ws.close() + else: + print("infer error") + + except Exception as e: + print("receive msg,but parse exception:", e) + + +# 收到websocket错误的处理 +def on_error(ws, error): + print("### error:", error) + + +# 收到websocket关闭的处理 +def on_close(ws): + print("### closed ###") + + +# 收到websocket连接建立的处理 +def on_open(ws): + def run(*args): + global st + text_base64 = str( + base64.b64encode((wsParam.text).encode('utf-8')), "UTF8") + d = {"text": text_base64} + d = json.dumps(d) + print("Start sending text data") + st = time.time() + ws.send(d) + + thread.start_new_thread(run, ()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--text", + type=str, + help="A sentence to be synthesized", + default="您好,欢迎使用语音合成服务。") + parser.add_argument( + "--server", type=str, help="server ip", default="127.0.0.1") + parser.add_argument("--port", type=int, help="server port", default=8092) + args = parser.parse_args() + + print("***************************************") + print("Server ip: ", args.server) + print("Server port: ", args.port) + print("Sentence to be synthesized: ", args.text) + print("***************************************") + + wsParam = Ws_Param(text=args.text, server=args.server, port=args.port) + + websocket.enableTrace(False) + wsUrl = wsParam.create_url() + ws = websocket.WebSocketApp( + wsUrl, on_message=on_message, on_error=on_error, on_close=on_close) + ws.on_open = on_open + ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) diff --git a/paddlespeech/server/tests/tts/online/ws_client_playaudio.py b/paddlespeech/server/tests/tts/online/ws_client_playaudio.py new file mode 100644 index 000000000..4e1c538d1 --- /dev/null +++ b/paddlespeech/server/tests/tts/online/ws_client_playaudio.py @@ -0,0 +1,160 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import _thread as thread +import argparse +import base64 +import json +import ssl +import threading +import time + +import pyaudio +import websocket + +mutex = threading.Lock() +buffer = b'' +p = pyaudio.PyAudio() +stream = p.open( + format=p.get_format_from_width(2), channels=1, rate=24000, output=True) +flag = 1 +st = 0.0 +all_bytes = 0.0 + + +class Ws_Param(object): + # 初始化 + def __init__(self, text, server="127.0.0.1", port=8090): + self.server = server + self.port = port + self.url = "ws://" + self.server + ":" + str(self.port) + "/ws/tts" + self.text = text + + # 生成url + def create_url(self): + return self.url + + +def play_audio(): + global stream + global buffer + while True: + time.sleep(0.05) + if not buffer: # buffer 为空 + break + mutex.acquire() + stream.write(buffer) + buffer = b'' + mutex.release() + + +t = threading.Thread(target=play_audio) + + +def on_message(ws, message): + global flag + global t + global buffer + global st + global all_bytes + + try: + message = json.loads(message) + audio = message["audio"] + audio = base64.b64decode(audio) # bytes + status = message["status"] + all_bytes += len(audio) + + if status == 0: + print("create successfully.") + elif status == 1: + mutex.acquire() + buffer += audio + mutex.release() + if flag: + print(f"首包响应:{time.time() - st} s") + flag = 0 + print("Start playing audio") + t.start() + elif status == 2: + final_response = time.time() - st + duration = all_bytes / 2 / 24000 + print(f"尾包响应:{final_response} s") + print(f"音频时长:{duration} s") + print(f"RTF: {final_response / duration}") + print("ws is closed") + ws.close() + else: + print("infer error") + + except Exception as e: + print("receive msg,but parse exception:", e) + + +# 收到websocket错误的处理 +def on_error(ws, error): + print("### error:", error) + + +# 收到websocket关闭的处理 +def on_close(ws): + print("### closed ###") + + +# 收到websocket连接建立的处理 +def on_open(ws): + def run(*args): + global st + text_base64 = str( + base64.b64encode((wsParam.text).encode('utf-8')), "UTF8") + d = {"text": text_base64} + d = json.dumps(d) + print("Start sending text data") + st = time.time() + ws.send(d) + + thread.start_new_thread(run, ()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--text", + type=str, + help="A sentence to be synthesized", + default="您好,欢迎使用语音合成服务。") + parser.add_argument( + "--server", type=str, help="server ip", default="127.0.0.1") + parser.add_argument("--port", type=int, help="server port", default=8092) + args = parser.parse_args() + + print("***************************************") + print("Server ip: ", args.server) + print("Server port: ", args.port) + print("Sentence to be synthesized: ", args.text) + print("***************************************") + + wsParam = Ws_Param(text=args.text, server=args.server, port=args.port) + + websocket.enableTrace(False) + wsUrl = wsParam.create_url() + ws = websocket.WebSocketApp( + wsUrl, on_message=on_message, on_error=on_error, on_close=on_close) + ws.on_open = on_open + ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) + + t.join() + print("End of playing audio") + stream.stop_stream() + stream.close() + p.terminate() diff --git a/paddlespeech/server/utils/audio_process.py b/paddlespeech/server/utils/audio_process.py index 3cbb495a6..1d4b158c5 100644 --- a/paddlespeech/server/utils/audio_process.py +++ b/paddlespeech/server/utils/audio_process.py @@ -103,3 +103,26 @@ def change_speed(sample_raw, speed_rate, sample_rate): sample_rate_in=sample_rate).squeeze(-1).astype(np.float32).copy() return sample_speed + + +def float2pcm(sig, dtype='int16'): + """Convert floating point signal with a range from -1 to 1 to PCM. + + Args: + sig (array): Input array, must have floating point type. + dtype (str, optional): Desired (integer) data type. Defaults to 'int16'. + + Returns: + numpy.ndarray: Integer data, scaled and clipped to the range of the given + """ + sig = np.asarray(sig) + if sig.dtype.kind != 'f': + raise TypeError("'sig' must be a float array") + dtype = np.dtype(dtype) + if dtype.kind not in 'iu': + raise TypeError("'dtype' must be an integer type") + + i = np.iinfo(dtype) + abs_max = 2**(i.bits - 1) + offset = i.min + abs_max + return (sig * abs_max + offset).clip(i.min, i.max).astype(dtype) diff --git a/paddlespeech/server/utils/util.py b/paddlespeech/server/utils/util.py index e9104fa2d..c35939b74 100644 --- a/paddlespeech/server/utils/util.py +++ b/paddlespeech/server/utils/util.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the import base64 +import math def wav2base64(wav_file: str): @@ -31,3 +32,29 @@ def self_check(): """ self check resource """ return True + + +def denorm(data, mean, std): + return data * std + mean + + +def get_chunks(data, block_size, pad_size, step): + if step == "am": + data_len = data.shape[1] + elif step == "voc": + data_len = data.shape[0] + else: + print("Please set correct type to get chunks, am or voc") + + chunks = [] + n = math.ceil(data_len / block_size) + for i in range(n): + start = max(0, i * block_size - pad_size) + end = min((i + 1) * block_size + pad_size, data_len) + if step == "am": + chunks.append(data[:, start:end, :]) + elif step == "voc": + chunks.append(data[start:end, :]) + else: + print("Please set correct type to get chunks, am or voc") + return chunks diff --git a/paddlespeech/server/ws/api.py b/paddlespeech/server/ws/api.py index 10664d114..313fd16f5 100644 --- a/paddlespeech/server/ws/api.py +++ b/paddlespeech/server/ws/api.py @@ -16,6 +16,7 @@ from typing import List from fastapi import APIRouter from paddlespeech.server.ws.asr_socket import router as asr_router +from paddlespeech.server.ws.tts_socket import router as tts_router _router = APIRouter() @@ -31,7 +32,7 @@ def setup_router(api_list: List): if api_name == 'asr': _router.include_router(asr_router) elif api_name == 'tts': - pass + _router.include_router(tts_router) else: pass diff --git a/paddlespeech/server/ws/tts_socket.py b/paddlespeech/server/ws/tts_socket.py new file mode 100644 index 000000000..4df2850af --- /dev/null +++ b/paddlespeech/server/ws/tts_socket.py @@ -0,0 +1,62 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json + +from fastapi import APIRouter +from fastapi import WebSocket +from fastapi import WebSocketDisconnect +from starlette.websockets import WebSocketState as WebSocketState + +from paddlespeech.cli.log import logger +from paddlespeech.server.engine.engine_pool import get_engine_pool + +router = APIRouter() + + +@router.websocket('/ws/tts') +async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + + try: + # careful here, changed the source code from starlette.websockets + assert websocket.application_state == WebSocketState.CONNECTED + message = await websocket.receive() + websocket._raise_on_disconnect(message) + + # get engine + engine_pool = get_engine_pool() + tts_engine = engine_pool['tts'] + + # 获取 message 并转文本 + message = json.loads(message["text"]) + text_bese64 = message["text"] + sentence = tts_engine.preprocess(text_bese64=text_bese64) + + # run + wav = tts_engine.run(sentence) + + while True: + try: + tts_results = next(wav) + resp = {"status": 1, "audio": tts_results} + await websocket.send_json(resp) + logger.info("streaming audio...") + except StopIteration as e: + resp = {"status": 2, "audio": ''} + await websocket.send_json(resp) + logger.info("Complete the transmission of audio streams") + break + + except WebSocketDisconnect: + pass