add stream tts server, test=doc

pull/1652/head
lym0302 3 years ago
parent 3c8f30c7a4
commit 603e565ab1

@ -0,0 +1,46 @@
# This is the parameter configuration file for PaddleSpeech Serving.
#################################################################################
# SERVER SETTING #
#################################################################################
host: 127.0.0.1
port: 8092
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online', 'tts_online']
# protocol = ['websocket', 'http'] (only one can be selected).
protocol: 'http'
engine_list: ['tts_online']
#################################################################################
# ENGINE CONFIG #
#################################################################################
################################### TTS #########################################
################### speech task: tts; engine_type: online #######################
tts_online:
# am (acoustic model) choices=['fastspeech2_csmsc']
am: 'fastspeech2_csmsc'
am_config:
am_ckpt:
am_stat:
phones_dict:
tones_dict:
speaker_dict:
spk_id: 0
# voc (vocoder) choices=['mb_melgan_csmsc']
voc: 'mb_melgan_csmsc'
voc_config:
voc_ckpt:
voc_stat:
# others
lang: 'zh'
device: # set 'gpu:id' or 'cpu'
am_block: 42
am_pad: 12
voc_block: 14
voc_pad: 14

@ -34,6 +34,9 @@ class EngineFactory(object):
elif engine_name == 'tts' and engine_type == 'python': elif engine_name == 'tts' and engine_type == 'python':
from paddlespeech.server.engine.tts.python.tts_engine import TTSEngine from paddlespeech.server.engine.tts.python.tts_engine import TTSEngine
return TTSEngine() return TTSEngine()
elif engine_name == 'tts' and engine_type == 'online':
from paddlespeech.server.engine.tts.online.tts_engine import TTSEngine
return TTSEngine()
elif engine_name == 'cls' and engine_type == 'inference': elif engine_name == 'cls' and engine_type == 'inference':
from paddlespeech.server.engine.cls.paddleinference.cls_engine import CLSEngine from paddlespeech.server.engine.cls.paddleinference.cls_engine import CLSEngine
return CLSEngine() return CLSEngine()

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,305 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import io
import time
import librosa
import numpy as np
import paddle
import soundfile as sf
from scipy.io import wavfile
from paddlespeech.cli.log import logger
from paddlespeech.cli.tts.infer import TTSExecutor
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.audio_process import change_speed
from paddlespeech.server.utils.errors import ErrorCode
from paddlespeech.server.utils.exception import ServerBaseException
from paddlespeech.server.utils.audio_process import float2pcm
from paddlespeech.server.utils.config import get_config
from paddlespeech.server.utils.util import denorm
from paddlespeech.server.utils.util import get_chunks
import math
__all__ = ['TTSEngine']
class TTSServerExecutor(TTSExecutor):
def __init__(self):
super().__init__()
pass
@paddle.no_grad()
def infer(self,
text: str,
lang: str='zh',
am: str='fastspeech2_csmsc',
spk_id: int=0,
am_block: int=42,
am_pad: int=12,
voc_block: int=14,
voc_pad: int=14,):
"""
Model inference and result stored in self.output.
"""
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
get_tone_ids = False
merge_sentences = False
frontend_st = time.time()
if am_name == 'speedyspeech':
get_tone_ids = True
if lang == 'zh':
input_ids = self.frontend.get_input_ids(
text,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
elif lang == 'en':
input_ids = self.frontend.get_input_ids(
text, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en'}!")
self.frontend_time = time.time() - frontend_st
for i in range(len(phone_ids)):
am_st = time.time()
part_phone_ids = phone_ids[i]
# am
if am_name == 'speedyspeech':
part_tone_ids = tone_ids[i]
mel = self.am_inference(part_phone_ids, part_tone_ids)
# fastspeech2
else:
# multi speaker
if am_dataset in {"aishell3", "vctk"}:
mel = self.am_inference(
part_phone_ids, spk_id=paddle.to_tensor(spk_id))
else:
mel = self.am_inference(part_phone_ids)
am_et = time.time()
# voc streaming
voc_upsample = self.voc_config.n_shift
mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc")
chunk_num = len(mel_chunks)
voc_st = time.time()
for i, mel_chunk in enumerate(mel_chunks):
sub_wav = self.voc_inference(mel_chunk)
front_pad = min(i*voc_block, voc_pad)
if i == 0:
sub_wav = sub_wav[: voc_block * voc_upsample]
elif i == chunk_num - 1:
sub_wav = sub_wav[front_pad * voc_upsample : ]
else:
sub_wav = sub_wav[front_pad * voc_upsample: (front_pad + voc_block) * voc_upsample]
yield sub_wav
class TTSEngine(BaseEngine):
"""TTS server engine
Args:
metaclass: Defaults to Singleton.
"""
def __init__(self, name=None):
"""Initialize TTS server engine
"""
super(TTSEngine, self).__init__()
def init(self, config: dict) -> bool:
self.executor = TTSServerExecutor()
try:
self.config = config
if self.config.device:
self.device = self.config.device
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except Exception as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
return False
try:
self.executor._init_from_path(
am=self.config.am,
am_config=self.config.am_config,
am_ckpt=self.config.am_ckpt,
am_stat=self.config.am_stat,
phones_dict=self.config.phones_dict,
tones_dict=self.config.tones_dict,
speaker_dict=self.config.speaker_dict,
voc=self.config.voc,
voc_config=self.config.voc_config,
voc_ckpt=self.config.voc_ckpt,
voc_stat=self.config.voc_stat,
lang=self.config.lang)
except Exception as e:
logger.error("Failed to get model related files.")
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
return False
self.am_block = self.config.am_block
self.am_pad = self.config.am_pad
self.voc_block = self.config.voc_block
self.voc_pad = self.config.voc_pad
logger.info("Initialize TTS server engine successfully on device: %s." %
(self.device))
return True
def preprocess(self, text_bese64: str=None, text_bytes: bytes=None):
# Convert byte to text
if text_bese64:
text_bytes = base64.b64decode(text_bese64) # base64 to bytes
text = text_bytes.decode('utf-8') # bytes to text
return text
def postprocess(self,
wav,
original_fs: int,
target_fs: int=0,
volume: float=1.0,
speed: float=1.0,
audio_path: str=None):
"""Post-processing operations, including speech, volume, sample rate, save audio file
Args:
wav (numpy(float)): Synthesized audio sample points
original_fs (int): original audio sample rate
target_fs (int): target audio sample rate
volume (float): target volume
speed (float): target speed
Raises:
ServerBaseException: Throws an exception if the change speed unsuccessfully.
Returns:
target_fs: target sample rate for synthesized audio.
wav_base64: The base64 format of the synthesized audio.
"""
# transform sample_rate
if target_fs == 0 or target_fs > original_fs:
target_fs = original_fs
wav_tar_fs = wav
logger.info(
"The sample rate of synthesized audio is the same as model, which is {}Hz".
format(original_fs))
else:
wav_tar_fs = librosa.resample(
np.squeeze(wav), original_fs, target_fs)
logger.info(
"The sample rate of model is {}Hz and the target sample rate is {}Hz. Converting the sample rate of the synthesized audio successfully.".
format(original_fs, target_fs))
# transform volume
wav_vol = wav_tar_fs * volume
logger.info("Transform the volume of the audio successfully.")
# transform speed
try: # windows not support soxbindings
wav_speed = change_speed(wav_vol, speed, target_fs)
logger.info("Transform the speed of the audio successfully.")
except ServerBaseException:
raise ServerBaseException(
ErrorCode.SERVER_INTERNAL_ERR,
"Failed to transform speed. Can not install soxbindings on your system. \
You need to set speed value 1.0.")
except BaseException:
logger.error("Failed to transform speed.")
# wav to base64
buf = io.BytesIO()
wavfile.write(buf, target_fs, wav_speed)
base64_bytes = base64.b64encode(buf.read())
wav_base64 = base64_bytes.decode('utf-8')
logger.info("Audio to string successfully.")
# save audio
if audio_path is not None:
if audio_path.endswith(".wav"):
sf.write(audio_path, wav_speed, target_fs)
elif audio_path.endswith(".pcm"):
wav_norm = wav_speed * (32767 / max(0.001,
np.max(np.abs(wav_speed))))
with open(audio_path, "wb") as f:
f.write(wav_norm.astype(np.int16))
logger.info("Save audio to {} successfully.".format(audio_path))
else:
logger.info("There is no need to save audio.")
return target_fs, wav_base64
def run(self,
sentence: str,
spk_id: int=0,
speed: float=1.0,
volume: float=1.0,
sample_rate: int=0,
save_path: str=None):
""" run include inference and postprocess.
Args:
sentence (str): text to be synthesized
spk_id (int, optional): speaker id for multi-speaker speech synthesis. Defaults to 0.
speed (float, optional): speed. Defaults to 1.0.
volume (float, optional): volume. Defaults to 1.0.
sample_rate (int, optional): target sample rate for synthesized audio,
0 means the same as the model sampling rate. Defaults to 0.
save_path (str, optional): The save path of the synthesized audio.
None means do not save audio. Defaults to None.
Raises:
ServerBaseException: Throws an exception if tts inference unsuccessfully.
ServerBaseException: Throws an exception if postprocess unsuccessfully.
Returns:
lang: model language
target_sample_rate: target sample rate for synthesized audio.
wav_base64: The base64 format of the synthesized audio.
"""
lang = self.config.lang
wav_list = []
for wav in self.executor.infer(text=sentence, lang=lang, am=self.config.am, spk_id=spk_id, am_block=self.am_block, am_pad=self.am_pad, voc_block=self.voc_block, voc_pad=self.voc_pad):
# wav type: <class 'numpy.ndarray'> float32, convert to pcm (base64)
wav = float2pcm(wav) # float32 to int16
wav_bytes = wav.tobytes() # to bytes
wav_base64 = base64.b64encode(wav_bytes).decode('utf8') # to base64
wav_list.append(wav)
yield wav_base64
wav_all = np.concatenate(wav_list, axis=0)
logger.info("The durations of audio is: {} s".format(len(wav_all)/self.executor.am_config.fs))

@ -15,6 +15,7 @@ import traceback
from typing import Union from typing import Union
from fastapi import APIRouter from fastapi import APIRouter
from fastapi.responses import StreamingResponse
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.server.engine.engine_pool import get_engine_pool from paddlespeech.server.engine.engine_pool import get_engine_pool
@ -125,3 +126,14 @@ def tts(request_body: TTSRequest):
traceback.print_exc() traceback.print_exc()
return response return response
@router.post("/paddlespeech/streaming/tts")
async def stream_tts(request_body: TTSRequest):
text = request_body.text
engine_pool = get_engine_pool()
tts_engine = engine_pool['tts']
logger.info("Get tts engine successfully.")
return StreamingResponse(tts_engine.run(sentence=text))

@ -33,7 +33,8 @@ def tts_client(args):
text: A sentence to be synthesized text: A sentence to be synthesized
outfile: Synthetic audio file outfile: Synthetic audio file
""" """
url = 'http://127.0.0.1:8090/paddlespeech/tts' url = "http://" + str(args.server) + ":" + str(
args.port) + "/paddlespeech/tts"
request = { request = {
"text": args.text, "text": args.text,
"spk_id": args.spk_id, "spk_id": args.spk_id,
@ -72,7 +73,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
'--text', '--text',
type=str, type=str,
default="你好,欢迎使用语音合成服务", default="您好,欢迎使用语音合成服务。",
help='A sentence to be synthesized') help='A sentence to be synthesized')
parser.add_argument('--spk_id', type=int, default=0, help='Speaker id') parser.add_argument('--spk_id', type=int, default=0, help='Speaker id')
parser.add_argument('--speed', type=float, default=1.0, help='Audio speed') parser.add_argument('--speed', type=float, default=1.0, help='Audio speed')
@ -88,6 +89,9 @@ if __name__ == "__main__":
type=str, type=str,
default="./out.wav", default="./out.wav",
help='Synthesized audio file') help='Synthesized audio file')
parser.add_argument(
"--server", type=str, help="server ip", default="127.0.0.1")
parser.add_argument("--port", type=int, help="server port", default=8090)
args = parser.parse_args() args = parser.parse_args()
st = time.time() st = time.time()

