parent
3c8f30c7a4
commit
603e565ab1
@ -0,0 +1,46 @@
|
||||
# This is the parameter configuration file for PaddleSpeech Serving.
|
||||
|
||||
#################################################################################
|
||||
# SERVER SETTING #
|
||||
#################################################################################
|
||||
host: 127.0.0.1
|
||||
port: 8092
|
||||
|
||||
# The task format in the engin_list is: <speech task>_<engine type>
|
||||
# task choices = ['asr_online', 'tts_online']
|
||||
# protocol = ['websocket', 'http'] (only one can be selected).
|
||||
protocol: 'http'
|
||||
engine_list: ['tts_online']
|
||||
|
||||
|
||||
#################################################################################
|
||||
# ENGINE CONFIG #
|
||||
#################################################################################
|
||||
|
||||
################################### TTS #########################################
|
||||
################### speech task: tts; engine_type: online #######################
|
||||
tts_online:
|
||||
# am (acoustic model) choices=['fastspeech2_csmsc']
|
||||
am: 'fastspeech2_csmsc'
|
||||
am_config:
|
||||
am_ckpt:
|
||||
am_stat:
|
||||
phones_dict:
|
||||
tones_dict:
|
||||
speaker_dict:
|
||||
spk_id: 0
|
||||
|
||||
# voc (vocoder) choices=['mb_melgan_csmsc']
|
||||
voc: 'mb_melgan_csmsc'
|
||||
voc_config:
|
||||
voc_ckpt:
|
||||
voc_stat:
|
||||
|
||||
# others
|
||||
lang: 'zh'
|
||||
device: # set 'gpu:id' or 'cpu'
|
||||
am_block: 42
|
||||
am_pad: 12
|
||||
voc_block: 14
|
||||
voc_pad: 14
|
||||
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,305 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import base64
|
||||
import io
|
||||
import time
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import paddle
|
||||
import soundfile as sf
|
||||
from scipy.io import wavfile
|
||||
|
||||
from paddlespeech.cli.log import logger
|
||||
from paddlespeech.cli.tts.infer import TTSExecutor
|
||||
from paddlespeech.server.engine.base_engine import BaseEngine
|
||||
from paddlespeech.server.utils.audio_process import change_speed
|
||||
from paddlespeech.server.utils.errors import ErrorCode
|
||||
from paddlespeech.server.utils.exception import ServerBaseException
|
||||
from paddlespeech.server.utils.audio_process import float2pcm
|
||||
from paddlespeech.server.utils.config import get_config
|
||||
from paddlespeech.server.utils.util import denorm
|
||||
from paddlespeech.server.utils.util import get_chunks
|
||||
|
||||
|
||||
import math
|
||||
|
||||
__all__ = ['TTSEngine']
|
||||
|
||||
|
||||
class TTSServerExecutor(TTSExecutor):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
pass
|
||||
|
||||
@paddle.no_grad()
|
||||
def infer(self,
|
||||
text: str,
|
||||
lang: str='zh',
|
||||
am: str='fastspeech2_csmsc',
|
||||
spk_id: int=0,
|
||||
am_block: int=42,
|
||||
am_pad: int=12,
|
||||
voc_block: int=14,
|
||||
voc_pad: int=14,):
|
||||
"""
|
||||
Model inference and result stored in self.output.
|
||||
"""
|
||||
am_name = am[:am.rindex('_')]
|
||||
am_dataset = am[am.rindex('_') + 1:]
|
||||
get_tone_ids = False
|
||||
merge_sentences = False
|
||||
frontend_st = time.time()
|
||||
if am_name == 'speedyspeech':
|
||||
get_tone_ids = True
|
||||
if lang == 'zh':
|
||||
input_ids = self.frontend.get_input_ids(
|
||||
text,
|
||||
merge_sentences=merge_sentences,
|
||||
get_tone_ids=get_tone_ids)
|
||||
phone_ids = input_ids["phone_ids"]
|
||||
if get_tone_ids:
|
||||
tone_ids = input_ids["tone_ids"]
|
||||
elif lang == 'en':
|
||||
input_ids = self.frontend.get_input_ids(
|
||||
text, merge_sentences=merge_sentences)
|
||||
phone_ids = input_ids["phone_ids"]
|
||||
else:
|
||||
print("lang should in {'zh', 'en'}!")
|
||||
self.frontend_time = time.time() - frontend_st
|
||||
|
||||
for i in range(len(phone_ids)):
|
||||
am_st = time.time()
|
||||
part_phone_ids = phone_ids[i]
|
||||
# am
|
||||
if am_name == 'speedyspeech':
|
||||
part_tone_ids = tone_ids[i]
|
||||
mel = self.am_inference(part_phone_ids, part_tone_ids)
|
||||
# fastspeech2
|
||||
else:
|
||||
# multi speaker
|
||||
if am_dataset in {"aishell3", "vctk"}:
|
||||
mel = self.am_inference(
|
||||
part_phone_ids, spk_id=paddle.to_tensor(spk_id))
|
||||
else:
|
||||
mel = self.am_inference(part_phone_ids)
|
||||
am_et = time.time()
|
||||
|
||||
# voc streaming
|
||||
voc_upsample = self.voc_config.n_shift
|
||||
mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc")
|
||||
chunk_num = len(mel_chunks)
|
||||
voc_st = time.time()
|
||||
for i, mel_chunk in enumerate(mel_chunks):
|
||||
sub_wav = self.voc_inference(mel_chunk)
|
||||
front_pad = min(i*voc_block, voc_pad)
|
||||
|
||||
if i == 0:
|
||||
sub_wav = sub_wav[: voc_block * voc_upsample]
|
||||
elif i == chunk_num - 1:
|
||||
sub_wav = sub_wav[front_pad * voc_upsample : ]
|
||||
else:
|
||||
sub_wav = sub_wav[front_pad * voc_upsample: (front_pad + voc_block) * voc_upsample]
|
||||
|
||||
yield sub_wav
|
||||
|
||||
class TTSEngine(BaseEngine):
|
||||
"""TTS server engine
|
||||
|
||||
Args:
|
||||
metaclass: Defaults to Singleton.
|
||||
"""
|
||||
|
||||
def __init__(self, name=None):
|
||||
"""Initialize TTS server engine
|
||||
"""
|
||||
super(TTSEngine, self).__init__()
|
||||
|
||||
def init(self, config: dict) -> bool:
|
||||
self.executor = TTSServerExecutor()
|
||||
|
||||
try:
|
||||
self.config = config
|
||||
if self.config.device:
|
||||
self.device = self.config.device
|
||||
else:
|
||||
self.device = paddle.get_device()
|
||||
paddle.set_device(self.device)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
|
||||
)
|
||||
logger.error("Initialize TTS server engine Failed on device: %s." %
|
||||
(self.device))
|
||||
return False
|
||||
|
||||
try:
|
||||
self.executor._init_from_path(
|
||||
am=self.config.am,
|
||||
am_config=self.config.am_config,
|
||||
am_ckpt=self.config.am_ckpt,
|
||||
am_stat=self.config.am_stat,
|
||||
phones_dict=self.config.phones_dict,
|
||||
tones_dict=self.config.tones_dict,
|
||||
speaker_dict=self.config.speaker_dict,
|
||||
voc=self.config.voc,
|
||||
voc_config=self.config.voc_config,
|
||||
voc_ckpt=self.config.voc_ckpt,
|
||||
voc_stat=self.config.voc_stat,
|
||||
lang=self.config.lang)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get model related files.")
|
||||
logger.error("Initialize TTS server engine Failed on device: %s." %
|
||||
(self.device))
|
||||
return False
|
||||
|
||||
self.am_block = self.config.am_block
|
||||
self.am_pad = self.config.am_pad
|
||||
self.voc_block = self.config.voc_block
|
||||
self.voc_pad = self.config.voc_pad
|
||||
|
||||
logger.info("Initialize TTS server engine successfully on device: %s." %
|
||||
(self.device))
|
||||
return True
|
||||
|
||||
def preprocess(self, text_bese64: str=None, text_bytes: bytes=None):
|
||||
# Convert byte to text
|
||||
if text_bese64:
|
||||
text_bytes = base64.b64decode(text_bese64) # base64 to bytes
|
||||
text = text_bytes.decode('utf-8') # bytes to text
|
||||
|
||||
return text
|
||||
|
||||
def postprocess(self,
|
||||
wav,
|
||||
original_fs: int,
|
||||
target_fs: int=0,
|
||||
volume: float=1.0,
|
||||
speed: float=1.0,
|
||||
audio_path: str=None):
|
||||
"""Post-processing operations, including speech, volume, sample rate, save audio file
|
||||
|
||||
Args:
|
||||
wav (numpy(float)): Synthesized audio sample points
|
||||
original_fs (int): original audio sample rate
|
||||
target_fs (int): target audio sample rate
|
||||
volume (float): target volume
|
||||
speed (float): target speed
|
||||
|
||||
Raises:
|
||||
ServerBaseException: Throws an exception if the change speed unsuccessfully.
|
||||
|
||||
Returns:
|
||||
target_fs: target sample rate for synthesized audio.
|
||||
wav_base64: The base64 format of the synthesized audio.
|
||||
"""
|
||||
|
||||
# transform sample_rate
|
||||
if target_fs == 0 or target_fs > original_fs:
|
||||
target_fs = original_fs
|
||||
wav_tar_fs = wav
|
||||
logger.info(
|
||||
"The sample rate of synthesized audio is the same as model, which is {}Hz".
|
||||
format(original_fs))
|
||||
else:
|
||||
wav_tar_fs = librosa.resample(
|
||||
np.squeeze(wav), original_fs, target_fs)
|
||||
logger.info(
|
||||
"The sample rate of model is {}Hz and the target sample rate is {}Hz. Converting the sample rate of the synthesized audio successfully.".
|
||||
format(original_fs, target_fs))
|
||||
# transform volume
|
||||
wav_vol = wav_tar_fs * volume
|
||||
logger.info("Transform the volume of the audio successfully.")
|
||||
|
||||
# transform speed
|
||||
try: # windows not support soxbindings
|
||||
wav_speed = change_speed(wav_vol, speed, target_fs)
|
||||
logger.info("Transform the speed of the audio successfully.")
|
||||
except ServerBaseException:
|
||||
raise ServerBaseException(
|
||||
ErrorCode.SERVER_INTERNAL_ERR,
|
||||
"Failed to transform speed. Can not install soxbindings on your system. \
|
||||
You need to set speed value 1.0.")
|
||||
except BaseException:
|
||||
logger.error("Failed to transform speed.")
|
||||
|
||||
# wav to base64
|
||||
buf = io.BytesIO()
|
||||
wavfile.write(buf, target_fs, wav_speed)
|
||||
base64_bytes = base64.b64encode(buf.read())
|
||||
wav_base64 = base64_bytes.decode('utf-8')
|
||||
logger.info("Audio to string successfully.")
|
||||
|
||||
# save audio
|
||||
if audio_path is not None:
|
||||
if audio_path.endswith(".wav"):
|
||||
sf.write(audio_path, wav_speed, target_fs)
|
||||
elif audio_path.endswith(".pcm"):
|
||||
wav_norm = wav_speed * (32767 / max(0.001,
|
||||
np.max(np.abs(wav_speed))))
|
||||
with open(audio_path, "wb") as f:
|
||||
f.write(wav_norm.astype(np.int16))
|
||||
logger.info("Save audio to {} successfully.".format(audio_path))
|
||||
else:
|
||||
logger.info("There is no need to save audio.")
|
||||
|
||||
return target_fs, wav_base64
|
||||
|
||||
def run(self,
|
||||
sentence: str,
|
||||
spk_id: int=0,
|
||||
speed: float=1.0,
|
||||
volume: float=1.0,
|
||||
sample_rate: int=0,
|
||||
save_path: str=None):
|
||||
""" run include inference and postprocess.
|
||||
|
||||
Args:
|
||||
sentence (str): text to be synthesized
|
||||
spk_id (int, optional): speaker id for multi-speaker speech synthesis. Defaults to 0.
|
||||
speed (float, optional): speed. Defaults to 1.0.
|
||||
volume (float, optional): volume. Defaults to 1.0.
|
||||
sample_rate (int, optional): target sample rate for synthesized audio,
|
||||
0 means the same as the model sampling rate. Defaults to 0.
|
||||
save_path (str, optional): The save path of the synthesized audio.
|
||||
None means do not save audio. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ServerBaseException: Throws an exception if tts inference unsuccessfully.
|
||||
ServerBaseException: Throws an exception if postprocess unsuccessfully.
|
||||
|
||||
Returns:
|
||||
lang: model language
|
||||
target_sample_rate: target sample rate for synthesized audio.
|
||||
wav_base64: The base64 format of the synthesized audio.
|
||||
"""
|
||||
|
||||
lang = self.config.lang
|
||||
wav_list = []
|
||||
|
||||
for wav in self.executor.infer(text=sentence, lang=lang, am=self.config.am, spk_id=spk_id, am_block=self.am_block, am_pad=self.am_pad, voc_block=self.voc_block, voc_pad=self.voc_pad):
|
||||
# wav type: <class 'numpy.ndarray'> float32, convert to pcm (base64)
|
||||
wav = float2pcm(wav) # float32 to int16
|
||||
wav_bytes = wav.tobytes() # to bytes
|
||||
wav_base64 = base64.b64encode(wav_bytes).decode('utf8') # to base64
|
||||
wav_list.append(wav)
|
||||
|
||||
yield wav_base64
|
||||
|
||||
wav_all = np.concatenate(wav_list, axis=0)
|
||||
logger.info("The durations of audio is: {} s".format(len(wav_all)/self.executor.am_config.fs))
|
||||
|
||||
|
||||
|
||||
|
@ -0,0 +1,100 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from paddlespeech.server.utils.audio_process import pcm2wav
|
||||
|
||||
|
||||
def save_audio(buffer, audio_path) -> bool:
|
||||
if args.save_path.endswith("pcm"):
|
||||
with open(args.save_path, "wb") as f:
|
||||
f.write(buffer)
|
||||
elif args.save_path.endswith("wav"):
|
||||
with open("./tmp.pcm", "wb") as f:
|
||||
f.write(buffer)
|
||||
pcm2wav("./tmp.pcm", audio_path, channels=1, bits=16, sample_rate=24000)
|
||||
os.system("rm ./tmp.pcm")
|
||||
else:
|
||||
print("Only supports saved audio format is pcm or wav")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def test(args):
|
||||
params = {
|
||||
"text": args.text,
|
||||
"spk_id": args.spk_id,
|
||||
"speed": args.speed,
|
||||
"volume": args.volume,
|
||||
"sample_rate": args.sample_rate,
|
||||
"save_path": ''
|
||||
}
|
||||
|
||||
buffer = b''
|
||||
flag = 1
|
||||
url = "http://" + str(args.server) + ":" + str(
|
||||
args.port) + "/paddlespeech/streaming/tts"
|
||||
st = time.time()
|
||||
html = requests.post(url, json.dumps(params), stream=True)
|
||||
for chunk in html.iter_content(chunk_size=1024):
|
||||
chunk = base64.b64decode(chunk) # bytes
|
||||
if flag:
|
||||
first_response = time.time() - st
|
||||
print(f"首包响应:{first_response} s")
|
||||
flag = 0
|
||||
buffer += chunk
|
||||
|
||||
final_response = time.time() - st
|
||||
duration = len(buffer) / 2.0 / 24000
|
||||
|
||||
print(f"尾包响应:{final_response} s")
|
||||
print(f"音频时长:{duration} s")
|
||||
print(f"RTF: {final_response / duration}")
|
||||
|
||||
if args.save_path is not None:
|
||||
if save_audio(buffer, args.save_path):
|
||||
print("音频保存至:", args.save_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--text',
|
||||
type=str,
|
||||
default="您好,欢迎使用语音合成服务。",
|
||||
help='A sentence to be synthesized')
|
||||
parser.add_argument('--spk_id', type=int, default=0, help='Speaker id')
|
||||
parser.add_argument('--speed', type=float, default=1.0, help='Audio speed')
|
||||
parser.add_argument(
|
||||
'--volume', type=float, default=1.0, help='Audio volume')
|
||||
parser.add_argument(
|
||||
'--sample_rate',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Sampling rate, the default is the same as the model')
|
||||
parser.add_argument(
|
||||
"--server", type=str, help="server ip", default="127.0.0.1")
|
||||
parser.add_argument("--port", type=int, help="server port", default=8092)
|
||||
parser.add_argument(
|
||||
"--save_path", type=str, help="save audio path", default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
test(args)
|
@ -0,0 +1,112 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import base64
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pyaudio
|
||||
import requests
|
||||
|
||||
mutex = threading.Lock()
|
||||
buffer = b''
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(
|
||||
format=p.get_format_from_width(2), channels=1, rate=24000, output=True)
|
||||
max_fail = 50
|
||||
|
||||
|
||||
def play_audio():
|
||||
global stream
|
||||
global buffer
|
||||
global max_fail
|
||||
while True:
|
||||
if not buffer:
|
||||
max_fail -= 1
|
||||
time.sleep(0.05)
|
||||
if max_fail < 0:
|
||||
break
|
||||
mutex.acquire()
|
||||
stream.write(buffer)
|
||||
buffer = b''
|
||||
mutex.release()
|
||||
|
||||
|
||||
def test(args):
|
||||
global mutex
|
||||
global buffer
|
||||
params = {
|
||||
"text": args.text,
|
||||
"spk_id": args.spk_id,
|
||||
"speed": args.speed,
|
||||
"volume": args.volume,
|
||||
"sample_rate": args.sample_rate,
|
||||
"save_path": ''
|
||||
}
|
||||
|
||||
all_bytes = 0.0
|
||||
t = threading.Thread(target=play_audio)
|
||||
flag = 1
|
||||
url = "http://" + str(args.server) + ":" + str(
|
||||
args.port) + "/paddlespeech/streaming/tts"
|
||||
st = time.time()
|
||||
html = requests.post(url, json.dumps(params), stream=True)
|
||||
for chunk in html.iter_content(chunk_size=1024):
|
||||
mutex.acquire()
|
||||
chunk = base64.b64decode(chunk) # bytes
|
||||
buffer += chunk
|
||||
mutex.release()
|
||||
if flag:
|
||||
first_response = time.time() - st
|
||||
print(f"首包响应:{first_response} s")
|
||||
flag = 0
|
||||
t.start()
|
||||
all_bytes += len(chunk)
|
||||
|
||||
final_response = time.time() - st
|
||||
duration = all_bytes / 2 / 24000
|
||||
|
||||
print(f"尾包响应:{final_response} s")
|
||||
print(f"音频时长:{duration} s")
|
||||
print(f"RTF: {final_response / duration}")
|
||||
|
||||
t.join()
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
p.terminate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--text',
|
||||
type=str,
|
||||
default="您好,欢迎使用语音合成服务。",
|
||||
help='A sentence to be synthesized')
|
||||
parser.add_argument('--spk_id', type=int, default=0, help='Speaker id')
|
||||
parser.add_argument('--speed', type=float, default=1.0, help='Audio speed')
|
||||
parser.add_argument(
|
||||
'--volume', type=float, default=1.0, help='Audio volume')
|
||||
parser.add_argument(
|
||||
'--sample_rate',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Sampling rate, the default is the same as the model')
|
||||
parser.add_argument(
|
||||
"--server", type=str, help="server ip", default="127.0.0.1")
|
||||
parser.add_argument("--port", type=int, help="server port", default=8092)
|
||||
|
||||
args = parser.parse_args()
|
||||
test(args)
|
Binary file not shown.
@ -0,0 +1,126 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import _thread as thread
|
||||
import argparse
|
||||
import base64
|
||||
import json
|
||||
import ssl
|
||||
import time
|
||||
|
||||
import websocket
|
||||
|
||||
flag = 1
|
||||
st = 0.0
|
||||
all_bytes = b''
|
||||
|
||||
|
||||
class Ws_Param(object):
|
||||
# 初始化
|
||||
def __init__(self, text, server="127.0.0.1", port=8090):
|
||||
self.server = server
|
||||
self.port = port
|
||||
self.url = "ws://" + self.server + ":" + str(self.port) + "/ws/tts"
|
||||
self.text = text
|
||||
|
||||
# 生成url
|
||||
def create_url(self):
|
||||
return self.url
|
||||
|
||||
|
||||
def on_message(ws, message):
|
||||
global flag
|
||||
global st
|
||||
global all_bytes
|
||||
|
||||
try:
|
||||
message = json.loads(message)
|
||||
audio = message["audio"]
|
||||
audio = base64.b64decode(audio) # bytes
|
||||
status = message["status"]
|
||||
all_bytes += audio
|
||||
|
||||
if status == 0:
|
||||
print("create successfully.")
|
||||
elif status == 1:
|
||||
if flag:
|
||||
print(f"首包响应:{time.time() - st} s")
|
||||
flag = 0
|
||||
elif status == 2:
|
||||
final_response = time.time() - st
|
||||
duration = len(all_bytes) / 2.0 / 24000
|
||||
print(f"尾包响应:{final_response} s")
|
||||
print(f"音频时长:{duration} s")
|
||||
print(f"RTF: {final_response / duration}")
|
||||
with open("./out.pcm", "wb") as f:
|
||||
f.write(all_bytes)
|
||||
print("ws is closed")
|
||||
ws.close()
|
||||
else:
|
||||
print("infer error")
|
||||
|
||||
except Exception as e:
|
||||
print("receive msg,but parse exception:", e)
|
||||
|
||||
|
||||
# 收到websocket错误的处理
|
||||
def on_error(ws, error):
|
||||
print("### error:", error)
|
||||
|
||||
|
||||
# 收到websocket关闭的处理
|
||||
def on_close(ws):
|
||||
print("### closed ###")
|
||||
|
||||
|
||||
# 收到websocket连接建立的处理
|
||||
def on_open(ws):
|
||||
def run(*args):
|
||||
global st
|
||||
text_base64 = str(
|
||||
base64.b64encode((wsParam.text).encode('utf-8')), "UTF8")
|
||||
d = {"text": text_base64}
|
||||
d = json.dumps(d)
|
||||
print("Start sending text data")
|
||||
st = time.time()
|
||||
ws.send(d)
|
||||
|
||||
thread.start_new_thread(run, ())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--text",
|
||||
type=str,
|
||||
help="A sentence to be synthesized",
|
||||
default="您好,欢迎使用语音合成服务。")
|
||||
parser.add_argument(
|
||||
"--server", type=str, help="server ip", default="127.0.0.1")
|
||||
parser.add_argument("--port", type=int, help="server port", default=8092)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("***************************************")
|
||||
print("Server ip: ", args.server)
|
||||
print("Server port: ", args.port)
|
||||
print("Sentence to be synthesized: ", args.text)
|
||||
print("***************************************")
|
||||
|
||||
wsParam = Ws_Param(text=args.text, server=args.server, port=args.port)
|
||||
|
||||
websocket.enableTrace(False)
|
||||
wsUrl = wsParam.create_url()
|
||||
ws = websocket.WebSocketApp(
|
||||
wsUrl, on_message=on_message, on_error=on_error, on_close=on_close)
|
||||
ws.on_open = on_open
|
||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
@ -0,0 +1,160 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import _thread as thread
|
||||
import argparse
|
||||
import base64
|
||||
import json
|
||||
import ssl
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pyaudio
|
||||
import websocket
|
||||
|
||||
mutex = threading.Lock()
|
||||
buffer = b''
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(
|
||||
format=p.get_format_from_width(2), channels=1, rate=24000, output=True)
|
||||
flag = 1
|
||||
st = 0.0
|
||||
all_bytes = 0.0
|
||||
|
||||
|
||||
class Ws_Param(object):
|
||||
# 初始化
|
||||
def __init__(self, text, server="127.0.0.1", port=8090):
|
||||
self.server = server
|
||||
self.port = port
|
||||
self.url = "ws://" + self.server + ":" + str(self.port) + "/ws/tts"
|
||||
self.text = text
|
||||
|
||||
# 生成url
|
||||
def create_url(self):
|
||||
return self.url
|
||||
|
||||
|
||||
def play_audio():
|
||||
global stream
|
||||
global buffer
|
||||
while True:
|
||||
time.sleep(0.05)
|
||||
if not buffer: # buffer 为空
|
||||
break
|
||||
mutex.acquire()
|
||||
stream.write(buffer)
|
||||
buffer = b''
|
||||
mutex.release()
|
||||
|
||||
|
||||
t = threading.Thread(target=play_audio)
|
||||
|
||||
|
||||
def on_message(ws, message):
|
||||
global flag
|
||||
global t
|
||||
global buffer
|
||||
global st
|
||||
global all_bytes
|
||||
|
||||
try:
|
||||
message = json.loads(message)
|
||||
audio = message["audio"]
|
||||
audio = base64.b64decode(audio) # bytes
|
||||
status = message["status"]
|
||||
all_bytes += len(audio)
|
||||
|
||||
if status == 0:
|
||||
print("create successfully.")
|
||||
elif status == 1:
|
||||
mutex.acquire()
|
||||
buffer += audio
|
||||
mutex.release()
|
||||
if flag:
|
||||
print(f"首包响应:{time.time() - st} s")
|
||||
flag = 0
|
||||
print("Start playing audio")
|
||||
t.start()
|
||||
elif status == 2:
|
||||
final_response = time.time() - st
|
||||
duration = all_bytes / 2 / 24000
|
||||
print(f"尾包响应:{final_response} s")
|
||||
print(f"音频时长:{duration} s")
|
||||
print(f"RTF: {final_response / duration}")
|
||||
print("ws is closed")
|
||||
ws.close()
|
||||
else:
|
||||
print("infer error")
|
||||
|
||||
except Exception as e:
|
||||
print("receive msg,but parse exception:", e)
|
||||
|
||||
|
||||
# 收到websocket错误的处理
|
||||
def on_error(ws, error):
|
||||
print("### error:", error)
|
||||
|
||||
|
||||
# 收到websocket关闭的处理
|
||||
def on_close(ws):
|
||||
print("### closed ###")
|
||||
|
||||
|
||||
# 收到websocket连接建立的处理
|
||||
def on_open(ws):
|
||||
def run(*args):
|
||||
global st
|
||||
text_base64 = str(
|
||||
base64.b64encode((wsParam.text).encode('utf-8')), "UTF8")
|
||||
d = {"text": text_base64}
|
||||
d = json.dumps(d)
|
||||
print("Start sending text data")
|
||||
st = time.time()
|
||||
ws.send(d)
|
||||
|
||||
thread.start_new_thread(run, ())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--text",
|
||||
type=str,
|
||||
help="A sentence to be synthesized",
|
||||
default="您好,欢迎使用语音合成服务。")
|
||||
parser.add_argument(
|
||||
"--server", type=str, help="server ip", default="127.0.0.1")
|
||||
parser.add_argument("--port", type=int, help="server port", default=8092)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("***************************************")
|
||||
print("Server ip: ", args.server)
|
||||
print("Server port: ", args.port)
|
||||
print("Sentence to be synthesized: ", args.text)
|
||||
print("***************************************")
|
||||
|
||||
wsParam = Ws_Param(text=args.text, server=args.server, port=args.port)
|
||||
|
||||
websocket.enableTrace(False)
|
||||
wsUrl = wsParam.create_url()
|
||||
ws = websocket.WebSocketApp(
|
||||
wsUrl, on_message=on_message, on_error=on_error, on_close=on_close)
|
||||
ws.on_open = on_open
|
||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||
|
||||
t.join()
|
||||
print("End of playing audio")
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
p.terminate()
|
@ -0,0 +1,62 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import WebSocket
|
||||
from fastapi import WebSocketDisconnect
|
||||
from starlette.websockets import WebSocketState as WebSocketState
|
||||
|
||||
from paddlespeech.cli.log import logger
|
||||
from paddlespeech.server.engine.engine_pool import get_engine_pool
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.websocket('/ws/tts')
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
|
||||
try:
|
||||
# careful here, changed the source code from starlette.websockets
|
||||
assert websocket.application_state == WebSocketState.CONNECTED
|
||||
message = await websocket.receive()
|
||||
websocket._raise_on_disconnect(message)
|
||||
|
||||
# get engine
|
||||
engine_pool = get_engine_pool()
|
||||
tts_engine = engine_pool['tts']
|
||||
|
||||
# 获取 message 并转文本
|
||||
message = json.loads(message["text"])
|
||||
text_bese64 = message["text"]
|
||||
sentence = tts_engine.preprocess(text_bese64=text_bese64)
|
||||
|
||||
# run
|
||||
wav = tts_engine.run(sentence)
|
||||
|
||||
while True:
|
||||
try:
|
||||
tts_results = next(wav)
|
||||
resp = {"status": 1, "audio": tts_results}
|
||||
await websocket.send_json(resp)
|
||||
logger.info("streaming audio...")
|
||||
except StopIteration as e:
|
||||
resp = {"status": 2, "audio": ''}
|
||||
await websocket.send_json(resp)
|
||||
logger.info("Complete the transmission of audio streams")
|
||||
break
|
||||
|
||||
except WebSocketDisconnect:
|
||||
pass
|
Loading…
Reference in new issue