Merge pull request #1882 from lym0302/streaming_tts_server

[server] improve code
pull/1887/head
liangym 3 years ago committed by GitHub
commit 22b67ed051
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -18,6 +18,7 @@ import io
import json
import os
import random
import sys
import time
from typing import List
@ -91,7 +92,7 @@ class TTSClientExecutor(BaseExecutor):
temp_wav = str(random.getrandbits(128)) + ".wav"
soundfile.write(temp_wav, samples, sample_rate)
wav2pcm(temp_wav, outfile, data_type=np.int16)
os.system("rm %s" % (temp_wav))
os.remove(temp_wav)
else:
logger.error("The format for saving audio only supports wav or pcm")
@ -128,6 +129,7 @@ class TTSClientExecutor(BaseExecutor):
return True
except Exception as e:
logger.error("Failed to synthesized audio.")
logger.error(e)
return False
@stats_wrapper
@ -236,6 +238,7 @@ class TTSOnlineClientExecutor(BaseExecutor):
return True
except Exception as e:
logger.error("Failed to synthesized audio.")
logger.error(e)
return False
@stats_wrapper
@ -275,7 +278,7 @@ class TTSOnlineClientExecutor(BaseExecutor):
else:
logger.error("Please set correct protocol, http or websocket")
return False
sys.exit(-1)
logger.info(f"sentence: {input}")
logger.info(f"duration: {duration} s")
@ -503,6 +506,7 @@ class ASROnlineClientExecutor(BaseExecutor):
Returns:
str: the audio text
"""
logger.info("asr websocket client start")
handler = ASRWsAudioHandler(
server_ip,
@ -555,6 +559,7 @@ class CLSClientExecutor(BaseExecutor):
return True
except Exception as e:
logger.error("Failed to speech classification.")
logger.error(e)
return False
@stats_wrapper
@ -728,6 +733,7 @@ class VectorClientExecutor(BaseExecutor):
Returns:
str: the audio embedding or score between enroll and test audio
"""
if task == "spk":
from paddlespeech.server.utils.audio_handler import VectorHttpHandler
logger.info("vector http client start")

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import sys
from typing import List
import uvicorn
@ -79,10 +80,12 @@ class ServerExecutor(BaseExecutor):
def execute(self, argv: List[str]) -> bool:
args = self.parser.parse_args(argv)
config = get_config(args.config_file)
if self.init(config):
uvicorn.run(app, host=config.host, port=config.port, debug=True)
try:
self(args.config_file, args.log_file)
except Exception as e:
logger.error("Failed to start server.")
logger.error(e)
sys.exit(-1)
@stats_wrapper
def __call__(self,

@ -61,7 +61,7 @@ def tts_client(args):
temp_wav = str(random.getrandbits(128)) + ".wav"
soundfile.write(temp_wav, samples, sample_rate)
wav2pcm(temp_wav, outfile, data_type=np.int16)
os.system("rm %s" % (temp_wav))
os.remove(temp_wav)
else:
print("The format for saving audio only supports wav or pcm")

@ -304,24 +304,61 @@ class TTSWsHandler:
receive_time_list = []
chunk_duration_list = []
# 1. Send websocket handshake protocal
# 1. Send websocket handshake request
async with websockets.connect(self.url) as ws:
# 2. Server has already received handshake protocal
# send text to engine
# 2. Server has already received handshake response, send start request
start_request = json.dumps({"task": "tts", "signal": "start"})
await ws.send(start_request)
msg = await ws.recv()
logger.info(f"client receive msg={msg}")
msg = json.loads(msg)
session = msg["session"]
# 3. send speech synthesis request
text_base64 = str(base64.b64encode((text).encode('utf-8')), "UTF8")
d = {"text": text_base64}
d = json.dumps(d)
request = json.dumps({"text": text_base64})
st = time.time()
await ws.send(d)
await ws.send(request)
logging.info("send a message to the server")
# 3. Process the received response
# 4. Process the received response
message = await ws.recv()
first_response = time.time() - st
message = json.loads(message)
status = message["status"]
while True:
# When throw an exception
if status == -1:
# send end request
end_request = json.dumps({
"task": "tts",
"signal": "end",
"session": session
})
await ws.send(end_request)
break
while (status == 1):
# Rerutn last packet normally, no audio information
elif status == 2:
final_response = time.time() - st
duration = len(all_bytes) / 2.0 / 24000
if output is not None:
save_audio_success = save_audio(all_bytes, output)
else:
save_audio_success = False
# send end request
end_request = json.dumps({
"task": "tts",
"signal": "end",
"session": session
})
await ws.send(end_request)
break
# Return the audio stream normally
elif status == 1:
receive_time_list.append(time.time())
audio = message["audio"]
audio = base64.b64decode(audio) # bytes
@ -339,17 +376,8 @@ class TTSWsHandler:
message = json.loads(message)
status = message["status"]
# 4. Last packet, no audio information
if status == 2:
final_response = time.time() - st
duration = len(all_bytes) / 2.0 / 24000
if output is not None:
save_audio_success = save_audio(all_bytes, output)
else:
save_audio_success = False
else:
logger.error("infer error")
logger.error("infer error, return status is invalid.")
if self.play:
self.t.join()
@ -458,6 +486,7 @@ class TTSHttpHandler:
final_response = time.time() - st
duration = len(all_bytes) / 2.0 / 24000
html.close() # when stream=True
if output is not None:
save_audio_success = save_audio(all_bytes, output)

@ -167,7 +167,7 @@ def save_audio(bytes_data, audio_path, sample_rate: int=24000) -> bool:
channels=1,
bits=16,
sample_rate=sample_rate)
os.system("rm ./tmp.pcm")
os.remove("./tmp.pcm")
else:
print("Only supports saved audio format is pcm or wav")
return False