@ -0,0 +1,100 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import base64
import json
import os
import time
import requests
from paddlespeech.server.utils.audio_process import pcm2wav
def save_audio(buffer, audio_path) -> bool:
if args.save_path.endswith("pcm"):
with open(args.save_path, "wb") as f:
f.write(buffer)
elif args.save_path.endswith("wav"):
with open("./tmp.pcm", "wb") as f:
f.write(buffer)
pcm2wav("./tmp.pcm", audio_path, channels=1, bits=16, sample_rate=24000)
os.system("rm ./tmp.pcm")
else:
print("Only supports saved audio format is pcm or wav")
return False
return True
def test(args):
params = {
"text": args.text,
"spk_id": args.spk_id,
"speed": args.speed,
"volume": args.volume,
"sample_rate": args.sample_rate,
"save_path": ''
}
buffer = b''
flag = 1
url = "http://" + str(args.server) + ":" + str(
args.port) + "/paddlespeech/streaming/tts"
st = time.time()
html = requests.post(url, json.dumps(params), stream=True)
for chunk in html.iter_content(chunk_size=1024):
chunk = base64.b64decode(chunk) # bytes
if flag:
first_response = time.time() - st
print(f"首包响应:{first_response} s")
flag = 0
buffer += chunk
final_response = time.time() - st
duration = len(buffer) / 2.0 / 24000
print(f"尾包响应:{final_response} s")
print(f"音频时长:{duration} s")
print(f"RTF: {final_response / duration}")
if args.save_path is not None:
if save_audio(buffer, args.save_path):
print("音频保存至:", args.save_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--text',
type=str,
default="您好,欢迎使用语音合成服务。",
help='A sentence to be synthesized')
parser.add_argument('--spk_id', type=int, default=0, help='Speaker id')
parser.add_argument('--speed', type=float, default=1.0, help='Audio speed')
parser.add_argument(
'--volume', type=float, default=1.0, help='Audio volume')
parser.add_argument(
'--sample_rate',
type=int,
default=0,
help='Sampling rate, the default is the same as the model')
parser.add_argument(
"--server", type=str, help="server ip", default="127.0.0.1")
parser.add_argument("--port", type=int, help="server port", default=8092)
parser.add_argument(
"--save_path", type=str, help="save audio path", default=None)
args = parser.parse_args()
test(args)

