From f22abf04cc69b0cd2d9133cfcedf05e59cb445a6 Mon Sep 17 00:00:00 2001 From: lym0302 Date: Tue, 2 Aug 2022 12:46:48 +0000 Subject: [PATCH] add tts server offline handle, test=doc --- .../server/bin/paddlespeech_client.py | 49 +++--------- paddlespeech/server/utils/audio_handler.py | 74 +++++++++++++++++++ 2 files changed, 86 insertions(+), 37 deletions(-) diff --git a/paddlespeech/server/bin/paddlespeech_client.py b/paddlespeech/server/bin/paddlespeech_client.py index f5dc368dd..208cb9607 100644 --- a/paddlespeech/server/bin/paddlespeech_client.py +++ b/paddlespeech/server/bin/paddlespeech_client.py @@ -13,26 +13,19 @@ # limitations under the License. import argparse import asyncio -import base64 -import io import json -import os -import random import sys import time import warnings from typing import List -import numpy as np import requests -import soundfile from ..executor import BaseExecutor from ..util import cli_client_register from ..util import stats_wrapper from paddlespeech.cli.log import logger from paddlespeech.server.utils.audio_handler import ASRWsAudioHandler -from paddlespeech.server.utils.audio_process import wav2pcm from paddlespeech.server.utils.util import compute_delay from paddlespeech.server.utils.util import wav2base64 warnings.filterwarnings("ignore") @@ -81,23 +74,6 @@ class TTSClientExecutor(BaseExecutor): self.parser.add_argument( '--output', type=str, default=None, help='Synthesized audio file') - def postprocess(self, wav_base64: str, outfile: str) -> float: - audio_data_byte = base64.b64decode(wav_base64) - # from byte - samples, sample_rate = soundfile.read( - io.BytesIO(audio_data_byte), dtype='float32') - - # transform audio - if outfile.endswith(".wav"): - soundfile.write(outfile, samples, sample_rate) - elif outfile.endswith(".pcm"): - temp_wav = str(random.getrandbits(128)) + ".wav" - soundfile.write(temp_wav, samples, sample_rate) - wav2pcm(temp_wav, outfile, data_type=np.int16) - os.remove(temp_wav) - else: - logger.error("The format for saving audio only supports wav or pcm") - def execute(self, argv: List[str]) -> bool: args = self.parser.parse_args(argv) input_ = args.input @@ -147,20 +123,19 @@ class TTSClientExecutor(BaseExecutor): Python API to call an executor. """ - url = 'http://' + server_ip + ":" + str(port) + '/paddlespeech/tts' - request = { - "text": input, - "spk_id": spk_id, - "speed": speed, - "volume": volume, - "sample_rate": sample_rate, - "save_path": output - } + protocol = "http" + if protocol.lower() == "http": + from paddlespeech.server.utils.audio_handler import TTSHttpOfflineHandler + logger.info("asr http client start") + handler = TTSHttpOfflineHandler(server_ip=server_ip, port=port) + res = handler.run(input, spk_id, speed, volume, sample_rate, output) + logger.info("tts http client finished") + else: + logger.error( + f"Sorry, we have not support protocol: {protocol}, please use http" + ) + sys.exit(-1) - res = requests.post(url, json.dumps(request)) - response_dict = res.json() - if output is not None: - self.postprocess(response_dict["result"]["audio"], output) return res diff --git a/paddlespeech/server/utils/audio_handler.py b/paddlespeech/server/utils/audio_handler.py index 43b73d6eb..59ffe673d 100644 --- a/paddlespeech/server/utils/audio_handler.py +++ b/paddlespeech/server/utils/audio_handler.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import base64 +import io import json import logging import threading @@ -24,6 +25,7 @@ import websockets from paddlespeech.cli.log import logger from paddlespeech.server.utils.audio_process import save_audio +from paddlespeech.server.utils.audio_process import wav2pcm from paddlespeech.server.utils.util import wav2base64 @@ -252,6 +254,78 @@ class ASRHttpHandler: return res.json() +class TTSHttpOfflineHandler: + def __init__(self, server_ip=None, port=None, endpoint="/paddlespeech/tts"): + """The ASR client http request + + Args: + server_ip (str, optional): the http asr server ip. Defaults to "127.0.0.1". + port (int, optional): the http asr server port. Defaults to 8090. + """ + super().__init__() + self.server_ip = server_ip + self.port = port + if server_ip is None or port is None: + self.url = None + else: + self.url = 'http://' + self.server_ip + ":" + str( + self.port) + endpoint + logger.info(f"endpoint: {self.url}") + + def postprocess(self, wav_base64: str, outfile: str) -> float: + audio_data_byte = base64.b64decode(wav_base64) + # from byte + samples, sample_rate = soundfile.read( + io.BytesIO(audio_data_byte), dtype='float32') + + # transform audio + if outfile.endswith(".wav"): + soundfile.write(outfile, samples, sample_rate) + elif outfile.endswith(".pcm"): + temp_wav = str(random.getrandbits(128)) + ".wav" + soundfile.write(temp_wav, samples, sample_rate) + wav2pcm(temp_wav, outfile, data_type=np.int16) + os.remove(temp_wav) + else: + logger.error("The format for saving audio only supports wav or pcm") + + def run(self, + input: str, + spk_id: int=0, + speed: float=1.0, + volume: float=1.0, + sample_rate: int=0, + output: str=None): + """_summary_ + + Args: + input (str): sentence to be synthesized + spk_id (int): speaker id + speed (float): 1.0 < speed < 3.0 + volume (float): 1.0 < speed < 3.0 + sample_rate (int): sample rate + output (str): save audio path on server + + Returns: + res: response + """ + request = { + "text": input, + "spk_id": spk_id, + "speed": speed, + "volume": volume, + "sample_rate": sample_rate, + "save_path": output + } + + res = requests.post(url=self.url, data=json.dumps(request)) + response_dict = res.json() + if output is not None: + self.postprocess(response_dict["result"]["audio"], output) + + return res + + class TTSWsHandler: def __init__(self, server="127.0.0.1", port=8092, play: bool=False): """PaddleSpeech Online TTS Server Client audio handler