diff --git a/paddlespeech/server/bin/paddlespeech_client.py b/paddlespeech/server/bin/paddlespeech_client.py index 19bdc10b..3adf8015 100644 --- a/paddlespeech/server/bin/paddlespeech_client.py +++ b/paddlespeech/server/bin/paddlespeech_client.py @@ -18,6 +18,7 @@ import io import json import os import random +import sys import time from typing import List @@ -91,7 +92,7 @@ class TTSClientExecutor(BaseExecutor): temp_wav = str(random.getrandbits(128)) + ".wav" soundfile.write(temp_wav, samples, sample_rate) wav2pcm(temp_wav, outfile, data_type=np.int16) - os.system("rm %s" % (temp_wav)) + os.remove(temp_wav) else: logger.error("The format for saving audio only supports wav or pcm") @@ -128,6 +129,7 @@ class TTSClientExecutor(BaseExecutor): return True except Exception as e: logger.error("Failed to synthesized audio.") + logger.error(e) return False @stats_wrapper @@ -236,6 +238,7 @@ class TTSOnlineClientExecutor(BaseExecutor): return True except Exception as e: logger.error("Failed to synthesized audio.") + logger.error(e) return False @stats_wrapper @@ -275,7 +278,7 @@ class TTSOnlineClientExecutor(BaseExecutor): else: logger.error("Please set correct protocol, http or websocket") - return False + sys.exit(-1) logger.info(f"sentence: {input}") logger.info(f"duration: {duration} s") @@ -503,6 +506,7 @@ class ASROnlineClientExecutor(BaseExecutor): Returns: str: the audio text """ + logger.info("asr websocket client start") handler = ASRWsAudioHandler( server_ip, @@ -555,6 +559,7 @@ class CLSClientExecutor(BaseExecutor): return True except Exception as e: logger.error("Failed to speech classification.") + logger.error(e) return False @stats_wrapper @@ -728,6 +733,7 @@ class VectorClientExecutor(BaseExecutor): Returns: str: the audio embedding or score between enroll and test audio """ + if task == "spk": from paddlespeech.server.utils.audio_handler import VectorHttpHandler logger.info("vector http client start") diff --git a/paddlespeech/server/bin/paddlespeech_server.py b/paddlespeech/server/bin/paddlespeech_server.py index 9e3b0ed0..db92f179 100644 --- a/paddlespeech/server/bin/paddlespeech_server.py +++ b/paddlespeech/server/bin/paddlespeech_server.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import sys from typing import List import uvicorn @@ -79,10 +80,12 @@ class ServerExecutor(BaseExecutor): def execute(self, argv: List[str]) -> bool: args = self.parser.parse_args(argv) - config = get_config(args.config_file) - - if self.init(config): - uvicorn.run(app, host=config.host, port=config.port, debug=True) + try: + self(args.config_file, args.log_file) + except Exception as e: + logger.error("Failed to start server.") + logger.error(e) + sys.exit(-1) @stats_wrapper def __call__(self, diff --git a/paddlespeech/server/tests/tts/offline/http_client.py b/paddlespeech/server/tests/tts/offline/http_client.py index 1bdee4c1..24109a0e 100644 --- a/paddlespeech/server/tests/tts/offline/http_client.py +++ b/paddlespeech/server/tests/tts/offline/http_client.py @@ -61,7 +61,7 @@ def tts_client(args): temp_wav = str(random.getrandbits(128)) + ".wav" soundfile.write(temp_wav, samples, sample_rate) wav2pcm(temp_wav, outfile, data_type=np.int16) - os.system("rm %s" % (temp_wav)) + os.remove(temp_wav) else: print("The format for saving audio only supports wav or pcm") diff --git a/paddlespeech/server/utils/audio_handler.py b/paddlespeech/server/utils/audio_handler.py index 2bce28e3..b85cf485 100644 --- a/paddlespeech/server/utils/audio_handler.py +++ b/paddlespeech/server/utils/audio_handler.py @@ -304,52 +304,80 @@ class TTSWsHandler: receive_time_list = [] chunk_duration_list = [] - # 1. Send websocket handshake protocal + # 1. Send websocket handshake request async with websockets.connect(self.url) as ws: - # 2. Server has already received handshake protocal - # send text to engine + # 2. Server has already received handshake response, send start request + start_request = json.dumps({"task": "tts", "signal": "start"}) + await ws.send(start_request) + msg = await ws.recv() + logger.info(f"client receive msg={msg}") + msg = json.loads(msg) + session = msg["session"] + + # 3. send speech synthesis request text_base64 = str(base64.b64encode((text).encode('utf-8')), "UTF8") - d = {"text": text_base64} - d = json.dumps(d) + request = json.dumps({"text": text_base64}) st = time.time() - await ws.send(d) + await ws.send(request) logging.info("send a message to the server") - # 3. Process the received response + # 4. Process the received response message = await ws.recv() first_response = time.time() - st message = json.loads(message) status = message["status"] + while True: + # When throw an exception + if status == -1: + # send end request + end_request = json.dumps({ + "task": "tts", + "signal": "end", + "session": session + }) + await ws.send(end_request) + break + + # Rerutn last packet normally, no audio information + elif status == 2: + final_response = time.time() - st + duration = len(all_bytes) / 2.0 / 24000 + + if output is not None: + save_audio_success = save_audio(all_bytes, output) + else: + save_audio_success = False + + # send end request + end_request = json.dumps({ + "task": "tts", + "signal": "end", + "session": session + }) + await ws.send(end_request) + break + + # Return the audio stream normally + elif status == 1: + receive_time_list.append(time.time()) + audio = message["audio"] + audio = base64.b64decode(audio) # bytes + chunk_duration_list.append(len(audio) / 2.0 / 24000) + all_bytes += audio + if self.play: + self.mutex.acquire() + self.buffer += audio + self.mutex.release() + if self.start_play: + self.t.start() + self.start_play = False + + message = await ws.recv() + message = json.loads(message) + status = message["status"] - while (status == 1): - receive_time_list.append(time.time()) - audio = message["audio"] - audio = base64.b64decode(audio) # bytes - chunk_duration_list.append(len(audio) / 2.0 / 24000) - all_bytes += audio - if self.play: - self.mutex.acquire() - self.buffer += audio - self.mutex.release() - if self.start_play: - self.t.start() - self.start_play = False - - message = await ws.recv() - message = json.loads(message) - status = message["status"] - - # 4. Last packet, no audio information - if status == 2: - final_response = time.time() - st - duration = len(all_bytes) / 2.0 / 24000 - - if output is not None: - save_audio_success = save_audio(all_bytes, output) else: - save_audio_success = False - else: - logger.error("infer error") + logger.error("infer error, return status is invalid.") if self.play: self.t.join() @@ -458,6 +486,7 @@ class TTSHttpHandler: final_response = time.time() - st duration = len(all_bytes) / 2.0 / 24000 + html.close() # when stream=True if output is not None: save_audio_success = save_audio(all_bytes, output) diff --git a/paddlespeech/server/utils/audio_process.py b/paddlespeech/server/utils/audio_process.py index bb02d664..416d77ac 100644 --- a/paddlespeech/server/utils/audio_process.py +++ b/paddlespeech/server/utils/audio_process.py @@ -167,7 +167,7 @@ def save_audio(bytes_data, audio_path, sample_rate: int=24000) -> bool: channels=1, bits=16, sample_rate=sample_rate) - os.system("rm ./tmp.pcm") + os.remove("./tmp.pcm") else: print("Only supports saved audio format is pcm or wav") return False diff --git a/paddlespeech/server/ws/tts_api.py b/paddlespeech/server/ws/tts_api.py index 20a63d4c..a3a4c4d4 100644 --- a/paddlespeech/server/ws/tts_api.py +++ b/paddlespeech/server/ws/tts_api.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +import uuid from fastapi import APIRouter from fastapi import WebSocket -from fastapi import WebSocketDisconnect from starlette.websockets import WebSocketState as WebSocketState from paddlespeech.cli.log import logger @@ -26,36 +26,79 @@ router = APIRouter() @router.websocket('/paddlespeech/tts/streaming') async def websocket_endpoint(websocket: WebSocket): + """PaddleSpeech Online TTS Server api + + Args: + websocket (WebSocket): the websocket instance + """ + + #1. the interface wait to accept the websocket protocal header + # and only we receive the header, it establish the connection with specific thread await websocket.accept() + #2. if we accept the websocket headers, we will get the online tts engine instance + engine_pool = get_engine_pool() + tts_engine = engine_pool['tts'] + try: - # careful here, changed the source code from starlette.websockets - assert websocket.application_state == WebSocketState.CONNECTED - message = await websocket.receive() - websocket._raise_on_disconnect(message) + while True: + # careful here, changed the source code from starlette.websockets + assert websocket.application_state == WebSocketState.CONNECTED + message = await websocket.receive() + websocket._raise_on_disconnect(message) + message = json.loads(message["text"]) - # get engine - engine_pool = get_engine_pool() - tts_engine = engine_pool['tts'] + if 'signal' in message: + # start request + if message['signal'] == 'start': + session = uuid.uuid1().hex + resp = { + "status": 0, + "signal": "server ready", + "session": session + } + await websocket.send_json(resp) - # 获取 message 并转文本 - message = json.loads(message["text"]) - text_bese64 = message["text"] - sentence = tts_engine.preprocess(text_bese64=text_bese64) + # end request + elif message['signal'] == 'end': + resp = { + "status": 0, + "signal": "connection will be closed", + "session": session + } + await websocket.send_json(resp) + break + else: + resp = {"status": 0, "signal": "no valid json data"} + await websocket.send_json(resp) - # run - wav_generator = tts_engine.run(sentence) + # speech synthesis request + elif 'text' in message: + text_bese64 = message["text"] + sentence = tts_engine.preprocess(text_bese64=text_bese64) - while True: - try: - tts_results = next(wav_generator) - resp = {"status": 1, "audio": tts_results} - await websocket.send_json(resp) - except StopIteration as e: - resp = {"status": 2, "audio": ''} - await websocket.send_json(resp) - logger.info("Complete the transmission of audio streams") - break - - except WebSocketDisconnect: - pass \ No newline at end of file + # run + wav_generator = tts_engine.run(sentence) + + while True: + try: + tts_results = next(wav_generator) + resp = {"status": 1, "audio": tts_results} + await websocket.send_json(resp) + except StopIteration as e: + resp = {"status": 2, "audio": ''} + await websocket.send_json(resp) + logger.info( + "Complete the synthesis of the audio streams") + break + except Exception as e: + resp = {"status": -1, "audio": ''} + await websocket.send_json(resp) + break + + else: + logger.error( + "Invalid request, please check if the request is correct.") + + except Exception as e: + logger.error(e) diff --git a/tests/unit/server/offline/change_yaml.py b/tests/unit/server/offline/change_yaml.py index d51a6259..ded7e3b4 100644 --- a/tests/unit/server/offline/change_yaml.py +++ b/tests/unit/server/offline/change_yaml.py @@ -1,6 +1,7 @@ #!/usr/bin/python import argparse import os +import shutil import yaml @@ -14,7 +15,7 @@ def change_device(yamlfile: str, engine: str, device: str): model_type (dict): change model type """ tmp_yamlfile = yamlfile.split(".yaml")[0] + "_tmp.yaml" - os.system("cp %s %s" % (yamlfile, tmp_yamlfile)) + shutil.copyfile(yamlfile, tmp_yamlfile) if device == 'cpu': set_device = 'cpu' @@ -41,7 +42,7 @@ def change_device(yamlfile: str, engine: str, device: str): print(yaml.dump(y, default_flow_style=False, sort_keys=False)) yaml.dump(y, fw, allow_unicode=True) - os.system("rm %s" % (tmp_yamlfile)) + os.remove(tmp_yamlfile) print("Change %s successfully." % (yamlfile)) @@ -52,7 +53,7 @@ def change_engine_type(yamlfile: str, engine_type): task (str): asr or tts """ tmp_yamlfile = yamlfile.split(".yaml")[0] + "_tmp.yaml" - os.system("cp %s %s" % (yamlfile, tmp_yamlfile)) + shutil.copyfile(yamlfile, tmp_yamlfile) speech_task = engine_type.split("_")[0] with open(tmp_yamlfile) as f, open(yamlfile, "w+", encoding="utf-8") as fw: @@ -65,7 +66,7 @@ def change_engine_type(yamlfile: str, engine_type): y['engine_list'] = engine_list print(yaml.dump(y, default_flow_style=False, sort_keys=False)) yaml.dump(y, fw, allow_unicode=True) - os.system("rm %s" % (tmp_yamlfile)) + os.remove(tmp_yamlfile) print("Change %s successfully." % (yamlfile)) diff --git a/tests/unit/server/online/tts/check_server/change_yaml.py b/tests/unit/server/online/tts/check_server/change_yaml.py index 01351df0..b04ad0a8 100644 --- a/tests/unit/server/online/tts/check_server/change_yaml.py +++ b/tests/unit/server/online/tts/check_server/change_yaml.py @@ -1,6 +1,7 @@ #!/usr/bin/python import argparse import os +import shutil import yaml @@ -13,7 +14,7 @@ def change_value(args): target_value = args.target_value tmp_yamlfile = yamlfile.split(".yaml")[0] + "_tmp.yaml" - os.system("cp %s %s" % (yamlfile, tmp_yamlfile)) + shutil.copyfile(yamlfile, tmp_yamlfile) with open(tmp_yamlfile) as f, open(yamlfile, "w+", encoding="utf-8") as fw: y = yaml.safe_load(f) @@ -51,7 +52,7 @@ def change_value(args): print(yaml.dump(y, default_flow_style=False, sort_keys=False)) yaml.dump(y, fw, allow_unicode=True) - os.system("rm %s" % (tmp_yamlfile)) + os.remove(tmp_yamlfile) print(f"Change key: {target_key} to value: {target_value} successfully.") diff --git a/tests/unit/server/online/tts/test_server/test_http_client.py b/tests/unit/server/online/tts/test_server/test_http_client.py index 7fdb4e00..3174e85e 100644 --- a/tests/unit/server/online/tts/test_server/test_http_client.py +++ b/tests/unit/server/online/tts/test_server/test_http_client.py @@ -75,8 +75,8 @@ if __name__ == "__main__": args = parser.parse_args() - os.system("rm -rf %s" % (args.output_dir)) - os.mkdir(args.output_dir) + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) first_response_list = [] final_response_list = []