@ -0,0 +1,112 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import base64
import json
import threading
import time
import pyaudio
import requests
mutex = threading.Lock()
buffer = b''
p = pyaudio.PyAudio()
stream = p.open(
format=p.get_format_from_width(2), channels=1, rate=24000, output=True)
max_fail = 50
def play_audio():
global stream
global buffer
global max_fail
while True:
if not buffer:
max_fail -= 1
time.sleep(0.05)
if max_fail < 0:
break
mutex.acquire()
stream.write(buffer)
buffer = b''
mutex.release()
def test(args):
global mutex
global buffer
params = {
"text": args.text,
"spk_id": args.spk_id,
"speed": args.speed,
"volume": args.volume,
"sample_rate": args.sample_rate,
"save_path": ''
}
all_bytes = 0.0
t = threading.Thread(target=play_audio)
flag = 1
url = "http://" + str(args.server) + ":" + str(
args.port) + "/paddlespeech/streaming/tts"
st = time.time()
html = requests.post(url, json.dumps(params), stream=True)
for chunk in html.iter_content(chunk_size=1024):
mutex.acquire()
chunk = base64.b64decode(chunk) # bytes
buffer += chunk
mutex.release()
if flag:
first_response = time.time() - st
print(f"首包响应:{first_response} s")
flag = 0
t.start()
all_bytes += len(chunk)
final_response = time.time() - st
duration = all_bytes / 2 / 24000
print(f"尾包响应:{final_response} s")
print(f"音频时长:{duration} s")
print(f"RTF: {final_response / duration}")
t.join()
stream.stop_stream()
stream.close()
p.terminate()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--text',
type=str,
default="您好,欢迎使用语音合成服务。",
help='A sentence to be synthesized')
parser.add_argument('--spk_id', type=int, default=0, help='Speaker id')
parser.add_argument('--speed', type=float, default=1.0, help='Audio speed')
parser.add_argument(
'--volume', type=float, default=1.0, help='Audio volume')
parser.add_argument(
'--sample_rate',
type=int,
default=0,
help='Sampling rate, the default is the same as the model')
parser.add_argument(
"--server", type=str, help="server ip", default="127.0.0.1")
parser.add_argument("--port", type=int, help="server port", default=8092)
args = parser.parse_args()
test(args)

