improve, test=doc

pull/1882/head
lym0302 3 years ago
parent 018dda6ee9
commit d4f863dc97

@ -18,6 +18,7 @@ import io
import json import json
import os import os
import random import random
import sys
import time import time
from typing import List from typing import List
@ -32,6 +33,7 @@ from paddlespeech.cli.log import logger
from paddlespeech.server.utils.audio_handler import ASRWsAudioHandler from paddlespeech.server.utils.audio_handler import ASRWsAudioHandler
from paddlespeech.server.utils.audio_process import wav2pcm from paddlespeech.server.utils.audio_process import wav2pcm
from paddlespeech.server.utils.util import compute_delay from paddlespeech.server.utils.util import compute_delay
from paddlespeech.server.utils.util import network_reachable
from paddlespeech.server.utils.util import wav2base64 from paddlespeech.server.utils.util import wav2base64
__all__ = [ __all__ = [
@ -128,6 +130,7 @@ class TTSClientExecutor(BaseExecutor):
return True return True
except Exception as e: except Exception as e:
logger.error("Failed to synthesized audio.") logger.error("Failed to synthesized audio.")
logger.error(e)
return False return False
@stats_wrapper @stats_wrapper
@ -154,6 +157,12 @@ class TTSClientExecutor(BaseExecutor):
"save_path": output "save_path": output
} }
# Check if the network is reachable
network = 'http://' + server_ip + ":" + str(port)
if network_reachable(network) is not True:
logger.error(f"{network} unreachable, please check the ip address.")
sys.exit(-1)
res = requests.post(url, json.dumps(request)) res = requests.post(url, json.dumps(request))
response_dict = res.json() response_dict = res.json()
if output is not None: if output is not None:
@ -236,6 +245,7 @@ class TTSOnlineClientExecutor(BaseExecutor):
return True return True
except Exception as e: except Exception as e:
logger.error("Failed to synthesized audio.") logger.error("Failed to synthesized audio.")
logger.error(e)
return False return False
@stats_wrapper @stats_wrapper
@ -254,6 +264,12 @@ class TTSOnlineClientExecutor(BaseExecutor):
Python API to call an executor. Python API to call an executor.
""" """
# Check if the network is reachable
network = 'http://' + server_ip + ":" + str(port)
if network_reachable(network) is not True:
logger.error(f"{network} unreachable, please check the ip address.")
sys.exit(-1)
if protocol == "http": if protocol == "http":
logger.info("tts http client start") logger.info("tts http client start")
from paddlespeech.server.utils.audio_handler import TTSHttpHandler from paddlespeech.server.utils.audio_handler import TTSHttpHandler
@ -275,7 +291,7 @@ class TTSOnlineClientExecutor(BaseExecutor):
else: else:
logger.error("Please set correct protocol, http or websocket") logger.error("Please set correct protocol, http or websocket")
return False sys.exit(-1)
logger.info(f"sentence: {input}") logger.info(f"sentence: {input}")
logger.info(f"duration: {duration} s") logger.info(f"duration: {duration} s")
@ -399,6 +415,13 @@ class ASRClientExecutor(BaseExecutor):
# and paddlespeech_client asr only support http protocol # and paddlespeech_client asr only support http protocol
protocol = "http" protocol = "http"
if protocol.lower() == "http": if protocol.lower() == "http":
# Check if the network is reachable
network = 'http://' + server_ip + ":" + str(port)
if network_reachable(network) is not True:
logger.error(
f"{network} unreachable, please check the ip address.")
sys.exit(-1)
from paddlespeech.server.utils.audio_handler import ASRHttpHandler from paddlespeech.server.utils.audio_handler import ASRHttpHandler
logger.info("asr http client start") logger.info("asr http client start")
handler = ASRHttpHandler(server_ip=server_ip, port=port) handler = ASRHttpHandler(server_ip=server_ip, port=port)
@ -503,6 +526,13 @@ class ASROnlineClientExecutor(BaseExecutor):
Returns: Returns:
str: the audio text str: the audio text
""" """
# Check if the network is reachable
network = 'http://' + server_ip + ":" + str(port)
if network_reachable(network) is not True:
logger.error(f"{network} unreachable, please check the ip address.")
sys.exit(-1)
logger.info("asr websocket client start") logger.info("asr websocket client start")
handler = ASRWsAudioHandler( handler = ASRWsAudioHandler(
server_ip, server_ip,
@ -555,6 +585,7 @@ class CLSClientExecutor(BaseExecutor):
return True return True
except Exception as e: except Exception as e:
logger.error("Failed to speech classification.") logger.error("Failed to speech classification.")
logger.error(e)
return False return False
@stats_wrapper @stats_wrapper
@ -567,6 +598,12 @@ class CLSClientExecutor(BaseExecutor):
Python API to call an executor. Python API to call an executor.
""" """
# Check if the network is reachable
network = 'http://' + server_ip + ":" + str(port)
if network_reachable(network) is not True:
logger.error(f"{network} unreachable, please check the ip address.")
sys.exit(-1)
url = 'http://' + server_ip + ":" + str(port) + '/paddlespeech/cls' url = 'http://' + server_ip + ":" + str(port) + '/paddlespeech/cls'
audio = wav2base64(input) audio = wav2base64(input)
data = {"audio": audio, "topk": topk} data = {"audio": audio, "topk": topk}
@ -632,6 +669,12 @@ class TextClientExecutor(BaseExecutor):
str: the punctuation text str: the punctuation text
""" """
# Check if the network is reachable
network = 'http://' + server_ip + ":" + str(port)
if network_reachable(network) is not True:
logger.error(f"{network} unreachable, please check the ip address.")
sys.exit(-1)
url = 'http://' + server_ip + ":" + str(port) + '/paddlespeech/text' url = 'http://' + server_ip + ":" + str(port) + '/paddlespeech/text'
request = { request = {
"text": input, "text": input,
@ -728,6 +771,13 @@ class VectorClientExecutor(BaseExecutor):
Returns: Returns:
str: the audio embedding or score between enroll and test audio str: the audio embedding or score between enroll and test audio
""" """
# Check if the network is reachable
network = 'http://' + server_ip + ":" + str(port)
if network_reachable(network) is not True:
logger.error(f"{network} unreachable, please check the ip address.")
sys.exit(-1)
if task == "spk": if task == "spk":
from paddlespeech.server.utils.audio_handler import VectorHttpHandler from paddlespeech.server.utils.audio_handler import VectorHttpHandler
logger.info("vector http client start") logger.info("vector http client start")

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import sys
from typing import List from typing import List
import uvicorn import uvicorn
@ -79,10 +80,12 @@ class ServerExecutor(BaseExecutor):
def execute(self, argv: List[str]) -> bool: def execute(self, argv: List[str]) -> bool:
args = self.parser.parse_args(argv) args = self.parser.parse_args(argv)
config = get_config(args.config_file) try:
self(args.config_file, args.log_file)
if self.init(config): except Exception as e:
uvicorn.run(app, host=config.host, port=config.port, debug=True) logger.error("Failed to start server.")
logger.error(e)
sys.exit(-1)
@stats_wrapper @stats_wrapper
def __call__(self, def __call__(self,

@ -304,18 +304,24 @@ class TTSWsHandler:
receive_time_list = [] receive_time_list = []
chunk_duration_list = [] chunk_duration_list = []
# 1. Send websocket handshake protocal # 1. Send websocket handshake request
async with websockets.connect(self.url) as ws: async with websockets.connect(self.url) as ws:
# 2. Server has already received handshake protocal # 2. Server has already received handshake response, send start request
# send text to engine start_request = json.dumps({"task": "tts", "signal": "start"})
await ws.send(start_request)
msg = await ws.recv()
logger.info(f"client receive msg={msg}")
msg = json.loads(msg)
session = msg["session"]
# 3. send speech synthesis request
text_base64 = str(base64.b64encode((text).encode('utf-8')), "UTF8") text_base64 = str(base64.b64encode((text).encode('utf-8')), "UTF8")
d = {"text": text_base64} request = json.dumps({"text": text_base64})
d = json.dumps(d)
st = time.time() st = time.time()
await ws.send(d) await ws.send(request)
logging.info("send a message to the server") logging.info("send a message to the server")
# 3. Process the received response # Process the received response
message = await ws.recv() message = await ws.recv()
first_response = time.time() - st first_response = time.time() - st
message = json.loads(message) message = json.loads(message)
@ -348,6 +354,15 @@ class TTSWsHandler:
save_audio_success = save_audio(all_bytes, output) save_audio_success = save_audio(all_bytes, output)
else: else:
save_audio_success = False save_audio_success = False
# 5. send end request
end_request = json.dumps({
"task": "tts",
"signal": "end",
"session": session
})
await ws.send(end_request)
else: else:
logger.error("infer error") logger.error("infer error")
@ -458,6 +473,7 @@ class TTSHttpHandler:
final_response = time.time() - st final_response = time.time() - st
duration = len(all_bytes) / 2.0 / 24000 duration = len(all_bytes) / 2.0 / 24000
html.close() # when stream=True
if output is not None: if output is not None:
save_audio_success = save_audio(all_bytes, output) save_audio_success = save_audio(all_bytes, output)

@ -13,6 +13,8 @@
import base64 import base64
import math import math
import requests
def wav2base64(wav_file: str): def wav2base64(wav_file: str):
""" """
@ -146,3 +148,21 @@ def count_engine(logfile: str="./nohup.out"):
print( print(
f"max final response: {max(final_response_list)} s, min final response: {min(final_response_list)} s" f"max final response: {max(final_response_list)} s, min final response: {min(final_response_list)} s"
) )
def network_reachable(url: str, timeout: int=5) -> bool:
"""Check if the network is reachable
Args:
url (str): http://server_ip:port or ws://server_ip:port
timeout (int, optional): timeout. Defaults to 5.
Returns:
bool: Whether the network is reachable.
"""
try:
request = requests.get(url, timeout=timeout)
return True
except (requests.ConnectionError, requests.Timeout) as exception:
print(exception)
return False

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json import json
import uuid
from fastapi import APIRouter from fastapi import APIRouter
from fastapi import WebSocket from fastapi import WebSocket
@ -26,36 +27,71 @@ router = APIRouter()
@router.websocket('/paddlespeech/tts/streaming') @router.websocket('/paddlespeech/tts/streaming')
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
"""PaddleSpeech Online TTS Server api
Args:
websocket (WebSocket): the websocket instance
"""
#1. the interface wait to accept the websocket protocal header
# and only we receive the header, it establish the connection with specific thread
await websocket.accept() await websocket.accept()
#2. if we accept the websocket headers, we will get the online tts engine instance
engine_pool = get_engine_pool()
tts_engine = engine_pool['tts']
try: try:
# careful here, changed the source code from starlette.websockets while True:
assert websocket.application_state == WebSocketState.CONNECTED # careful here, changed the source code from starlette.websockets
message = await websocket.receive() assert websocket.application_state == WebSocketState.CONNECTED
websocket._raise_on_disconnect(message) message = await websocket.receive()
websocket._raise_on_disconnect(message)
message = json.loads(message["text"])
# get engine if 'signal' in message:
engine_pool = get_engine_pool() # start request
tts_engine = engine_pool['tts'] if message['signal'] == 'start':
session = uuid.uuid1().hex
resp = {
"status": 0,
"signal": "server ready",
"session": session
}
await websocket.send_json(resp)
# 获取 message 并转文本 # end request
message = json.loads(message["text"]) elif message['signal'] == 'end':
text_bese64 = message["text"] resp = {
sentence = tts_engine.preprocess(text_bese64=text_bese64) "status": 0,
"signal": "connection will be closed",
"session": session
}
await websocket.send_json(resp)
# run # speech synthesis request
wav_generator = tts_engine.run(sentence) elif 'text' in message:
text_bese64 = message["text"]
sentence = tts_engine.preprocess(text_bese64=text_bese64)
while True: # run
try: wav_generator = tts_engine.run(sentence)
tts_results = next(wav_generator)
resp = {"status": 1, "audio": tts_results} while True:
await websocket.send_json(resp) try:
except StopIteration as e: tts_results = next(wav_generator)
resp = {"status": 2, "audio": ''} resp = {"status": 1, "audio": tts_results}
await websocket.send_json(resp) await websocket.send_json(resp)
logger.info("Complete the transmission of audio streams") except StopIteration as e:
break resp = {"status": 2, "audio": ''}
await websocket.send_json(resp)
logger.info(
"Complete the transmission of audio streams")
break
else:
logger.error(
"Invalid request, please check if the request is correct.")
except WebSocketDisconnect: except WebSocketDisconnect:
pass pass

Loading…
Cancel
Save