@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import uuid
from fastapi import APIRouter
from fastapi import WebSocket
from fastapi import WebSocketDisconnect
from starlette.websockets import WebSocketState as WebSocketState
from paddlespeech.cli.log import logger
@ -26,20 +26,54 @@ router = APIRouter()
@router.websocket('/paddlespeech/tts/streaming')
async def websocket_endpoint(websocket: WebSocket):
"""PaddleSpeech Online TTS Server api
Args:
websocket (WebSocket): the websocket instance
"""
#1. the interface wait to accept the websocket protocal header
# and only we receive the header, it establish the connection with specific thread
await websocket.accept()
#2. if we accept the websocket headers, we will get the online tts engine instance
engine_pool = get_engine_pool()
tts_engine = engine_pool['tts']
try:
while True:
# careful here, changed the source code from starlette.websockets
assert websocket.application_state == WebSocketState.CONNECTED
message = await websocket.receive()
websocket._raise_on_disconnect(message)
message = json.loads(message["text"])
# get engine
engine_pool = get_engine_pool()
tts_engine = engine_pool['tts']
if 'signal' in message:
# start request
if message['signal'] == 'start':
session = uuid.uuid1().hex
resp = {
"status": 0,
"signal": "server ready",
"session": session
}
await websocket.send_json(resp)
# 获取 message 并转文本
message = json.loads(message["text"])
# end request
elif message['signal'] == 'end':
resp = {
"status": 0,
"signal": "connection will be closed",
"session": session
}
await websocket.send_json(resp)
break
else:
resp = {"status": 0, "signal": "no valid json data"}
await websocket.send_json(resp)
# speech synthesis request
elif 'text' in message:
text_bese64 = message["text"]
sentence = tts_engine.preprocess(text_bese64=text_bese64)
@ -54,8 +88,17 @@ async def websocket_endpoint(websocket: WebSocket):
except StopIteration as e:
resp = {"status": 2, "audio": ''}
await websocket.send_json(resp)
logger.info("Complete the transmission of audio streams")
logger.info(
"Complete the synthesis of the audio streams")
break
except Exception as e:
resp = {"status": -1, "audio": ''}
await websocket.send_json(resp)
break
except WebSocketDisconnect:
pass
else:
logger.error(
"Invalid request, please check if the request is correct.")
except Exception as e:
logger.error(e)

@ -1,6 +1,7 @@
#!/usr/bin/python
import argparse
import os
import shutil
import yaml
@ -14,7 +15,7 @@ def change_device(yamlfile: str, engine: str, device: str):
model_type (dict): change model type
"""
tmp_yamlfile = yamlfile.split(".yaml")[0] + "_tmp.yaml"
os.system("cp %s %s" % (yamlfile, tmp_yamlfile))
shutil.copyfile(yamlfile, tmp_yamlfile)
if device == 'cpu':
set_device = 'cpu'
@ -41,7 +42,7 @@ def change_device(yamlfile: str, engine: str, device: str):
print(yaml.dump(y, default_flow_style=False, sort_keys=False))
yaml.dump(y, fw, allow_unicode=True)
os.system("rm %s" % (tmp_yamlfile))
os.remove(tmp_yamlfile)
print("Change %s successfully." % (yamlfile))
@ -52,7 +53,7 @@ def change_engine_type(yamlfile: str, engine_type):
task (str): asr or tts
"""
tmp_yamlfile = yamlfile.split(".yaml")[0] + "_tmp.yaml"
os.system("cp %s %s" % (yamlfile, tmp_yamlfile))
shutil.copyfile(yamlfile, tmp_yamlfile)
speech_task = engine_type.split("_")[0]
with open(tmp_yamlfile) as f, open(yamlfile, "w+", encoding="utf-8") as fw:
@ -65,7 +66,7 @@ def change_engine_type(yamlfile: str, engine_type):
y['engine_list'] = engine_list
print(yaml.dump(y, default_flow_style=False, sort_keys=False))
yaml.dump(y, fw, allow_unicode=True)
os.system("rm %s" % (tmp_yamlfile))
os.remove(tmp_yamlfile)
print("Change %s successfully." % (yamlfile))

@ -1,6 +1,7 @@
#!/usr/bin/python
import argparse
import os
import shutil
import yaml
@ -13,7 +14,7 @@ def change_value(args):
target_value = args.target_value
tmp_yamlfile = yamlfile.split(".yaml")[0] + "_tmp.yaml"
os.system("cp %s %s" % (yamlfile, tmp_yamlfile))
shutil.copyfile(yamlfile, tmp_yamlfile)
with open(tmp_yamlfile) as f, open(yamlfile, "w+", encoding="utf-8") as fw:
y = yaml.safe_load(f)
@ -51,7 +52,7 @@ def change_value(args):
print(yaml.dump(y, default_flow_style=False, sort_keys=False))
yaml.dump(y, fw, allow_unicode=True)
os.system("rm %s" % (tmp_yamlfile))
os.remove(tmp_yamlfile)
print(f"Change key: {target_key} to value: {target_value} successfully.")

@ -75,8 +75,8 @@ if __name__ == "__main__":
args = parser.parse_args()
os.system("rm -rf %s" % (args.output_dir))
os.mkdir(args.output_dir)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
first_response_list = []
final_response_list = []

Loading…
Cancel
Save