@ -0,0 +1,126 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import _thread as thread
import argparse
import base64
import json
import ssl
import time
import websocket
flag = 1
st = 0.0
all_bytes = b''
class Ws_Param(object):
# 初始化
def __init__(self, text, server="127.0.0.1", port=8090):
self.server = server
self.port = port
self.url = "ws://" + self.server + ":" + str(self.port) + "/ws/tts"
self.text = text
# 生成url
def create_url(self):
return self.url
def on_message(ws, message):
global flag
global st
global all_bytes
try:
message = json.loads(message)
audio = message["audio"]
audio = base64.b64decode(audio) # bytes
status = message["status"]
all_bytes += audio
if status == 0:
print("create successfully.")
elif status == 1:
if flag:
print(f"首包响应:{time.time() - st} s")
flag = 0
elif status == 2:
final_response = time.time() - st
duration = len(all_bytes) / 2.0 / 24000
print(f"尾包响应:{final_response} s")
print(f"音频时长:{duration} s")
print(f"RTF: {final_response / duration}")
with open("./out.pcm", "wb") as f:
f.write(all_bytes)
print("ws is closed")
ws.close()
else:
print("infer error")
except Exception as e:
print("receive msg,but parse exception:", e)
# 收到websocket错误的处理
def on_error(ws, error):
print("### error:", error)
# 收到websocket关闭的处理
def on_close(ws):
print("### closed ###")
# 收到websocket连接建立的处理
def on_open(ws):
def run(*args):
global st
text_base64 = str(
base64.b64encode((wsParam.text).encode('utf-8')), "UTF8")
d = {"text": text_base64}
d = json.dumps(d)
print("Start sending text data")
st = time.time()
ws.send(d)
thread.start_new_thread(run, ())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--text",
type=str,
help="A sentence to be synthesized",
default="您好,欢迎使用语音合成服务。")
parser.add_argument(
"--server", type=str, help="server ip", default="127.0.0.1")
parser.add_argument("--port", type=int, help="server port", default=8092)
args = parser.parse_args()
print("***************************************")
print("Server ip: ", args.server)
print("Server port: ", args.port)
print("Sentence to be synthesized: ", args.text)
print("***************************************")
wsParam = Ws_Param(text=args.text, server=args.server, port=args.port)
websocket.enableTrace(False)
wsUrl = wsParam.create_url()
ws = websocket.WebSocketApp(
wsUrl, on_message=on_message, on_error=on_error, on_close=on_close)
ws.on_open = on_open
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})

