add cli, test=doc

pull/1466/head
lym0302 2 years ago
parent 0d1f90adc4
commit 80b83b7434

@ -11,3 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import _locale
from .base_commands import BaseCommand
from .base_commands import HelpCommand
_locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8'])
"""
import _locale
from .base_commands import BaseCommand
from .base_commands import HelpCommand
from paddlespeech.server.bin.paddlespeech_client import TTSClientExecutor
from paddlespeech.server.bin.paddlespeech_server import ServerExecutor
_locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8'])

@ -0,0 +1,49 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from .entry import commands
from .util import cli_register
from .util import get_command
__all__ = [
'BaseCommand',
'HelpCommand',
]
@cli_register(name='paddleserver')
class BaseCommand:
def execute(self, argv: List[str]) -> bool:
help = get_command('paddleserver.help')
return help().execute(argv)
@cli_register(name='paddleserver.help', description='Show help for commands.')
class HelpCommand:
def execute(self, argv: List[str]) -> bool:
msg = 'Usage:\n'
msg += ' paddleserver <command> <options>\n\n'
msg += 'Commands:\n'
for command, detail in commands['paddleserver'].items():
if command.startswith('_'):
continue
if '_description' not in detail:
continue
msg += ' {:<15} {}\n'.format(command,
detail['_description'])
print(msg)
return True

@ -0,0 +1,15 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .paddlespeech_server import ServerExecutor
#from .paddlespeech_client import ClientExecutor

@ -0,0 +1,112 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import base64
import io
import json
import os
import random
import time
from typing import List
import numpy as np
import requests
import soundfile
from paddlespeech.server.util import cli_register
from paddlespeech.server.utils.audio_process import wav2pcm
@cli_register(name='paddleserver.ttsclient', description='visit tts service')
class TTSClientExecutor():
def __init__(self):
super().__init__()
self.parser = argparse.ArgumentParser()
self.parser.add_argument(
'--server_ip', type=str, default='127.0.0.1', help='server ip')
self.parser.add_argument(
'--port', type=int, default=8090, help='server port')
self.parser.add_argument(
'--text',
type=str,
default="你好,欢迎使用语音合成服务",
help='A sentence to be synthesized')
self.parser.add_argument(
'--spk_id', type=int, default=0, help='Speaker id')
self.parser.add_argument(
'--speed', type=float, default=1.0, help='Audio speed')
self.parser.add_argument(
'--volume', type=float, default=1.0, help='Audio volume')
self.parser.add_argument(
'--sample_rate',
type=int,
default=0,
help='Sampling rate, the default is the same as the model')
self.parser.add_argument(
'--output',
type=str,
default="./out.wav",
help='Synthesized audio file')
# Request and response
def tts_client(self, args):
""" Request and response
Args:
text: A sentence to be synthesized
outfile: Synthetic audio file
"""
url = 'http://' + args.server_ip + ":" + str(
args.port) + '/paddlespeech/tts'
request = {
"text": args.text,
"spk_id": args.spk_id,
"speed": args.speed,
"volume": args.volume,
"sample_rate": args.sample_rate,
"save_path": args.output
}
response = requests.post(url, json.dumps(request))
response_dict = response.json()
wav_base64 = response_dict["result"]["audio"]
audio_data_byte = base64.b64decode(wav_base64)
# from byte
samples, sample_rate = soundfile.read(
io.BytesIO(audio_data_byte), dtype='float32')
# transform audio
outfile = args.output
if outfile.endswith(".wav"):
soundfile.write(outfile, samples, sample_rate)
elif outfile.endswith(".pcm"):
temp_wav = str(random.getrandbits(128)) + ".wav"
soundfile.write(temp_wav, samples, sample_rate)
wav2pcm(temp_wav, outfile, data_type=np.int16)
os.system("rm %s" % (temp_wav))
else:
print("The format for saving audio only supports wav or pcm")
return len(samples), sample_rate
def execute(self, argv: List[str]) -> bool:
args = self.parser.parse_args(argv)
st = time.time()
try:
samples_length, sample_rate = self.tts_client(args)
time_consume = time.time() - st
print("Save synthesized audio successfully on %s." % (args.output))
print("Inference time: %f" % (time_consume))
except:
print("Failed to synthesized audio.")

@ -0,0 +1,75 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from typing import List
import uvicorn
from fastapi import FastAPI
from paddlespeech.server.engine.engine_factory import EngineFactory
from paddlespeech.server.restful.api import setup_router
from paddlespeech.server.util import cli_register
from paddlespeech.server.utils.config import get_config
app = FastAPI(
title="PaddleSpeech Serving API", description="Api", version="0.0.1")
@cli_register(name='paddleserver.server', description='Start the service')
class ServerExecutor():
def __init__(self):
super().__init__()
self.parser = argparse.ArgumentParser()
self.parser.add_argument(
"--config_file",
action="store",
help="yaml file of the app",
default="./conf/application.yaml")
self.parser.add_argument(
"--log_file",
action="store",
help="log file",
default="./log/paddlespeech.log")
def init(self, config) -> bool:
"""system initialization
Args:
config (CfgNode): config object
Returns:
bool:
"""
# init api
api_list = list(config.engine_backend)
api_router = setup_router(api_list)
app.include_router(api_router)
# init engine
engine_pool = []
for engine in config.engine_backend:
engine_pool.append(EngineFactory.get_engine(engine_name=engine))
if not engine_pool[-1].init(
config_file=config.engine_backend[engine]):
return False
return True
def execute(self, argv: List[str]) -> bool:
args = self.parser.parse_args(argv)
config = get_config(args.config_file)
if self.init(config):
uvicorn.run(app, host=config.host, port=config.port, debug=True)

@ -0,0 +1,329 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import hashlib
import os
import os.path as osp
import shutil
import subprocess
import tarfile
import time
import zipfile
import requests
from tqdm import tqdm
from paddlespeech.cli.log import logger
__all__ = ['get_path_from_url']
DOWNLOAD_RETRY_LIMIT = 3
def _is_url(path):
"""
Whether path is URL.
Args:
path (string): URL string or not.
"""
return path.startswith('http://') or path.startswith('https://')
def _map_path(url, root_dir):
# parse path after download under root_dir
fname = osp.split(url)[-1]
fpath = fname
return osp.join(root_dir, fpath)
def _get_unique_endpoints(trainer_endpoints):
# Sorting is to avoid different environmental variables for each card
trainer_endpoints.sort()
ips = set()
unique_endpoints = set()
for endpoint in trainer_endpoints:
ip = endpoint.split(":")[0]
if ip in ips:
continue
ips.add(ip)
unique_endpoints.add(endpoint)
logger.info("unique_endpoints {}".format(unique_endpoints))
return unique_endpoints
def get_path_from_url(url,
root_dir,
md5sum=None,
check_exist=True,
decompress=True,
method='get'):
""" Download from given url to root_dir.
if file or directory specified by url is exists under
root_dir, return the path directly, otherwise download
from url and decompress it, return the path.
Args:
url (str): download url
root_dir (str): root dir for downloading, it should be
WEIGHTS_HOME or DATASET_HOME
md5sum (str): md5 sum of download package
decompress (bool): decompress zip or tar file. Default is `True`
method (str): which download method to use. Support `wget` and `get`. Default is `get`.
Returns:
str: a local path to save downloaded models & weights & datasets.
"""
from paddle.fluid.dygraph.parallel import ParallelEnv
assert _is_url(url), "downloading from {} not a url".format(url)
# parse path after download to decompress under root_dir
fullpath = _map_path(url, root_dir)
# Mainly used to solve the problem of downloading data from different
# machines in the case of multiple machines. Different ips will download
# data, and the same ip will only download data once.
unique_endpoints = _get_unique_endpoints(ParallelEnv().trainer_endpoints[:])
if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum):
logger.info("Found {}".format(fullpath))
else:
if ParallelEnv().current_endpoint in unique_endpoints:
fullpath = _download(url, root_dir, md5sum, method=method)
else:
while not os.path.exists(fullpath):
time.sleep(1)
if ParallelEnv().current_endpoint in unique_endpoints:
if decompress and (tarfile.is_tarfile(fullpath) or
zipfile.is_zipfile(fullpath)):
fullpath = _decompress(fullpath)
return fullpath
def _get_download(url, fullname):
# using requests.get method
fname = osp.basename(fullname)
try:
req = requests.get(url, stream=True)
except Exception as e: # requests.exceptions.ConnectionError
logger.info("Downloading {} from {} failed with exception {}".format(
fname, url, str(e)))
return False
if req.status_code != 200:
raise RuntimeError("Downloading from {} failed with code "
"{}!".format(url, req.status_code))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname = fullname + "_tmp"
total_size = req.headers.get('content-length')
with open(tmp_fullname, 'wb') as f:
if total_size:
with tqdm(total=(int(total_size) + 1023) // 1024) as pbar:
for chunk in req.iter_content(chunk_size=1024):
f.write(chunk)
pbar.update(1)
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)
return fullname
def _wget_download(url, fullname):
# using wget to download url
tmp_fullname = fullname + "_tmp"
# user-agent
command = 'wget -O {} -t {} {}'.format(tmp_fullname, DOWNLOAD_RETRY_LIMIT,
url)
subprc = subprocess.Popen(
command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
_ = subprc.communicate()
if subprc.returncode != 0:
raise RuntimeError(
'{} failed. Please make sure `wget` is installed or {} exists'.
format(command, url))
shutil.move(tmp_fullname, fullname)
return fullname
_download_methods = {
'get': _get_download,
'wget': _wget_download,
}
def _download(url, path, md5sum=None, method='get'):
"""
Download from url, save to path.
url (str): download url
path (str): download to given path
md5sum (str): md5 sum of download package
method (str): which download method to use. Support `wget` and `get`. Default is `get`.
"""
assert method in _download_methods, 'make sure `{}` implemented'.format(
method)
if not osp.exists(path):
os.makedirs(path)
fname = osp.split(url)[-1]
fullname = osp.join(path, fname)
retry_cnt = 0
logger.info("Downloading {} from {}".format(fname, url))
while not (osp.exists(fullname) and _md5check(fullname, md5sum)):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
else:
raise RuntimeError("Download from {} failed. "
"Retry limit reached".format(url))
if not _download_methods[method](url, fullname):
time.sleep(1)
continue
return fullname
def _md5check(fullname, md5sum=None):
if md5sum is None:
return True
logger.info("File {} md5 checking...".format(fullname))
md5 = hashlib.md5()
with open(fullname, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b""):
md5.update(chunk)
calc_md5sum = md5.hexdigest()
if calc_md5sum != md5sum:
logger.info("File {} md5 check failed, {}(calc) != "
"{}(base)".format(fullname, calc_md5sum, md5sum))
return False
return True
def _decompress(fname):
"""
Decompress for zip and tar file
"""
logger.info("Decompressing {}...".format(fname))
# For protecting decompressing interupted,
# decompress to fpath_tmp directory firstly, if decompress
# successed, move decompress files to fpath and delete
# fpath_tmp and remove download compress file.
if tarfile.is_tarfile(fname):
uncompressed_path = _uncompress_file_tar(fname)
elif zipfile.is_zipfile(fname):
uncompressed_path = _uncompress_file_zip(fname)
else:
raise TypeError("Unsupport compress file type {}".format(fname))
return uncompressed_path
def _uncompress_file_zip(filepath):
files = zipfile.ZipFile(filepath, 'r')
file_list = files.namelist()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
elif _is_a_single_dir(file_list):
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[0]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
for item in file_list:
files.extract(item, os.path.join(file_dir, rootpath))
files.close()
return uncompressed_path
def _uncompress_file_tar(filepath, mode="r:*"):
files = tarfile.open(filepath, mode)
file_list = files.getnames()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
elif _is_a_single_dir(file_list):
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
for item in file_list:
files.extract(item, os.path.join(file_dir, rootpath))
files.close()
return uncompressed_path
def _is_a_single_file(file_list):
if len(file_list) == 1 and file_list[0].find(os.sep) < -1:
return True
return False
def _is_a_single_dir(file_list):
new_file_list = []
for file_path in file_list:
if '/' in file_path:
file_path = file_path.replace('/', os.sep)
elif '\\' in file_path:
file_path = file_path.replace('\\', os.sep)
new_file_list.append(file_path)
file_name = new_file_list[0].split(os.sep)[0]
for i in range(1, len(new_file_list)):
if file_name != new_file_list[i].split(os.sep)[0]:
return False
return True

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -26,8 +26,6 @@ from paddlespeech.cli.log import logger
from paddlespeech.cli.tts.infer import TTSExecutor
from paddlespeech.cli.utils import download_and_decompress
from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.audio_process import change_speed
from paddlespeech.server.utils.config import get_config
@ -35,6 +33,8 @@ from paddlespeech.server.utils.errors import ErrorCode
from paddlespeech.server.utils.exception import ServerBaseException
from paddlespeech.server.utils.paddle_predictor import init_predictor
from paddlespeech.server.utils.paddle_predictor import run_model
from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend
__all__ = ['TTSEngine']
@ -153,7 +153,7 @@ class TTSServerExecutor(TTSExecutor):
"""
Init model and other resources from a specific path.
"""
if hasattr(self, 'am') and hasattr(self, 'voc'):
if hasattr(self, 'am_predictor') and hasattr(self, 'voc_predictor'):
logger.info('Models had been initialized.')
return
# am
@ -341,24 +341,29 @@ class TTSEngine(BaseEngine):
def init(self, config_file: str) -> bool:
self.executor = TTSServerExecutor()
self.config_file = config_file
self.config = get_config(config_file)
self.executor._init_from_path(
am=self.config.am,
am_model=self.config.am_model,
am_params=self.config.am_params,
am_sample_rate=self.config.am_sample_rate,
phones_dict=self.config.phones_dict,
tones_dict=self.config.tones_dict,
speaker_dict=self.config.speaker_dict,
voc=self.config.voc,
voc_model=self.config.voc_model,
voc_params=self.config.voc_params,
voc_sample_rate=self.config.voc_sample_rate,
lang=self.config.lang,
am_predictor_conf=self.config.am_predictor_conf,
voc_predictor_conf=self.config.voc_predictor_conf, )
try:
self.config = get_config(config_file)
self.executor._init_from_path(
am=self.config.am,
am_model=self.config.am_model,
am_params=self.config.am_params,
am_sample_rate=self.config.am_sample_rate,
phones_dict=self.config.phones_dict,
tones_dict=self.config.tones_dict,
speaker_dict=self.config.speaker_dict,
voc=self.config.voc,
voc_model=self.config.voc_model,
voc_params=self.config.voc_params,
voc_sample_rate=self.config.voc_sample_rate,
lang=self.config.lang,
am_predictor_conf=self.config.am_predictor_conf,
voc_predictor_conf=self.config.voc_predictor_conf, )
except:
logger.info("Initialize TTS server engine Failed.")
return False
logger.info("Initialize TTS server engine successfully.")
return True
@ -404,7 +409,8 @@ class TTSEngine(BaseEngine):
except:
raise ServerBaseException(
ErrorCode.SERVER_INTERNAL_ERR,
"Can not install soxbindings on your system.")
"Transform speed failed. Can not install soxbindings on your system. \
You need to set speed value 1.0.")
# wav to base64
buf = io.BytesIO()

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -16,13 +16,14 @@ import io
import librosa
import numpy as np
import paddle
import soundfile as sf
from scipy.io import wavfile
from paddlespeech.cli.log import logger
from paddlespeech.cli.tts.infer import TTSExecutor
from paddlespeech.server.utils.audio_process import change_speed
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.audio_process import change_speed
from paddlespeech.server.utils.config import get_config
from paddlespeech.server.utils.errors import ErrorCode
from paddlespeech.server.utils.exception import ServerBaseException
@ -50,22 +51,27 @@ class TTSEngine(BaseEngine):
def init(self, config_file: str) -> bool:
self.executor = TTSServerExecutor()
self.config_file = config_file
self.config = get_config(config_file)
self.executor._init_from_path(
am=self.config.am,
am_config=self.config.am_config,
am_ckpt=self.config.am_ckpt,
am_stat=self.config.am_stat,
phones_dict=self.config.phones_dict,
tones_dict=self.config.tones_dict,
speaker_dict=self.config.speaker_dict,
voc=self.config.voc,
voc_config=self.config.voc_config,
voc_ckpt=self.config.voc_ckpt,
voc_stat=self.config.voc_stat,
lang=self.config.lang)
try:
self.config = get_config(config_file)
paddle.set_device(self.config.device)
self.executor._init_from_path(
am=self.config.am,
am_config=self.config.am_config,
am_ckpt=self.config.am_ckpt,
am_stat=self.config.am_stat,
phones_dict=self.config.phones_dict,
tones_dict=self.config.tones_dict,
speaker_dict=self.config.speaker_dict,
voc=self.config.voc,
voc_config=self.config.voc_config,
voc_ckpt=self.config.voc_ckpt,
voc_stat=self.config.voc_stat,
lang=self.config.lang)
except:
logger.info("Initialize TTS server engine Failed.")
return False
logger.info("Initialize TTS server engine successfully.")
return True

@ -0,0 +1,42 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from collections import defaultdict
__all__ = ['commands']
def _CommandDict():
return defaultdict(_CommandDict)
def _execute():
com = commands
print("0000000000000: ", com)
idx = 0
for _argv in (['paddleserver'] + sys.argv[1:]):
if _argv not in com:
break
idx += 1
com = com[_argv]
# The method 'execute' of a command instance returns 'True' for a success
# while 'False' for a failure. Here converts this result into a exit status
# in bash: 0 for a success and 1 for a failure.
status = 0 if com['_entry']().execute(sys.argv[idx:]) else 1
return status
commands = _CommandDict()

@ -13,9 +13,9 @@
# limitations under the License.
import traceback
from typing import Union
from fastapi import APIRouter
#from paddlespeech.server.engine.tts.python.tts_engine import TTSEngine
from paddlespeech.server.engine.tts.paddleinference.tts_engine import TTSEngine
from paddlespeech.server.restful.request import TTSRequest
from paddlespeech.server.restful.response import ErrorResponse
@ -72,7 +72,7 @@ def tts(request_body: TTSRequest):
# Check parameters
if speed <=0 or speed > 3 or volume <=0 or volume > 3 or \
sample_rate not in [0, 16000, 8000] or \
(save_path is not None and save_path.endswith("pcm") == False and save_path.endswith("wav") == False):
(save_path is not None and not save_path.endswith("pcm") and not save_path.endswith("wav")):
return failed_response(ErrorCode.SERVER_PARAM_ERR)
# single

@ -0,0 +1,337 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import inspect
import json
import os
import tarfile
import threading
import time
import uuid
import zipfile
from typing import Any
from typing import Dict
import paddle
import requests
import yaml
from paddle.framework import load
import paddleaudio
from . import download
from .. import __version__
from .entry import commands
requests.adapters.DEFAULT_RETRIES = 3
__all__ = [
'cli_register',
'get_command',
'download_and_decompress',
'load_state_dict_from_url',
'stats_wrapper',
]
def cli_register(name: str, description: str='') -> Any:
def _warpper(command):
items = name.split('.')
com = commands
for item in items:
com = com[item]
com['_entry'] = command
if description:
com['_description'] = description
return command
return _warpper
def get_command(name: str) -> Any:
items = name.split('.')
com = commands
for item in items:
com = com[item]
return com['_entry']
def _get_uncompress_path(filepath: os.PathLike) -> os.PathLike:
file_dir = os.path.dirname(filepath)
is_zip_file = False
if tarfile.is_tarfile(filepath):
files = tarfile.open(filepath, "r:*")
file_list = files.getnames()
elif zipfile.is_zipfile(filepath):
files = zipfile.ZipFile(filepath, 'r')
file_list = files.namelist()
is_zip_file = True
else:
return file_dir
if download._is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
elif download._is_a_single_dir(file_list):
if is_zip_file:
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[0]
else:
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
files.close()
return uncompressed_path
def download_and_decompress(archive: Dict[str, str], path: str) -> os.PathLike:
"""
Download archieves and decompress to specific path.
"""
if not os.path.isdir(path):
os.makedirs(path)
assert 'url' in archive and 'md5' in archive, \
'Dictionary keys of "url" and "md5" are required in the archive, but got: {}'.format(list(archive.keys()))
filepath = os.path.join(path, os.path.basename(archive['url']))
if os.path.isfile(filepath) and download._md5check(filepath,
archive['md5']):
uncompress_path = _get_uncompress_path(filepath)
if not os.path.isdir(uncompress_path):
download._decompress(filepath)
else:
StatsWorker(
task='download',
version=__version__,
extra_info={
'download_url': archive['url'],
'paddle_version': paddle.__version__
}).start()
uncompress_path = download.get_path_from_url(archive['url'], path,
archive['md5'])
return uncompress_path
def load_state_dict_from_url(url: str, path: str, md5: str=None) -> os.PathLike:
"""
Download and load a state dict from url
"""
if not os.path.isdir(path):
os.makedirs(path)
download.get_path_from_url(url, path, md5)
return load(os.path.join(path, os.path.basename(url)))
def _get_user_home():
return os.path.expanduser('~')
def _get_paddlespcceh_home():
if 'PPSPEECH_HOME' in os.environ:
home_path = os.environ['PPSPEECH_HOME']
if os.path.exists(home_path):
if os.path.isdir(home_path):
return home_path
else:
raise RuntimeError(
'The environment variable PPSPEECH_HOME {} is not a directory.'.
format(home_path))
else:
return home_path
return os.path.join(_get_user_home(), '.paddlespeech')
def _get_sub_home(directory):
home = os.path.join(_get_paddlespcceh_home(), directory)
if not os.path.exists(home):
os.makedirs(home)
return home
PPSPEECH_HOME = _get_paddlespcceh_home()
MODEL_HOME = _get_sub_home('models')
CONF_HOME = _get_sub_home('conf')
def _md5(text: str):
'''Calculate the md5 value of the input text.'''
md5code = hashlib.md5(text.encode())
return md5code.hexdigest()
class ConfigCache:
def __init__(self):
self._data = {}
self._initialize()
self.file = os.path.join(CONF_HOME, 'cache.yaml')
if not os.path.exists(self.file):
self.flush()
return
with open(self.file, 'r') as file:
try:
cfg = yaml.load(file, Loader=yaml.FullLoader)
self._data.update(cfg)
except:
self.flush()
@property
def cache_info(self):
return self._data['cache_info']
def _initialize(self):
# Set default configuration values.
cache_info = _md5(str(uuid.uuid1())[-12:]) + "-" + str(int(time.time()))
self._data['cache_info'] = cache_info
def flush(self):
'''Flush the current configuration into the configuration file.'''
with open(self.file, 'w') as file:
cfg = json.loads(json.dumps(self._data))
yaml.dump(cfg, file)
stats_api = "http://paddlepaddle.org.cn/paddlehub/stat"
cache_info = ConfigCache().cache_info
class StatsWorker(threading.Thread):
def __init__(self,
task="asr",
model=None,
version=__version__,
extra_info={}):
threading.Thread.__init__(self)
self._task = task
self._model = model
self._version = version
self._extra_info = extra_info
def run(self):
params = {
'task': self._task,
'version': self._version,
'from': 'ppspeech'
}
if self._model:
params['model'] = self._model
self._extra_info.update({
'cache_info': cache_info,
})
params.update({"extra": json.dumps(self._extra_info)})
try:
requests.get(stats_api, params)
except Exception:
pass
return
def _note_one_stat(cls_name, params={}):
task = cls_name.replace('Executor', '').lower() # XXExecutor
extra_info = {
'paddle_version': paddle.__version__,
}
if 'model' in params:
model = params['model']
else:
model = None
if 'audio_file' in params:
try:
_, sr = paddleaudio.load(params['audio_file'])
except Exception:
sr = -1
if task == 'asr':
extra_info.update({
'lang': params['lang'],
'inp_sr': sr,
'model_sr': params['sample_rate'],
})
elif task == 'st':
extra_info.update({
'lang':
params['src_lang'] + '-' + params['tgt_lang'],
'inp_sr':
sr,
'model_sr':
params['sample_rate'],
})
elif task == 'tts':
model = params['am']
extra_info.update({
'lang': params['lang'],
'vocoder': params['voc'],
})
elif task == 'cls':
extra_info.update({
'inp_sr': sr,
})
elif task == 'text':
extra_info.update({
'sub_task': params['task'],
'lang': params['lang'],
})
else:
return
StatsWorker(
task=task,
model=model,
version=__version__,
extra_info=extra_info, ).start()
def _parse_args(func, *args, **kwargs):
# FullArgSpec(args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations)
argspec = inspect.getfullargspec(func)
keys = argspec[0]
if keys[0] == 'self': # Remove self pointer.
keys = keys[1:]
default_values = argspec[3]
values = [None] * (len(keys) - len(default_values))
values.extend(list(default_values))
params = dict(zip(keys, values))
for idx, v in enumerate(args):
params[keys[idx]] = v
for k, v in kwargs.items():
params[k] = v
return params
def stats_wrapper(executor_func):
def _warpper(self, *args, **kwargs):
try:
_note_one_stat(
type(self).__name__, _parse_args(executor_func, *args,
**kwargs))
except Exception:
pass
return executor_func(self, *args, **kwargs)
return _warpper

@ -234,7 +234,10 @@ setup_info = dict(
'Programming Language :: Python :: 3.9',
],
entry_points={
'console_scripts': ['paddlespeech=paddlespeech.cli.entry:_execute']
'console_scripts': [
'paddlespeech=paddlespeech.cli.entry:_execute',
'paddleserver=paddlespeech.server.entry:_execute'
]
})
setup(**setup_info)

Loading…
Cancel
Save