diff --git a/paddlespeech/server/__init__.py b/paddlespeech/server/__init__.py index 97043fd7..bbe1b34e 100644 --- a/paddlespeech/server/__init__.py +++ b/paddlespeech/server/__init__.py @@ -11,3 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import _locale + +from .base_commands import BaseCommand +from .base_commands import HelpCommand +from paddlespeech.server.bin.paddlespeech_client import TTSClientExecutor +from paddlespeech.server.bin.paddlespeech_server import ServerExecutor + +_locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8']) diff --git a/paddlespeech/server/base_commands.py b/paddlespeech/server/base_commands.py new file mode 100644 index 00000000..b32fe99d --- /dev/null +++ b/paddlespeech/server/base_commands.py @@ -0,0 +1,49 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List + +from .entry import commands +from .util import cli_register +from .util import get_command + +__all__ = [ + 'BaseCommand', + 'HelpCommand', +] + + +@cli_register(name='paddleserver') +class BaseCommand: + def execute(self, argv: List[str]) -> bool: + help = get_command('paddleserver.help') + return help().execute(argv) + + +@cli_register(name='paddleserver.help', description='Show help for commands.') +class HelpCommand: + def execute(self, argv: List[str]) -> bool: + msg = 'Usage:\n' + msg += ' paddleserver \n\n' + msg += 'Commands:\n' + for command, detail in commands['paddleserver'].items(): + if command.startswith('_'): + continue + + if '_description' not in detail: + continue + msg += ' {:<15} {}\n'.format(command, + detail['_description']) + + print(msg) + return True diff --git a/paddlespeech/server/bin/__init__.py b/paddlespeech/server/bin/__init__.py new file mode 100644 index 00000000..3d37168b --- /dev/null +++ b/paddlespeech/server/bin/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .paddlespeech_server import ServerExecutor diff --git a/paddlespeech/server/bin/paddlespeech_client.py b/paddlespeech/server/bin/paddlespeech_client.py new file mode 100644 index 00000000..f05d5216 --- /dev/null +++ b/paddlespeech/server/bin/paddlespeech_client.py @@ -0,0 +1,112 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import base64 +import io +import json +import os +import random +import time +from typing import List + +import numpy as np +import requests +import soundfile + +from paddlespeech.server.util import cli_register +from paddlespeech.server.utils.audio_process import wav2pcm + + +@cli_register(name='paddleserver.ttsclient', description='visit tts service') +class TTSClientExecutor(): + def __init__(self): + super().__init__() + self.parser = argparse.ArgumentParser() + self.parser.add_argument( + '--server_ip', type=str, default='127.0.0.1', help='server ip') + self.parser.add_argument( + '--port', type=int, default=8090, help='server port') + self.parser.add_argument( + '--text', + type=str, + default="你好,欢迎使用语音合成服务", + help='A sentence to be synthesized') + self.parser.add_argument( + '--spk_id', type=int, default=0, help='Speaker id') + self.parser.add_argument( + '--speed', type=float, default=1.0, help='Audio speed') + self.parser.add_argument( + '--volume', type=float, default=1.0, help='Audio volume') + self.parser.add_argument( + '--sample_rate', + type=int, + default=0, + help='Sampling rate, the default is the same as the model') + self.parser.add_argument( + '--output', + type=str, + default="./out.wav", + help='Synthesized audio file') + + # Request and response + def tts_client(self, args): + """ Request and response + Args: + text: A sentence to be synthesized + outfile: Synthetic audio file + """ + url = 'http://' + args.server_ip + ":" + str( + args.port) + '/paddlespeech/tts' + request = { + "text": args.text, + "spk_id": args.spk_id, + "speed": args.speed, + "volume": args.volume, + "sample_rate": args.sample_rate, + "save_path": args.output + } + + response = requests.post(url, json.dumps(request)) + response_dict = response.json() + wav_base64 = response_dict["result"]["audio"] + + audio_data_byte = base64.b64decode(wav_base64) + # from byte + samples, sample_rate = soundfile.read( + io.BytesIO(audio_data_byte), dtype='float32') + + # transform audio + outfile = args.output + if outfile.endswith(".wav"): + soundfile.write(outfile, samples, sample_rate) + elif outfile.endswith(".pcm"): + temp_wav = str(random.getrandbits(128)) + ".wav" + soundfile.write(temp_wav, samples, sample_rate) + wav2pcm(temp_wav, outfile, data_type=np.int16) + os.system("rm %s" % (temp_wav)) + else: + print("The format for saving audio only supports wav or pcm") + + return len(samples), sample_rate + + def execute(self, argv: List[str]) -> bool: + args = self.parser.parse_args(argv) + st = time.time() + try: + samples_length, sample_rate = self.tts_client(args) + time_consume = time.time() - st + print("Save synthesized audio successfully on %s." % (args.output)) + print("Inference time: %f" % (time_consume)) + except: + print("Failed to synthesized audio.") diff --git a/paddlespeech/server/bin/paddlespeech_server.py b/paddlespeech/server/bin/paddlespeech_server.py new file mode 100644 index 00000000..c4f8e0e0 --- /dev/null +++ b/paddlespeech/server/bin/paddlespeech_server.py @@ -0,0 +1,75 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +from typing import List + +import uvicorn +from fastapi import FastAPI + +from paddlespeech.server.engine.engine_factory import EngineFactory +from paddlespeech.server.restful.api import setup_router +from paddlespeech.server.util import cli_register +from paddlespeech.server.utils.config import get_config + +app = FastAPI( + title="PaddleSpeech Serving API", description="Api", version="0.0.1") + + +@cli_register(name='paddleserver.server', description='Start the service') +class ServerExecutor(): + def __init__(self): + super().__init__() + self.parser = argparse.ArgumentParser() + self.parser.add_argument( + "--config_file", + action="store", + help="yaml file of the app", + default="./conf/application.yaml") + + self.parser.add_argument( + "--log_file", + action="store", + help="log file", + default="./log/paddlespeech.log") + + def init(self, config) -> bool: + """system initialization + + Args: + config (CfgNode): config object + + Returns: + bool: + """ + # init api + api_list = list(config.engine_backend) + api_router = setup_router(api_list) + app.include_router(api_router) + + # init engine + engine_pool = [] + for engine in config.engine_backend: + engine_pool.append(EngineFactory.get_engine(engine_name=engine)) + if not engine_pool[-1].init( + config_file=config.engine_backend[engine]): + return False + + return True + + def execute(self, argv: List[str]) -> bool: + args = self.parser.parse_args(argv) + config = get_config(args.config_file) + + if self.init(config): + uvicorn.run(app, host=config.host, port=config.port, debug=True) diff --git a/paddlespeech/server/download.py b/paddlespeech/server/download.py new file mode 100644 index 00000000..ea943dd8 --- /dev/null +++ b/paddlespeech/server/download.py @@ -0,0 +1,329 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import hashlib +import os +import os.path as osp +import shutil +import subprocess +import tarfile +import time +import zipfile + +import requests +from tqdm import tqdm + +from paddlespeech.cli.log import logger + +__all__ = ['get_path_from_url'] + +DOWNLOAD_RETRY_LIMIT = 3 + + +def _is_url(path): + """ + Whether path is URL. + Args: + path (string): URL string or not. + """ + return path.startswith('http://') or path.startswith('https://') + + +def _map_path(url, root_dir): + # parse path after download under root_dir + fname = osp.split(url)[-1] + fpath = fname + return osp.join(root_dir, fpath) + + +def _get_unique_endpoints(trainer_endpoints): + # Sorting is to avoid different environmental variables for each card + trainer_endpoints.sort() + ips = set() + unique_endpoints = set() + for endpoint in trainer_endpoints: + ip = endpoint.split(":")[0] + if ip in ips: + continue + ips.add(ip) + unique_endpoints.add(endpoint) + logger.info("unique_endpoints {}".format(unique_endpoints)) + return unique_endpoints + + +def get_path_from_url(url, + root_dir, + md5sum=None, + check_exist=True, + decompress=True, + method='get'): + """ Download from given url to root_dir. + if file or directory specified by url is exists under + root_dir, return the path directly, otherwise download + from url and decompress it, return the path. + Args: + url (str): download url + root_dir (str): root dir for downloading, it should be + WEIGHTS_HOME or DATASET_HOME + md5sum (str): md5 sum of download package + decompress (bool): decompress zip or tar file. Default is `True` + method (str): which download method to use. Support `wget` and `get`. Default is `get`. + Returns: + str: a local path to save downloaded models & weights & datasets. + """ + + from paddle.fluid.dygraph.parallel import ParallelEnv + + assert _is_url(url), "downloading from {} not a url".format(url) + # parse path after download to decompress under root_dir + fullpath = _map_path(url, root_dir) + # Mainly used to solve the problem of downloading data from different + # machines in the case of multiple machines. Different ips will download + # data, and the same ip will only download data once. + unique_endpoints = _get_unique_endpoints(ParallelEnv().trainer_endpoints[:]) + if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum): + logger.info("Found {}".format(fullpath)) + else: + if ParallelEnv().current_endpoint in unique_endpoints: + fullpath = _download(url, root_dir, md5sum, method=method) + else: + while not os.path.exists(fullpath): + time.sleep(1) + + if ParallelEnv().current_endpoint in unique_endpoints: + if decompress and (tarfile.is_tarfile(fullpath) or + zipfile.is_zipfile(fullpath)): + fullpath = _decompress(fullpath) + + return fullpath + + +def _get_download(url, fullname): + # using requests.get method + fname = osp.basename(fullname) + try: + req = requests.get(url, stream=True) + except Exception as e: # requests.exceptions.ConnectionError + logger.info("Downloading {} from {} failed with exception {}".format( + fname, url, str(e))) + return False + + if req.status_code != 200: + raise RuntimeError("Downloading from {} failed with code " + "{}!".format(url, req.status_code)) + + # For protecting download interupted, download to + # tmp_fullname firstly, move tmp_fullname to fullname + # after download finished + tmp_fullname = fullname + "_tmp" + total_size = req.headers.get('content-length') + with open(tmp_fullname, 'wb') as f: + if total_size: + with tqdm(total=(int(total_size) + 1023) // 1024) as pbar: + for chunk in req.iter_content(chunk_size=1024): + f.write(chunk) + pbar.update(1) + else: + for chunk in req.iter_content(chunk_size=1024): + if chunk: + f.write(chunk) + shutil.move(tmp_fullname, fullname) + + return fullname + + +def _wget_download(url, fullname): + # using wget to download url + tmp_fullname = fullname + "_tmp" + # –user-agent + command = 'wget -O {} -t {} {}'.format(tmp_fullname, DOWNLOAD_RETRY_LIMIT, + url) + subprc = subprocess.Popen( + command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + _ = subprc.communicate() + + if subprc.returncode != 0: + raise RuntimeError( + '{} failed. Please make sure `wget` is installed or {} exists'. + format(command, url)) + + shutil.move(tmp_fullname, fullname) + + return fullname + + +_download_methods = { + 'get': _get_download, + 'wget': _wget_download, +} + + +def _download(url, path, md5sum=None, method='get'): + """ + Download from url, save to path. + url (str): download url + path (str): download to given path + md5sum (str): md5 sum of download package + method (str): which download method to use. Support `wget` and `get`. Default is `get`. + """ + assert method in _download_methods, 'make sure `{}` implemented'.format( + method) + + if not osp.exists(path): + os.makedirs(path) + + fname = osp.split(url)[-1] + fullname = osp.join(path, fname) + retry_cnt = 0 + + logger.info("Downloading {} from {}".format(fname, url)) + while not (osp.exists(fullname) and _md5check(fullname, md5sum)): + if retry_cnt < DOWNLOAD_RETRY_LIMIT: + retry_cnt += 1 + else: + raise RuntimeError("Download from {} failed. " + "Retry limit reached".format(url)) + + if not _download_methods[method](url, fullname): + time.sleep(1) + continue + + return fullname + + +def _md5check(fullname, md5sum=None): + if md5sum is None: + return True + + logger.info("File {} md5 checking...".format(fullname)) + md5 = hashlib.md5() + with open(fullname, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b""): + md5.update(chunk) + calc_md5sum = md5.hexdigest() + + if calc_md5sum != md5sum: + logger.info("File {} md5 check failed, {}(calc) != " + "{}(base)".format(fullname, calc_md5sum, md5sum)) + return False + return True + + +def _decompress(fname): + """ + Decompress for zip and tar file + """ + logger.info("Decompressing {}...".format(fname)) + + # For protecting decompressing interupted, + # decompress to fpath_tmp directory firstly, if decompress + # successed, move decompress files to fpath and delete + # fpath_tmp and remove download compress file. + + if tarfile.is_tarfile(fname): + uncompressed_path = _uncompress_file_tar(fname) + elif zipfile.is_zipfile(fname): + uncompressed_path = _uncompress_file_zip(fname) + else: + raise TypeError("Unsupport compress file type {}".format(fname)) + + return uncompressed_path + + +def _uncompress_file_zip(filepath): + files = zipfile.ZipFile(filepath, 'r') + file_list = files.namelist() + + file_dir = os.path.dirname(filepath) + + if _is_a_single_file(file_list): + rootpath = file_list[0] + uncompressed_path = os.path.join(file_dir, rootpath) + + for item in file_list: + files.extract(item, file_dir) + + elif _is_a_single_dir(file_list): + rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[0] + uncompressed_path = os.path.join(file_dir, rootpath) + + for item in file_list: + files.extract(item, file_dir) + + else: + rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + if not os.path.exists(uncompressed_path): + os.makedirs(uncompressed_path) + for item in file_list: + files.extract(item, os.path.join(file_dir, rootpath)) + + files.close() + + return uncompressed_path + + +def _uncompress_file_tar(filepath, mode="r:*"): + files = tarfile.open(filepath, mode) + file_list = files.getnames() + + file_dir = os.path.dirname(filepath) + + if _is_a_single_file(file_list): + rootpath = file_list[0] + uncompressed_path = os.path.join(file_dir, rootpath) + for item in file_list: + files.extract(item, file_dir) + elif _is_a_single_dir(file_list): + rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + for item in file_list: + files.extract(item, file_dir) + else: + rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + if not os.path.exists(uncompressed_path): + os.makedirs(uncompressed_path) + + for item in file_list: + files.extract(item, os.path.join(file_dir, rootpath)) + + files.close() + + return uncompressed_path + + +def _is_a_single_file(file_list): + if len(file_list) == 1 and file_list[0].find(os.sep) < -1: + return True + return False + + +def _is_a_single_dir(file_list): + new_file_list = [] + for file_path in file_list: + if '/' in file_path: + file_path = file_path.replace('/', os.sep) + elif '\\' in file_path: + file_path = file_path.replace('\\', os.sep) + new_file_list.append(file_path) + + file_name = new_file_list[0].split(os.sep)[0] + for i in range(1, len(new_file_list)): + if file_name != new_file_list[i].split(os.sep)[0]: + return False + return True diff --git a/paddlespeech/server/bin/paddlespeech-client.py b/paddlespeech/server/engine/asr/__init__.py similarity index 100% rename from paddlespeech/server/bin/paddlespeech-client.py rename to paddlespeech/server/engine/asr/__init__.py diff --git a/paddlespeech/server/bin/paddlespeech-server.py b/paddlespeech/server/engine/asr/python/__init__.py similarity index 100% rename from paddlespeech/server/bin/paddlespeech-server.py rename to paddlespeech/server/engine/asr/python/__init__.py diff --git a/paddlespeech/server/engine/tts/__init__.py b/paddlespeech/server/engine/tts/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/server/engine/tts/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/engine/tts/paddleinference/__init__.py b/paddlespeech/server/engine/tts/paddleinference/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/server/engine/tts/paddleinference/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/engine/tts/paddleinference/tts_engine.py b/paddlespeech/server/engine/tts/paddleinference/tts_engine.py index fbaf372b..7679b02f 100644 --- a/paddlespeech/server/engine/tts/paddleinference/tts_engine.py +++ b/paddlespeech/server/engine/tts/paddleinference/tts_engine.py @@ -26,8 +26,6 @@ from paddlespeech.cli.log import logger from paddlespeech.cli.tts.infer import TTSExecutor from paddlespeech.cli.utils import download_and_decompress from paddlespeech.cli.utils import MODEL_HOME -from paddlespeech.t2s.frontend import English -from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.audio_process import change_speed from paddlespeech.server.utils.config import get_config @@ -35,6 +33,8 @@ from paddlespeech.server.utils.errors import ErrorCode from paddlespeech.server.utils.exception import ServerBaseException from paddlespeech.server.utils.paddle_predictor import init_predictor from paddlespeech.server.utils.paddle_predictor import run_model +from paddlespeech.t2s.frontend import English +from paddlespeech.t2s.frontend.zh_frontend import Frontend __all__ = ['TTSEngine'] @@ -153,7 +153,7 @@ class TTSServerExecutor(TTSExecutor): """ Init model and other resources from a specific path. """ - if hasattr(self, 'am') and hasattr(self, 'voc'): + if hasattr(self, 'am_predictor') and hasattr(self, 'voc_predictor'): logger.info('Models had been initialized.') return # am @@ -341,24 +341,29 @@ class TTSEngine(BaseEngine): def init(self, config_file: str) -> bool: self.executor = TTSServerExecutor() - self.config_file = config_file - self.config = get_config(config_file) - - self.executor._init_from_path( - am=self.config.am, - am_model=self.config.am_model, - am_params=self.config.am_params, - am_sample_rate=self.config.am_sample_rate, - phones_dict=self.config.phones_dict, - tones_dict=self.config.tones_dict, - speaker_dict=self.config.speaker_dict, - voc=self.config.voc, - voc_model=self.config.voc_model, - voc_params=self.config.voc_params, - voc_sample_rate=self.config.voc_sample_rate, - lang=self.config.lang, - am_predictor_conf=self.config.am_predictor_conf, - voc_predictor_conf=self.config.voc_predictor_conf, ) + + try: + self.config = get_config(config_file) + + self.executor._init_from_path( + am=self.config.am, + am_model=self.config.am_model, + am_params=self.config.am_params, + am_sample_rate=self.config.am_sample_rate, + phones_dict=self.config.phones_dict, + tones_dict=self.config.tones_dict, + speaker_dict=self.config.speaker_dict, + voc=self.config.voc, + voc_model=self.config.voc_model, + voc_params=self.config.voc_params, + voc_sample_rate=self.config.voc_sample_rate, + lang=self.config.lang, + am_predictor_conf=self.config.am_predictor_conf, + voc_predictor_conf=self.config.voc_predictor_conf, ) + + except: + logger.info("Initialize TTS server engine Failed.") + return False logger.info("Initialize TTS server engine successfully.") return True @@ -404,7 +409,8 @@ class TTSEngine(BaseEngine): except: raise ServerBaseException( ErrorCode.SERVER_INTERNAL_ERR, - "Can not install soxbindings on your system.") + "Transform speed failed. Can not install soxbindings on your system. \ + You need to set speed value 1.0.") # wav to base64 buf = io.BytesIO() diff --git a/paddlespeech/server/engine/tts/python/__init__.py b/paddlespeech/server/engine/tts/python/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/server/engine/tts/python/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/engine/tts/python/tts_engine.py b/paddlespeech/server/engine/tts/python/tts_engine.py index b323551b..e11cfb1d 100644 --- a/paddlespeech/server/engine/tts/python/tts_engine.py +++ b/paddlespeech/server/engine/tts/python/tts_engine.py @@ -16,13 +16,14 @@ import io import librosa import numpy as np +import paddle import soundfile as sf from scipy.io import wavfile from paddlespeech.cli.log import logger from paddlespeech.cli.tts.infer import TTSExecutor -from paddlespeech.server.utils.audio_process import change_speed from paddlespeech.server.engine.base_engine import BaseEngine +from paddlespeech.server.utils.audio_process import change_speed from paddlespeech.server.utils.config import get_config from paddlespeech.server.utils.errors import ErrorCode from paddlespeech.server.utils.exception import ServerBaseException @@ -50,22 +51,27 @@ class TTSEngine(BaseEngine): def init(self, config_file: str) -> bool: self.executor = TTSServerExecutor() - self.config_file = config_file - self.config = get_config(config_file) - - self.executor._init_from_path( - am=self.config.am, - am_config=self.config.am_config, - am_ckpt=self.config.am_ckpt, - am_stat=self.config.am_stat, - phones_dict=self.config.phones_dict, - tones_dict=self.config.tones_dict, - speaker_dict=self.config.speaker_dict, - voc=self.config.voc, - voc_config=self.config.voc_config, - voc_ckpt=self.config.voc_ckpt, - voc_stat=self.config.voc_stat, - lang=self.config.lang) + + try: + self.config = get_config(config_file) + paddle.set_device(self.config.device) + + self.executor._init_from_path( + am=self.config.am, + am_config=self.config.am_config, + am_ckpt=self.config.am_ckpt, + am_stat=self.config.am_stat, + phones_dict=self.config.phones_dict, + tones_dict=self.config.tones_dict, + speaker_dict=self.config.speaker_dict, + voc=self.config.voc, + voc_config=self.config.voc_config, + voc_ckpt=self.config.voc_ckpt, + voc_stat=self.config.voc_stat, + lang=self.config.lang) + except: + logger.info("Initialize TTS server engine Failed.") + return False logger.info("Initialize TTS server engine successfully.") return True diff --git a/paddlespeech/server/entry.py b/paddlespeech/server/entry.py new file mode 100644 index 00000000..50f37fbb --- /dev/null +++ b/paddlespeech/server/entry.py @@ -0,0 +1,41 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +from collections import defaultdict + +__all__ = ['commands'] + + +def _CommandDict(): + return defaultdict(_CommandDict) + + +def _execute(): + com = commands + + idx = 0 + for _argv in (['paddleserver'] + sys.argv[1:]): + if _argv not in com: + break + idx += 1 + com = com[_argv] + + # The method 'execute' of a command instance returns 'True' for a success + # while 'False' for a failure. Here converts this result into a exit status + # in bash: 0 for a success and 1 for a failure. + status = 0 if com['_entry']().execute(sys.argv[idx:]) else 1 + return status + + +commands = _CommandDict() diff --git a/paddlespeech/server/restful/tts_api.py b/paddlespeech/server/restful/tts_api.py index 36c4be62..d5fa1d42 100644 --- a/paddlespeech/server/restful/tts_api.py +++ b/paddlespeech/server/restful/tts_api.py @@ -13,9 +13,9 @@ # limitations under the License. import traceback from typing import Union + from fastapi import APIRouter -#from paddlespeech.server.engine.tts.python.tts_engine import TTSEngine from paddlespeech.server.engine.tts.paddleinference.tts_engine import TTSEngine from paddlespeech.server.restful.request import TTSRequest from paddlespeech.server.restful.response import ErrorResponse @@ -72,7 +72,7 @@ def tts(request_body: TTSRequest): # Check parameters if speed <=0 or speed > 3 or volume <=0 or volume > 3 or \ sample_rate not in [0, 16000, 8000] or \ - (save_path is not None and save_path.endswith("pcm") == False and save_path.endswith("wav") == False): + (save_path is not None and not save_path.endswith("pcm") and not save_path.endswith("wav")): return failed_response(ErrorCode.SERVER_PARAM_ERR) # single diff --git a/paddlespeech/server/util.py b/paddlespeech/server/util.py new file mode 100644 index 00000000..d11178df --- /dev/null +++ b/paddlespeech/server/util.py @@ -0,0 +1,337 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import hashlib +import inspect +import json +import os +import tarfile +import threading +import time +import uuid +import zipfile +from typing import Any +from typing import Dict + +import paddle +import requests +import yaml +from paddle.framework import load + +import paddleaudio +from . import download +from .. import __version__ +from .entry import commands + +requests.adapters.DEFAULT_RETRIES = 3 + +__all__ = [ + 'cli_register', + 'get_command', + 'download_and_decompress', + 'load_state_dict_from_url', + 'stats_wrapper', +] + + +def cli_register(name: str, description: str='') -> Any: + def _warpper(command): + items = name.split('.') + + com = commands + for item in items: + com = com[item] + com['_entry'] = command + if description: + com['_description'] = description + return command + + return _warpper + + +def get_command(name: str) -> Any: + items = name.split('.') + com = commands + for item in items: + com = com[item] + + return com['_entry'] + + +def _get_uncompress_path(filepath: os.PathLike) -> os.PathLike: + file_dir = os.path.dirname(filepath) + is_zip_file = False + if tarfile.is_tarfile(filepath): + files = tarfile.open(filepath, "r:*") + file_list = files.getnames() + elif zipfile.is_zipfile(filepath): + files = zipfile.ZipFile(filepath, 'r') + file_list = files.namelist() + is_zip_file = True + else: + return file_dir + + if download._is_a_single_file(file_list): + rootpath = file_list[0] + uncompressed_path = os.path.join(file_dir, rootpath) + elif download._is_a_single_dir(file_list): + if is_zip_file: + rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[0] + else: + rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + else: + rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + + files.close() + return uncompressed_path + + +def download_and_decompress(archive: Dict[str, str], path: str) -> os.PathLike: + """ + Download archieves and decompress to specific path. + """ + if not os.path.isdir(path): + os.makedirs(path) + + assert 'url' in archive and 'md5' in archive, \ + 'Dictionary keys of "url" and "md5" are required in the archive, but got: {}'.format(list(archive.keys())) + + filepath = os.path.join(path, os.path.basename(archive['url'])) + if os.path.isfile(filepath) and download._md5check(filepath, + archive['md5']): + uncompress_path = _get_uncompress_path(filepath) + if not os.path.isdir(uncompress_path): + download._decompress(filepath) + else: + StatsWorker( + task='download', + version=__version__, + extra_info={ + 'download_url': archive['url'], + 'paddle_version': paddle.__version__ + }).start() + uncompress_path = download.get_path_from_url(archive['url'], path, + archive['md5']) + + return uncompress_path + + +def load_state_dict_from_url(url: str, path: str, md5: str=None) -> os.PathLike: + """ + Download and load a state dict from url + """ + if not os.path.isdir(path): + os.makedirs(path) + + download.get_path_from_url(url, path, md5) + return load(os.path.join(path, os.path.basename(url))) + + +def _get_user_home(): + return os.path.expanduser('~') + + +def _get_paddlespcceh_home(): + if 'PPSPEECH_HOME' in os.environ: + home_path = os.environ['PPSPEECH_HOME'] + if os.path.exists(home_path): + if os.path.isdir(home_path): + return home_path + else: + raise RuntimeError( + 'The environment variable PPSPEECH_HOME {} is not a directory.'. + format(home_path)) + else: + return home_path + return os.path.join(_get_user_home(), '.paddlespeech') + + +def _get_sub_home(directory): + home = os.path.join(_get_paddlespcceh_home(), directory) + if not os.path.exists(home): + os.makedirs(home) + return home + + +PPSPEECH_HOME = _get_paddlespcceh_home() +MODEL_HOME = _get_sub_home('models') +CONF_HOME = _get_sub_home('conf') + + +def _md5(text: str): + '''Calculate the md5 value of the input text.''' + md5code = hashlib.md5(text.encode()) + return md5code.hexdigest() + + +class ConfigCache: + def __init__(self): + self._data = {} + self._initialize() + self.file = os.path.join(CONF_HOME, 'cache.yaml') + if not os.path.exists(self.file): + self.flush() + return + + with open(self.file, 'r') as file: + try: + cfg = yaml.load(file, Loader=yaml.FullLoader) + self._data.update(cfg) + except: + self.flush() + + @property + def cache_info(self): + return self._data['cache_info'] + + def _initialize(self): + # Set default configuration values. + cache_info = _md5(str(uuid.uuid1())[-12:]) + "-" + str(int(time.time())) + self._data['cache_info'] = cache_info + + def flush(self): + '''Flush the current configuration into the configuration file.''' + with open(self.file, 'w') as file: + cfg = json.loads(json.dumps(self._data)) + yaml.dump(cfg, file) + + +stats_api = "http://paddlepaddle.org.cn/paddlehub/stat" +cache_info = ConfigCache().cache_info + + +class StatsWorker(threading.Thread): + def __init__(self, + task="asr", + model=None, + version=__version__, + extra_info={}): + threading.Thread.__init__(self) + self._task = task + self._model = model + self._version = version + self._extra_info = extra_info + + def run(self): + params = { + 'task': self._task, + 'version': self._version, + 'from': 'ppspeech' + } + if self._model: + params['model'] = self._model + + self._extra_info.update({ + 'cache_info': cache_info, + }) + params.update({"extra": json.dumps(self._extra_info)}) + + try: + requests.get(stats_api, params) + except Exception: + pass + + return + + +def _note_one_stat(cls_name, params={}): + task = cls_name.replace('Executor', '').lower() # XXExecutor + extra_info = { + 'paddle_version': paddle.__version__, + } + + if 'model' in params: + model = params['model'] + else: + model = None + + if 'audio_file' in params: + try: + _, sr = paddleaudio.load(params['audio_file']) + except Exception: + sr = -1 + + if task == 'asr': + extra_info.update({ + 'lang': params['lang'], + 'inp_sr': sr, + 'model_sr': params['sample_rate'], + }) + elif task == 'st': + extra_info.update({ + 'lang': + params['src_lang'] + '-' + params['tgt_lang'], + 'inp_sr': + sr, + 'model_sr': + params['sample_rate'], + }) + elif task == 'tts': + model = params['am'] + extra_info.update({ + 'lang': params['lang'], + 'vocoder': params['voc'], + }) + elif task == 'cls': + extra_info.update({ + 'inp_sr': sr, + }) + elif task == 'text': + extra_info.update({ + 'sub_task': params['task'], + 'lang': params['lang'], + }) + else: + return + + StatsWorker( + task=task, + model=model, + version=__version__, + extra_info=extra_info, ).start() + + +def _parse_args(func, *args, **kwargs): + # FullArgSpec(args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations) + argspec = inspect.getfullargspec(func) + + keys = argspec[0] + if keys[0] == 'self': # Remove self pointer. + keys = keys[1:] + + default_values = argspec[3] + values = [None] * (len(keys) - len(default_values)) + values.extend(list(default_values)) + params = dict(zip(keys, values)) + + for idx, v in enumerate(args): + params[keys[idx]] = v + for k, v in kwargs.items(): + params[k] = v + + return params + + +def stats_wrapper(executor_func): + def _warpper(self, *args, **kwargs): + try: + _note_one_stat( + type(self).__name__, _parse_args(executor_func, *args, + **kwargs)) + except Exception: + pass + return executor_func(self, *args, **kwargs) + + return _warpper diff --git a/setup.py b/setup.py index cdb899e4..cf7f3a2c 100644 --- a/setup.py +++ b/setup.py @@ -234,7 +234,10 @@ setup_info = dict( 'Programming Language :: Python :: 3.9', ], entry_points={ - 'console_scripts': ['paddlespeech=paddlespeech.cli.entry:_execute'] + 'console_scripts': [ + 'paddlespeech=paddlespeech.cli.entry:_execute', + 'paddleserver=paddlespeech.server.entry:_execute' + ] }) setup(**setup_info)