@ -0,0 +1,160 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import _thread as thread
import argparse
import base64
import json
import ssl
import threading
import time
import pyaudio
import websocket
mutex = threading.Lock()
buffer = b''
p = pyaudio.PyAudio()
stream = p.open(
format=p.get_format_from_width(2), channels=1, rate=24000, output=True)
flag = 1
st = 0.0
all_bytes = 0.0
class Ws_Param(object):
# 初始化
def __init__(self, text, server="127.0.0.1", port=8090):
self.server = server
self.port = port
self.url = "ws://" + self.server + ":" + str(self.port) + "/ws/tts"
self.text = text
# 生成url
def create_url(self):
return self.url
def play_audio():
global stream
global buffer
while True:
time.sleep(0.05)
if not buffer: # buffer 为空
break
mutex.acquire()
stream.write(buffer)
buffer = b''
mutex.release()
t = threading.Thread(target=play_audio)
def on_message(ws, message):
global flag
global t
global buffer
global st
global all_bytes
try:
message = json.loads(message)
audio = message["audio"]
audio = base64.b64decode(audio) # bytes
status = message["status"]
all_bytes += len(audio)
if status == 0:
print("create successfully.")
elif status == 1:
mutex.acquire()
buffer += audio
mutex.release()
if flag:
print(f"首包响应:{time.time() - st} s")
flag = 0
print("Start playing audio")
t.start()
elif status == 2:
final_response = time.time() - st
duration = all_bytes / 2 / 24000
print(f"尾包响应:{final_response} s")
print(f"音频时长:{duration} s")
print(f"RTF: {final_response / duration}")
print("ws is closed")
ws.close()
else:
print("infer error")
except Exception as e:
print("receive msg,but parse exception:", e)
# 收到websocket错误的处理
def on_error(ws, error):
print("### error:", error)
# 收到websocket关闭的处理
def on_close(ws):
print("### closed ###")
# 收到websocket连接建立的处理
def on_open(ws):
def run(*args):
global st
text_base64 = str(
base64.b64encode((wsParam.text).encode('utf-8')), "UTF8")
d = {"text": text_base64}
d = json.dumps(d)
print("Start sending text data")
st = time.time()
ws.send(d)
thread.start_new_thread(run, ())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--text",
type=str,
help="A sentence to be synthesized",
default="您好,欢迎使用语音合成服务。")
parser.add_argument(
"--server", type=str, help="server ip", default="127.0.0.1")
parser.add_argument("--port", type=int, help="server port", default=8092)
args = parser.parse_args()
print("***************************************")
print("Server ip: ", args.server)
print("Server port: ", args.port)
print("Sentence to be synthesized: ", args.text)
print("***************************************")
wsParam = Ws_Param(text=args.text, server=args.server, port=args.port)
websocket.enableTrace(False)
wsUrl = wsParam.create_url()
ws = websocket.WebSocketApp(
wsUrl, on_message=on_message, on_error=on_error, on_close=on_close)
ws.on_open = on_open
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
t.join()
print("End of playing audio")
stream.stop_stream()
stream.close()
p.terminate()

