From 3b75f63fa9f8893ae6d0bf0bfabeccef3ebba972 Mon Sep 17 00:00:00 2001 From: lym0302 Date: Wed, 26 Jan 2022 16:50:22 +0800 Subject: [PATCH 1/2] add error code, test=server --- .../engine/tts/python/tts_engine.py | 35 ++++++++---- .../speechserving/restful/response.py | 2 +- .../speechserving/restful/tts_api.py | 32 ++++++++--- speechserving/speechserving/utils/errors.py | 57 +++++++++++++++++++ .../speechserving/utils/exception.py | 30 ++++++++++ 5 files changed, 137 insertions(+), 19 deletions(-) create mode 100644 speechserving/speechserving/utils/exception.py diff --git a/speechserving/speechserving/engine/tts/python/tts_engine.py b/speechserving/speechserving/engine/tts/python/tts_engine.py index d790aa31..4f0e9906 100644 --- a/speechserving/speechserving/engine/tts/python/tts_engine.py +++ b/speechserving/speechserving/engine/tts/python/tts_engine.py @@ -22,6 +22,8 @@ from engine.base_engine import BaseEngine from paddlespeech.cli.log import logger from paddlespeech.cli.tts.infer import TTSExecutor +from utils.errors import ErrorCode +from utils.exception import ServerBaseException __all__ = ['TTSEngine'] @@ -128,16 +130,27 @@ class TTSEngine(BaseEngine): lang = self.conf_dict["lang"] - self.executor.infer( - text=sentence, lang=lang, am=self.conf_dict["am"], spk_id=spk_id) - - target_sample_rate, wav_base64 = self.postprocess( - wav=self.executor._outputs['wav'].numpy(), - original_fs=self.executor.am_config.fs, - target_fs=sample_rate, - volume=volume, - speed=speed, - audio_path=save_path, - audio_format=audio_format) + try: + self.executor.infer( + text=sentence, + lang=lang, + am=self.conf_dict["am"], + spk_id=spk_id) + except: + raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, + "tts infer failed.") + + try: + target_sample_rate, wav_base64 = self.postprocess( + wav=self.executor._outputs['wav'].numpy(), + original_fs=self.executor.am_config.fs, + target_fs=sample_rate, + volume=volume, + speed=speed, + audio_path=save_path, + audio_format=audio_format) + except: + raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, + "tts postprocess failed.") return lang, target_sample_rate, wav_base64 diff --git a/speechserving/speechserving/restful/response.py b/speechserving/speechserving/restful/response.py index 684a37f9..db24f531 100644 --- a/speechserving/speechserving/restful/response.py +++ b/speechserving/speechserving/restful/response.py @@ -68,7 +68,7 @@ class TTSResponse(BaseModel): response example { "success": true, - "code": 0, + "code": 200, "message": { "description": "success" }, diff --git a/speechserving/speechserving/restful/tts_api.py b/speechserving/speechserving/restful/tts_api.py index c78eaf63..69930f24 100644 --- a/speechserving/speechserving/restful/tts_api.py +++ b/speechserving/speechserving/restful/tts_api.py @@ -11,11 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import traceback + from engine.tts.python.tts_engine import TTSEngine from fastapi import APIRouter from .request import TTSRequest from .response import TTSResponse +from utils.errors import ErrorCode +from utils.errors import ErrorMsg +from utils.errors import failed_response +from utils.exception import ServerBaseException router = APIRouter() @@ -28,7 +34,7 @@ def help(): json: [description] """ json_body = { - "success": true, + "success": "True", "code": 0, "message": { "global": "success" @@ -62,19 +68,31 @@ def tts(request_body: TTSRequest): save_path = item_dict['save_path'] audio_format = item_dict['audio_format'] + # Check parameters + if speed <=0 or speed > 3 or volume <=0 or volume > 3 or \ + sample_rate not in [0, 16000, 8000] or \ + audio_format not in ["pcm", "wav"]: + return failed_response(ErrorCode.SERVER_PARAM_ERR) + # single tts_engine = TTSEngine() - #tts_engine.init() - lang, target_sample_rate, wav_base64 = tts_engine.run( - sentence, spk_id, speed, volume, sample_rate, save_path, audio_format) - #tts_engine.postprocess() + # run + try: + lang, target_sample_rate, wav_base64 = tts_engine.run( + sentence, spk_id, speed, volume, sample_rate, save_path, + audio_format) + except ServerBaseException as e: + response = failed_response(e.error_code, e.msg) + except: + response = failed_response(ErrorCode.SERVER_UNKOWN_ERR) + traceback.print_exc() json_body = { "success": True, - "code": 0, + "code": 200, "message": { - "description": "success" + "description": "success." }, "result": { "lang": lang, diff --git a/speechserving/speechserving/utils/errors.py b/speechserving/speechserving/utils/errors.py index e69de29b..aa858cb0 100644 --- a/speechserving/speechserving/utils/errors.py +++ b/speechserving/speechserving/utils/errors.py @@ -0,0 +1,57 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from enum import IntEnum + +from fastapi import Response + + +class ErrorCode(IntEnum): + SERVER_OK = 200 # success. + + SERVER_PARAM_ERR = 400 # Input parameters are not valid. + SERVER_TASK_NOT_EXIST = 404 # Task is not exist. + + SERVER_INTERNAL_ERR = 500 # Internal error. + SERVER_NETWORK_ERR = 502 # Network exception. + SERVER_UNKOWN_ERR = 509 # Unknown error occurred. + + +ErrorMsg = { + ErrorCode.SERVER_OK: "success.", + ErrorCode.SERVER_PARAM_ERR: "Input parameters are not valid.", + ErrorCode.SERVER_TASK_NOT_EXIST: "Task is not exist.", + ErrorCode.SERVER_INTERNAL_ERR: "Internal error.", + ErrorCode.SERVER_NETWORK_ERR: "Network exception.", + ErrorCode.SERVER_UNKOWN_ERR: "Unknown error occurred." +} + + +def failed_response(code, msg=""): + """Interface call failure response + + Args: + code (int): error code number + msg (str, optional): Interface call failure information. Defaults to "". + + Returns: + Response (json): failure json information. + """ + + if not msg: + msg = ErrorMsg.get(code, "Unknown error occurred.") + + res = {"success": False, "code": int(code), "message": {"global": msg}} + + return Response(content=json.dumps(res), media_type="application/json") diff --git a/speechserving/speechserving/utils/exception.py b/speechserving/speechserving/utils/exception.py new file mode 100644 index 00000000..03a6deee --- /dev/null +++ b/speechserving/speechserving/utils/exception.py @@ -0,0 +1,30 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import traceback + +from utils.errors import ErrorMsg + + +class ServerBaseException(Exception): + """ Server Base exception + """ + + def __init__(self, error_code, msg=None): + #if msg: + #log.error(msg) + msg = msg if msg else ErrorMsg.get(error_code, "") + super(ServerBaseException, self).__init__(error_code, msg) + self.error_code = error_code + self.msg = msg + traceback.print_exc() From c8a98ecfffd6bcd14a246010928f0bbe29492ba3 Mon Sep 17 00:00:00 2001 From: lym0302 Date: Fri, 28 Jan 2022 14:21:17 +0800 Subject: [PATCH 2/2] add postproces, test=doc --- .../engine/tts/python/tts_engine.py | 41 ++++++++++------ .../speechserving/restful/request.py | 4 +- .../speechserving/restful/tts_api.py | 49 +++++++++---------- .../speechserving/utils/audio_types.py | 40 +++++++++++++++ 4 files changed, 89 insertions(+), 45 deletions(-) create mode 100644 speechserving/speechserving/utils/audio_types.py diff --git a/speechserving/speechserving/engine/tts/python/tts_engine.py b/speechserving/speechserving/engine/tts/python/tts_engine.py index 4f0e9906..65e35fb8 100644 --- a/speechserving/speechserving/engine/tts/python/tts_engine.py +++ b/speechserving/speechserving/engine/tts/python/tts_engine.py @@ -13,15 +13,19 @@ # limitations under the License. import argparse import base64 +import os +import random import librosa import numpy as np import soundfile as sf import yaml from engine.base_engine import BaseEngine +from ffmpeg import audio from paddlespeech.cli.log import logger from paddlespeech.cli.tts.infer import TTSExecutor +from utils.audio_types import wav2pcm from utils.errors import ErrorCode from utils.exception import ServerBaseException @@ -80,8 +84,7 @@ class TTSEngine(BaseEngine): target_fs: int=16000, volume: float=1.0, speed: float=1.0, - audio_path: str=None, - audio_format: str="wav"): + audio_path: str=None): """Post-processing operations, including speech, volume, sample rate, save audio file Args: @@ -104,18 +107,26 @@ class TTSEngine(BaseEngine): wav_vol = wav_tar_fs * volume # transform speed - # TODO - target_wav = wav_vol.reshape(-1, 1) - - # save audio - if audio_path is not None: - sf.write(audio_path, target_wav, target_fs) - logger.info('Wave file has been generated: {}'.format(audio_path)) + hash = random.getrandbits(128) + temp_wav = str(hash) + ".wav" + temp_speed_wav = str(hash + 1) + ".wav" + sf.write(temp_wav, wav_vol.reshape(-1, 1), target_fs) + audio.a_speed(temp_wav, speed, temp_speed_wav) + os.system("rm %s" % (temp_wav)) # wav to base64 - base64_bytes = base64.b64encode(target_wav) - base64_string = base64_bytes.decode('utf-8') - wav_base64 = base64_string + with open(temp_speed_wav, 'rb') as f: + base64_bytes = base64.b64encode(f.read()) + wav_base64 = base64_bytes.decode('utf-8') + + # save audio + if audio_path is not None and audio_path.endswith(".wav"): + os.system("mv %s %s" % (temp_speed_wav, audio_path)) + elif audio_path is not None and audio_path.endswith(".pcm"): + wav2pcm(temp_speed_wav, audio_path, data_type=np.int16) + os.system("rm %s" % (temp_speed_wav)) + else: + os.system("rm %s" % (temp_speed_wav)) return target_fs, wav_base64 @@ -125,8 +136,7 @@ class TTSEngine(BaseEngine): speed: float=1.0, volume: float=1.0, sample_rate: int=0, - save_path: str=None, - audio_format: str="wav"): + save_path: str=None): lang = self.conf_dict["lang"] @@ -147,8 +157,7 @@ class TTSEngine(BaseEngine): target_fs=sample_rate, volume=volume, speed=speed, - audio_path=save_path, - audio_format=audio_format) + audio_path=save_path) except: raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, "tts postprocess failed.") diff --git a/speechserving/speechserving/restful/request.py b/speechserving/speechserving/restful/request.py index 32f157d3..625fe235 100644 --- a/speechserving/speechserving/restful/request.py +++ b/speechserving/speechserving/restful/request.py @@ -53,8 +53,7 @@ class TTSRequest(BaseModel): "speed": 1.0, "volume": 1.0, "sample_rate": 0, - "tts_audio_path": "./tts.wav", - "audio_format": "wav" + "tts_audio_path": "./tts.wav" } """ @@ -65,4 +64,3 @@ class TTSRequest(BaseModel): volume: float = 1.0 sample_rate: int = 0 save_path: str = None - audio_format: str = "wav" diff --git a/speechserving/speechserving/restful/tts_api.py b/speechserving/speechserving/restful/tts_api.py index 69930f24..e9dcfa16 100644 --- a/speechserving/speechserving/restful/tts_api.py +++ b/speechserving/speechserving/restful/tts_api.py @@ -19,7 +19,6 @@ from fastapi import APIRouter from .request import TTSRequest from .response import TTSResponse from utils.errors import ErrorCode -from utils.errors import ErrorMsg from utils.errors import failed_response from utils.exception import ServerBaseException @@ -33,9 +32,9 @@ def help(): Returns: json: [description] """ - json_body = { + response = { "success": "True", - "code": 0, + "code": 200, "message": { "global": "success" }, @@ -45,7 +44,7 @@ def help(): "audio": "the base64 of audio" } } - return json_body + return response @router.post("/paddlespeech/tts", response_model=TTSResponse) @@ -66,12 +65,11 @@ def tts(request_body: TTSRequest): volume = item_dict['volume'] sample_rate = item_dict['sample_rate'] save_path = item_dict['save_path'] - audio_format = item_dict['audio_format'] # Check parameters if speed <=0 or speed > 3 or volume <=0 or volume > 3 or \ sample_rate not in [0, 16000, 8000] or \ - audio_format not in ["pcm", "wav"]: + (save_path is not None and save_path.endswith("pcm") == False and save_path.endswith("wav") == False): return failed_response(ErrorCode.SERVER_PARAM_ERR) # single @@ -80,29 +78,28 @@ def tts(request_body: TTSRequest): # run try: lang, target_sample_rate, wav_base64 = tts_engine.run( - sentence, spk_id, speed, volume, sample_rate, save_path, - audio_format) + sentence, spk_id, speed, volume, sample_rate, save_path) + + response = { + "success": True, + "code": 200, + "message": { + "description": "success." + }, + "result": { + "lang": lang, + "spk_id": spk_id, + "speed": speed, + "volume": volume, + "sample_rate": target_sample_rate, + "save_path": save_path, + "audio": wav_base64 + } + } except ServerBaseException as e: response = failed_response(e.error_code, e.msg) except: response = failed_response(ErrorCode.SERVER_UNKOWN_ERR) traceback.print_exc() - json_body = { - "success": True, - "code": 200, - "message": { - "description": "success." - }, - "result": { - "lang": lang, - "spk_id": spk_id, - "speed": speed, - "volume": volume, - "sample_rate": target_sample_rate, - "save_path": save_path, - "audio": wav_base64 - } - } - - return json_body + return response diff --git a/speechserving/speechserving/utils/audio_types.py b/speechserving/speechserving/utils/audio_types.py new file mode 100644 index 00000000..eb655ddd --- /dev/null +++ b/speechserving/speechserving/utils/audio_types.py @@ -0,0 +1,40 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import wave + +import numpy as np + + +def wav2pcm(wavfile, pcmfile, data_type=np.int16): + f = open(wavfile, "rb") + f.seek(0) + f.read(44) + data = np.fromfile(f, dtype=data_type) + data.tofile(pcmfile) + + +def pcm2wav(pcm_file, wav_file, channels=1, bits=16, sample_rate=16000): + pcmf = open(pcm_file, 'rb') + pcmdata = pcmf.read() + pcmf.close() + + if bits % 8 != 0: + raise ValueError("bits % 8 must == 0. now bits:" + str(bits)) + + wavfile = wave.open(wav_file, 'wb') + wavfile.setnchannels(channels) + wavfile.setsampwidth(bits // 8) + wavfile.setframerate(sample_rate) + wavfile.writeframes(pcmdata) + wavfile.close()