Merge pull request #1882 from lym0302/streaming_tts_server

[server] improve code
pull/1887/head
liangym 3 years ago committed by GitHub
commit 22b67ed051
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -18,6 +18,7 @@ import io
import json import json
import os import os
import random import random
import sys
import time import time
from typing import List from typing import List
@ -91,7 +92,7 @@ class TTSClientExecutor(BaseExecutor):
temp_wav = str(random.getrandbits(128)) + ".wav" temp_wav = str(random.getrandbits(128)) + ".wav"
soundfile.write(temp_wav, samples, sample_rate) soundfile.write(temp_wav, samples, sample_rate)
wav2pcm(temp_wav, outfile, data_type=np.int16) wav2pcm(temp_wav, outfile, data_type=np.int16)
os.system("rm %s" % (temp_wav)) os.remove(temp_wav)
else: else:
logger.error("The format for saving audio only supports wav or pcm") logger.error("The format for saving audio only supports wav or pcm")
@ -128,6 +129,7 @@ class TTSClientExecutor(BaseExecutor):
return True return True
except Exception as e: except Exception as e:
logger.error("Failed to synthesized audio.") logger.error("Failed to synthesized audio.")
logger.error(e)
return False return False
@stats_wrapper @stats_wrapper
@ -236,6 +238,7 @@ class TTSOnlineClientExecutor(BaseExecutor):
return True return True
except Exception as e: except Exception as e:
logger.error("Failed to synthesized audio.") logger.error("Failed to synthesized audio.")
logger.error(e)
return False return False
@stats_wrapper @stats_wrapper
@ -275,7 +278,7 @@ class TTSOnlineClientExecutor(BaseExecutor):
else: else:
logger.error("Please set correct protocol, http or websocket") logger.error("Please set correct protocol, http or websocket")
return False sys.exit(-1)
logger.info(f"sentence: {input}") logger.info(f"sentence: {input}")
logger.info(f"duration: {duration} s") logger.info(f"duration: {duration} s")
@ -503,6 +506,7 @@ class ASROnlineClientExecutor(BaseExecutor):
Returns: Returns:
str: the audio text str: the audio text
""" """
logger.info("asr websocket client start") logger.info("asr websocket client start")
handler = ASRWsAudioHandler( handler = ASRWsAudioHandler(
server_ip, server_ip,
@ -555,6 +559,7 @@ class CLSClientExecutor(BaseExecutor):
return True return True
except Exception as e: except Exception as e:
logger.error("Failed to speech classification.") logger.error("Failed to speech classification.")
logger.error(e)
return False return False
@stats_wrapper @stats_wrapper
@ -728,6 +733,7 @@ class VectorClientExecutor(BaseExecutor):
Returns: Returns:
str: the audio embedding or score between enroll and test audio str: the audio embedding or score between enroll and test audio
""" """
if task == "spk": if task == "spk":
from paddlespeech.server.utils.audio_handler import VectorHttpHandler from paddlespeech.server.utils.audio_handler import VectorHttpHandler
logger.info("vector http client start") logger.info("vector http client start")

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import sys
from typing import List from typing import List
import uvicorn import uvicorn
@ -79,10 +80,12 @@ class ServerExecutor(BaseExecutor):
def execute(self, argv: List[str]) -> bool: def execute(self, argv: List[str]) -> bool:
args = self.parser.parse_args(argv) args = self.parser.parse_args(argv)
config = get_config(args.config_file) try:
self(args.config_file, args.log_file)
if self.init(config): except Exception as e:
uvicorn.run(app, host=config.host, port=config.port, debug=True) logger.error("Failed to start server.")
logger.error(e)
sys.exit(-1)
@stats_wrapper @stats_wrapper
def __call__(self, def __call__(self,

@ -61,7 +61,7 @@ def tts_client(args):
temp_wav = str(random.getrandbits(128)) + ".wav" temp_wav = str(random.getrandbits(128)) + ".wav"
soundfile.write(temp_wav, samples, sample_rate) soundfile.write(temp_wav, samples, sample_rate)
wav2pcm(temp_wav, outfile, data_type=np.int16) wav2pcm(temp_wav, outfile, data_type=np.int16)
os.system("rm %s" % (temp_wav)) os.remove(temp_wav)
else: else:
print("The format for saving audio only supports wav or pcm") print("The format for saving audio only supports wav or pcm")

@ -304,52 +304,80 @@ class TTSWsHandler:
receive_time_list = [] receive_time_list = []
chunk_duration_list = [] chunk_duration_list = []
# 1. Send websocket handshake protocal # 1. Send websocket handshake request
async with websockets.connect(self.url) as ws: async with websockets.connect(self.url) as ws:
# 2. Server has already received handshake protocal # 2. Server has already received handshake response, send start request
# send text to engine start_request = json.dumps({"task": "tts", "signal": "start"})
await ws.send(start_request)
msg = await ws.recv()
logger.info(f"client receive msg={msg}")
msg = json.loads(msg)
session = msg["session"]
# 3. send speech synthesis request
text_base64 = str(base64.b64encode((text).encode('utf-8')), "UTF8") text_base64 = str(base64.b64encode((text).encode('utf-8')), "UTF8")
d = {"text": text_base64} request = json.dumps({"text": text_base64})
d = json.dumps(d)
st = time.time() st = time.time()
await ws.send(d) await ws.send(request)
logging.info("send a message to the server") logging.info("send a message to the server")
# 3. Process the received response # 4. Process the received response
message = await ws.recv() message = await ws.recv()
first_response = time.time() - st first_response = time.time() - st
message = json.loads(message) message = json.loads(message)
status = message["status"] status = message["status"]
while True:
# When throw an exception
if status == -1:
# send end request
end_request = json.dumps({
"task": "tts",
"signal": "end",
"session": session
})
await ws.send(end_request)
break
# Rerutn last packet normally, no audio information
elif status == 2:
final_response = time.time() - st
duration = len(all_bytes) / 2.0 / 24000
if output is not None:
save_audio_success = save_audio(all_bytes, output)
else:
save_audio_success = False
# send end request
end_request = json.dumps({
"task": "tts",
"signal": "end",
"session": session
})
await ws.send(end_request)
break
# Return the audio stream normally
elif status == 1:
receive_time_list.append(time.time())
audio = message["audio"]
audio = base64.b64decode(audio) # bytes
chunk_duration_list.append(len(audio) / 2.0 / 24000)
all_bytes += audio
if self.play:
self.mutex.acquire()
self.buffer += audio
self.mutex.release()
if self.start_play:
self.t.start()
self.start_play = False
message = await ws.recv()
message = json.loads(message)
status = message["status"]
while (status == 1):
receive_time_list.append(time.time())
audio = message["audio"]
audio = base64.b64decode(audio) # bytes
chunk_duration_list.append(len(audio) / 2.0 / 24000)
all_bytes += audio
if self.play:
self.mutex.acquire()
self.buffer += audio
self.mutex.release()
if self.start_play:
self.t.start()
self.start_play = False
message = await ws.recv()
message = json.loads(message)
status = message["status"]
# 4. Last packet, no audio information
if status == 2:
final_response = time.time() - st
duration = len(all_bytes) / 2.0 / 24000
if output is not None:
save_audio_success = save_audio(all_bytes, output)
else: else:
save_audio_success = False logger.error("infer error, return status is invalid.")
else:
logger.error("infer error")
if self.play: if self.play:
self.t.join() self.t.join()
@ -458,6 +486,7 @@ class TTSHttpHandler:
final_response = time.time() - st final_response = time.time() - st
duration = len(all_bytes) / 2.0 / 24000 duration = len(all_bytes) / 2.0 / 24000
html.close() # when stream=True
if output is not None: if output is not None:
save_audio_success = save_audio(all_bytes, output) save_audio_success = save_audio(all_bytes, output)

@ -167,7 +167,7 @@ def save_audio(bytes_data, audio_path, sample_rate: int=24000) -> bool:
channels=1, channels=1,
bits=16, bits=16,
sample_rate=sample_rate) sample_rate=sample_rate)
os.system("rm ./tmp.pcm") os.remove("./tmp.pcm")
else: else:
print("Only supports saved audio format is pcm or wav") print("Only supports saved audio format is pcm or wav")
return False return False

@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json import json
import uuid
from fastapi import APIRouter from fastapi import APIRouter
from fastapi import WebSocket from fastapi import WebSocket
from fastapi import WebSocketDisconnect
from starlette.websockets import WebSocketState as WebSocketState from starlette.websockets import WebSocketState as WebSocketState
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
@ -26,36 +26,79 @@ router = APIRouter()
@router.websocket('/paddlespeech/tts/streaming') @router.websocket('/paddlespeech/tts/streaming')
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
"""PaddleSpeech Online TTS Server api
Args:
websocket (WebSocket): the websocket instance
"""
#1. the interface wait to accept the websocket protocal header
# and only we receive the header, it establish the connection with specific thread
await websocket.accept() await websocket.accept()
#2. if we accept the websocket headers, we will get the online tts engine instance
engine_pool = get_engine_pool()
tts_engine = engine_pool['tts']
try: try:
# careful here, changed the source code from starlette.websockets while True:
assert websocket.application_state == WebSocketState.CONNECTED # careful here, changed the source code from starlette.websockets
message = await websocket.receive() assert websocket.application_state == WebSocketState.CONNECTED
websocket._raise_on_disconnect(message) message = await websocket.receive()
websocket._raise_on_disconnect(message)
message = json.loads(message["text"])
# get engine if 'signal' in message:
engine_pool = get_engine_pool() # start request
tts_engine = engine_pool['tts'] if message['signal'] == 'start':
session = uuid.uuid1().hex
resp = {
"status": 0,
"signal": "server ready",
"session": session
}
await websocket.send_json(resp)
# 获取 message 并转文本 # end request
message = json.loads(message["text"]) elif message['signal'] == 'end':
text_bese64 = message["text"] resp = {
sentence = tts_engine.preprocess(text_bese64=text_bese64) "status": 0,
"signal": "connection will be closed",
"session": session
}
await websocket.send_json(resp)
break
else:
resp = {"status": 0, "signal": "no valid json data"}
await websocket.send_json(resp)
# run # speech synthesis request
wav_generator = tts_engine.run(sentence) elif 'text' in message:
text_bese64 = message["text"]
sentence = tts_engine.preprocess(text_bese64=text_bese64)
while True: # run
try: wav_generator = tts_engine.run(sentence)
tts_results = next(wav_generator)
resp = {"status": 1, "audio": tts_results} while True:
await websocket.send_json(resp) try:
except StopIteration as e: tts_results = next(wav_generator)
resp = {"status": 2, "audio": ''} resp = {"status": 1, "audio": tts_results}
await websocket.send_json(resp) await websocket.send_json(resp)
logger.info("Complete the transmission of audio streams") except StopIteration as e:
break resp = {"status": 2, "audio": ''}
await websocket.send_json(resp)
except WebSocketDisconnect: logger.info(
pass "Complete the synthesis of the audio streams")
break
except Exception as e:
resp = {"status": -1, "audio": ''}
await websocket.send_json(resp)
break
else:
logger.error(
"Invalid request, please check if the request is correct.")
except Exception as e:
logger.error(e)

@ -1,6 +1,7 @@
#!/usr/bin/python #!/usr/bin/python
import argparse import argparse
import os import os
import shutil
import yaml import yaml
@ -14,7 +15,7 @@ def change_device(yamlfile: str, engine: str, device: str):
model_type (dict): change model type model_type (dict): change model type
""" """
tmp_yamlfile = yamlfile.split(".yaml")[0] + "_tmp.yaml" tmp_yamlfile = yamlfile.split(".yaml")[0] + "_tmp.yaml"
os.system("cp %s %s" % (yamlfile, tmp_yamlfile)) shutil.copyfile(yamlfile, tmp_yamlfile)
if device == 'cpu': if device == 'cpu':
set_device = 'cpu' set_device = 'cpu'
@ -41,7 +42,7 @@ def change_device(yamlfile: str, engine: str, device: str):
print(yaml.dump(y, default_flow_style=False, sort_keys=False)) print(yaml.dump(y, default_flow_style=False, sort_keys=False))
yaml.dump(y, fw, allow_unicode=True) yaml.dump(y, fw, allow_unicode=True)
os.system("rm %s" % (tmp_yamlfile)) os.remove(tmp_yamlfile)
print("Change %s successfully." % (yamlfile)) print("Change %s successfully." % (yamlfile))
@ -52,7 +53,7 @@ def change_engine_type(yamlfile: str, engine_type):
task (str): asr or tts task (str): asr or tts
""" """
tmp_yamlfile = yamlfile.split(".yaml")[0] + "_tmp.yaml" tmp_yamlfile = yamlfile.split(".yaml")[0] + "_tmp.yaml"
os.system("cp %s %s" % (yamlfile, tmp_yamlfile)) shutil.copyfile(yamlfile, tmp_yamlfile)
speech_task = engine_type.split("_")[0] speech_task = engine_type.split("_")[0]
with open(tmp_yamlfile) as f, open(yamlfile, "w+", encoding="utf-8") as fw: with open(tmp_yamlfile) as f, open(yamlfile, "w+", encoding="utf-8") as fw:
@ -65,7 +66,7 @@ def change_engine_type(yamlfile: str, engine_type):
y['engine_list'] = engine_list y['engine_list'] = engine_list
print(yaml.dump(y, default_flow_style=False, sort_keys=False)) print(yaml.dump(y, default_flow_style=False, sort_keys=False))
yaml.dump(y, fw, allow_unicode=True) yaml.dump(y, fw, allow_unicode=True)
os.system("rm %s" % (tmp_yamlfile)) os.remove(tmp_yamlfile)
print("Change %s successfully." % (yamlfile)) print("Change %s successfully." % (yamlfile))

@ -1,6 +1,7 @@
#!/usr/bin/python #!/usr/bin/python
import argparse import argparse
import os import os
import shutil
import yaml import yaml
@ -13,7 +14,7 @@ def change_value(args):
target_value = args.target_value target_value = args.target_value
tmp_yamlfile = yamlfile.split(".yaml")[0] + "_tmp.yaml" tmp_yamlfile = yamlfile.split(".yaml")[0] + "_tmp.yaml"
os.system("cp %s %s" % (yamlfile, tmp_yamlfile)) shutil.copyfile(yamlfile, tmp_yamlfile)
with open(tmp_yamlfile) as f, open(yamlfile, "w+", encoding="utf-8") as fw: with open(tmp_yamlfile) as f, open(yamlfile, "w+", encoding="utf-8") as fw:
y = yaml.safe_load(f) y = yaml.safe_load(f)
@ -51,7 +52,7 @@ def change_value(args):
print(yaml.dump(y, default_flow_style=False, sort_keys=False)) print(yaml.dump(y, default_flow_style=False, sort_keys=False))
yaml.dump(y, fw, allow_unicode=True) yaml.dump(y, fw, allow_unicode=True)
os.system("rm %s" % (tmp_yamlfile)) os.remove(tmp_yamlfile)
print(f"Change key: {target_key} to value: {target_value} successfully.") print(f"Change key: {target_key} to value: {target_value} successfully.")

@ -75,8 +75,8 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
os.system("rm -rf %s" % (args.output_dir)) if not os.path.exists(args.output_dir):
os.mkdir(args.output_dir) os.makedirs(args.output_dir)
first_response_list = [] first_response_list = []
final_response_list = [] final_response_list = []

Loading…
Cancel
Save