@ -103,3 +103,26 @@ def change_speed(sample_raw, speed_rate, sample_rate):
sample_rate_in=sample_rate).squeeze(-1).astype(np.float32).copy() sample_rate_in=sample_rate).squeeze(-1).astype(np.float32).copy()
return sample_speed return sample_speed
def float2pcm(sig, dtype='int16'):
"""Convert floating point signal with a range from -1 to 1 to PCM.
Args:
sig (array): Input array, must have floating point type.
dtype (str, optional): Desired (integer) data type. Defaults to 'int16'.
Returns:
numpy.ndarray: Integer data, scaled and clipped to the range of the given
"""
sig = np.asarray(sig)
if sig.dtype.kind != 'f':
raise TypeError("'sig' must be a float array")
dtype = np.dtype(dtype)
if dtype.kind not in 'iu':
raise TypeError("'dtype' must be an integer type")
i = np.iinfo(dtype)
abs_max = 2**(i.bits - 1)
offset = i.min + abs_max
return (sig * abs_max + offset).clip(i.min, i.max).astype(dtype)

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the # See the License for the
import base64 import base64
import math
def wav2base64(wav_file: str): def wav2base64(wav_file: str):
@ -31,3 +32,29 @@ def self_check():
""" self check resource """ self check resource
""" """
return True return True
def denorm(data, mean, std):
return data * std + mean
def get_chunks(data, block_size, pad_size, step):
if step == "am":
data_len = data.shape[1]
elif step == "voc":
data_len = data.shape[0]
else:
print("Please set correct type to get chunks, am or voc")
chunks = []
n = math.ceil(data_len / block_size)
for i in range(n):
start = max(0, i * block_size - pad_size)
end = min((i + 1) * block_size + pad_size, data_len)
if step == "am":
chunks.append(data[:, start:end, :])
elif step == "voc":
chunks.append(data[start:end, :])
else:
print("Please set correct type to get chunks, am or voc")
return chunks

@ -16,6 +16,7 @@ from typing import List
from fastapi import APIRouter from fastapi import APIRouter
from paddlespeech.server.ws.asr_socket import router as asr_router from paddlespeech.server.ws.asr_socket import router as asr_router
from paddlespeech.server.ws.tts_socket import router as tts_router
_router = APIRouter() _router = APIRouter()
@ -31,7 +32,7 @@ def setup_router(api_list: List):
if api_name == 'asr': if api_name == 'asr':
_router.include_router(asr_router) _router.include_router(asr_router)
elif api_name == 'tts': elif api_name == 'tts':
pass _router.include_router(tts_router)
else: else:
pass pass

@ -0,0 +1,62 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from fastapi import APIRouter
from fastapi import WebSocket
from fastapi import WebSocketDisconnect
from starlette.websockets import WebSocketState as WebSocketState
from paddlespeech.cli.log import logger
from paddlespeech.server.engine.engine_pool import get_engine_pool
router = APIRouter()
@router.websocket('/ws/tts')
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
try:
# careful here, changed the source code from starlette.websockets
assert websocket.application_state == WebSocketState.CONNECTED
message = await websocket.receive()
websocket._raise_on_disconnect(message)
# get engine
engine_pool = get_engine_pool()
tts_engine = engine_pool['tts']
# 获取 message 并转文本
message = json.loads(message["text"])
text_bese64 = message["text"]
sentence = tts_engine.preprocess(text_bese64=text_bese64)
# run
wav = tts_engine.run(sentence)
while True:
try:
tts_results = next(wav)
resp = {"status": 1, "audio": tts_results}
await websocket.send_json(resp)
logger.info("streaming audio...")
except StopIteration as e:
resp = {"status": 2, "audio": ''}
await websocket.send_json(resp)
logger.info("Complete the transmission of audio streams")
break
except WebSocketDisconnect:
pass
Loading…
Cancel
Save