parent
35e3be9ac8
commit
da3ea7bb40
@ -0,0 +1,82 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from .entry import client_commands
|
||||||
|
from .entry import server_commands
|
||||||
|
from .util import cli_client_register
|
||||||
|
from .util import cli_server_register
|
||||||
|
from .util import get_client_command
|
||||||
|
from .util import get_server_command
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'ServerBaseCommand',
|
||||||
|
'ServerHelpCommand',
|
||||||
|
'ClientBaseCommand',
|
||||||
|
'ClientHelpCommand',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@cli_server_register(name='paddlespeech_server')
|
||||||
|
class ServerBaseCommand:
|
||||||
|
def execute(self, argv: List[str]) -> bool:
|
||||||
|
help = get_server_command('paddlespeech_server.help')
|
||||||
|
return help().execute(argv)
|
||||||
|
|
||||||
|
|
||||||
|
@cli_server_register(
|
||||||
|
name='paddlespeech_server.help', description='Show help for commands.')
|
||||||
|
class ServerHelpCommand:
|
||||||
|
def execute(self, argv: List[str]) -> bool:
|
||||||
|
msg = 'Usage:\n'
|
||||||
|
msg += ' paddlespeech_server <command> <options>\n\n'
|
||||||
|
msg += 'Commands:\n'
|
||||||
|
for command, detail in server_commands['paddlespeech_server'].items():
|
||||||
|
if command.startswith('_'):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if '_description' not in detail:
|
||||||
|
continue
|
||||||
|
msg += ' {:<15} {}\n'.format(command,
|
||||||
|
detail['_description'])
|
||||||
|
|
||||||
|
print(msg)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@cli_client_register(name='paddlespeech_client')
|
||||||
|
class ClientBaseCommand:
|
||||||
|
def execute(self, argv: List[str]) -> bool:
|
||||||
|
help = get_client_command('paddlespeech_client.help')
|
||||||
|
return help().execute(argv)
|
||||||
|
|
||||||
|
|
||||||
|
@cli_client_register(
|
||||||
|
name='paddlespeech_client.help', description='Show help for commands.')
|
||||||
|
class ClientHelpCommand:
|
||||||
|
def execute(self, argv: List[str]) -> bool:
|
||||||
|
msg = 'Usage:\n'
|
||||||
|
msg += ' paddlespeech_client <command> <options>\n\n'
|
||||||
|
msg += 'Commands:\n'
|
||||||
|
for command, detail in client_commands['paddlespeech_client'].items():
|
||||||
|
if command.startswith('_'):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if '_description' not in detail:
|
||||||
|
continue
|
||||||
|
msg += ' {:<15} {}\n'.format(command,
|
||||||
|
detail['_description'])
|
||||||
|
|
||||||
|
print(msg)
|
||||||
|
return True
|
@ -0,0 +1,16 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from .paddlespeech_client import ASRClientExecutor
|
||||||
|
from .paddlespeech_client import TTSClientExecutor
|
||||||
|
from .paddlespeech_server import ServerExecutor
|
@ -0,0 +1,156 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
from ..util import cli_client_register
|
||||||
|
from paddlespeech.server.utils.audio_process import wav2pcm
|
||||||
|
from paddlespeech.server.utils.util import wav2base64
|
||||||
|
|
||||||
|
__all__ = ['TTSClientExecutor', 'ASRClientExecutor']
|
||||||
|
|
||||||
|
|
||||||
|
@cli_client_register(
|
||||||
|
name='paddlespeech_client.tts', description='visit tts service')
|
||||||
|
class TTSClientExecutor():
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.parser = argparse.ArgumentParser()
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--server_ip', type=str, default='127.0.0.1', help='server ip')
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--port', type=int, default=8090, help='server port')
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--text',
|
||||||
|
type=str,
|
||||||
|
default="你好,欢迎使用语音合成服务",
|
||||||
|
help='A sentence to be synthesized')
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--spk_id', type=int, default=0, help='Speaker id')
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--speed', type=float, default=1.0, help='Audio speed')
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--volume', type=float, default=1.0, help='Audio volume')
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--sample_rate',
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help='Sampling rate, the default is the same as the model')
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--output',
|
||||||
|
type=str,
|
||||||
|
default="./out.wav",
|
||||||
|
help='Synthesized audio file')
|
||||||
|
|
||||||
|
# Request and response
|
||||||
|
def tts_client(self, args):
|
||||||
|
""" Request and response
|
||||||
|
Args:
|
||||||
|
text: A sentence to be synthesized
|
||||||
|
outfile: Synthetic audio file
|
||||||
|
"""
|
||||||
|
url = 'http://' + args.server_ip + ":" + str(
|
||||||
|
args.port) + '/paddlespeech/tts'
|
||||||
|
request = {
|
||||||
|
"text": args.text,
|
||||||
|
"spk_id": args.spk_id,
|
||||||
|
"speed": args.speed,
|
||||||
|
"volume": args.volume,
|
||||||
|
"sample_rate": args.sample_rate,
|
||||||
|
"save_path": args.output
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(url, json.dumps(request))
|
||||||
|
response_dict = response.json()
|
||||||
|
print(response_dict["message"])
|
||||||
|
wav_base64 = response_dict["result"]["audio"]
|
||||||
|
|
||||||
|
audio_data_byte = base64.b64decode(wav_base64)
|
||||||
|
# from byte
|
||||||
|
samples, sample_rate = soundfile.read(
|
||||||
|
io.BytesIO(audio_data_byte), dtype='float32')
|
||||||
|
|
||||||
|
# transform audio
|
||||||
|
outfile = args.output
|
||||||
|
if outfile.endswith(".wav"):
|
||||||
|
soundfile.write(outfile, samples, sample_rate)
|
||||||
|
elif outfile.endswith(".pcm"):
|
||||||
|
temp_wav = str(random.getrandbits(128)) + ".wav"
|
||||||
|
soundfile.write(temp_wav, samples, sample_rate)
|
||||||
|
wav2pcm(temp_wav, outfile, data_type=np.int16)
|
||||||
|
os.system("rm %s" % (temp_wav))
|
||||||
|
else:
|
||||||
|
print("The format for saving audio only supports wav or pcm")
|
||||||
|
|
||||||
|
return len(samples), sample_rate
|
||||||
|
|
||||||
|
def execute(self, argv: List[str]) -> bool:
|
||||||
|
args = self.parser.parse_args(argv)
|
||||||
|
st = time.time()
|
||||||
|
try:
|
||||||
|
samples_length, sample_rate = self.tts_client(args)
|
||||||
|
time_consume = time.time() - st
|
||||||
|
print("Save synthesized audio successfully on %s." % (args.output))
|
||||||
|
print("Inference time: %f s." % (time_consume))
|
||||||
|
except:
|
||||||
|
print("Failed to synthesized audio.")
|
||||||
|
|
||||||
|
|
||||||
|
@cli_client_register(
|
||||||
|
name='paddlespeech_client.asr', description='visit asr service')
|
||||||
|
class ASRClientExecutor():
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.parser = argparse.ArgumentParser()
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--server_ip', type=str, default='127.0.0.1', help='server ip')
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--port', type=int, default=8090, help='server port')
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--audio_file',
|
||||||
|
type=str,
|
||||||
|
default="./paddlespeech/server/tests/16_audio.wav",
|
||||||
|
help='Audio file to be recognized')
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--sample_rate', type=int, default=16000, help='audio sample rate')
|
||||||
|
|
||||||
|
def execute(self, argv: List[str]) -> bool:
|
||||||
|
args = self.parser.parse_args(argv)
|
||||||
|
url = 'http://' + args.server_ip + ":" + str(
|
||||||
|
args.port) + '/paddlespeech/asr'
|
||||||
|
audio = wav2base64(args.audio_file)
|
||||||
|
data = {
|
||||||
|
"audio": audio,
|
||||||
|
"audio_format": "wav",
|
||||||
|
"sample_rate": args.sample_rate,
|
||||||
|
"lang": "zh_cn",
|
||||||
|
}
|
||||||
|
time_start = time.time()
|
||||||
|
try:
|
||||||
|
r = requests.post(url=url, data=json.dumps(data))
|
||||||
|
# ending Timestamp
|
||||||
|
time_end = time.time()
|
||||||
|
print('time cost', time_end - time_start, 's')
|
||||||
|
except:
|
||||||
|
print("Failed to speech recognition.")
|
@ -0,0 +1,78 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from ..util import cli_server_register
|
||||||
|
from paddlespeech.server.engine.engine_factory import EngineFactory
|
||||||
|
from paddlespeech.server.restful.api import setup_router
|
||||||
|
from paddlespeech.server.utils.config import get_config
|
||||||
|
|
||||||
|
__all__ = ['ServerExecutor']
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="PaddleSpeech Serving API", description="Api", version="0.0.1")
|
||||||
|
|
||||||
|
|
||||||
|
@cli_server_register(
|
||||||
|
name='paddlespeech_server.server', description='Start the service')
|
||||||
|
class ServerExecutor():
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.parser = argparse.ArgumentParser()
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--config_file",
|
||||||
|
action="store",
|
||||||
|
help="yaml file of the app",
|
||||||
|
default="./conf/application.yaml")
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--log_file",
|
||||||
|
action="store",
|
||||||
|
help="log file",
|
||||||
|
default="./log/paddlespeech.log")
|
||||||
|
|
||||||
|
def init(self, config) -> bool:
|
||||||
|
"""system initialization
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (CfgNode): config object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool:
|
||||||
|
"""
|
||||||
|
# init api
|
||||||
|
api_list = list(config.engine_backend)
|
||||||
|
api_router = setup_router(api_list)
|
||||||
|
app.include_router(api_router)
|
||||||
|
|
||||||
|
# init engine
|
||||||
|
engine_pool = []
|
||||||
|
for engine in config.engine_backend:
|
||||||
|
engine_pool.append(EngineFactory.get_engine(engine_name=engine))
|
||||||
|
if not engine_pool[-1].init(
|
||||||
|
config_file=config.engine_backend[engine]):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def execute(self, argv: List[str]) -> bool:
|
||||||
|
args = self.parser.parse_args(argv)
|
||||||
|
config = get_config(args.config_file)
|
||||||
|
|
||||||
|
if self.init(config):
|
||||||
|
uvicorn.run(app, host=config.host, port=config.port, debug=True)
|
@ -1,7 +1,7 @@
|
|||||||
model: 'conformer_wenetspeech'
|
model: 'conformer_wenetspeech'
|
||||||
lang: 'zh'
|
lang: 'zh'
|
||||||
sample_rate: 16000
|
sample_rate: 16000
|
||||||
cfg_path:
|
cfg_path: # [optional]
|
||||||
ckpt_path:
|
ckpt_path: # [optional]
|
||||||
decode_method: 'attention_rescoring'
|
decode_method: 'attention_rescoring'
|
||||||
force_yes: False
|
force_yes: False
|
||||||
|
@ -0,0 +1,25 @@
|
|||||||
|
# This is the parameter configuration file for ASR server.
|
||||||
|
# These are the static models that support paddle inference.
|
||||||
|
|
||||||
|
##################################################################
|
||||||
|
# ACOUSTIC MODEL SETTING #
|
||||||
|
# am choices=['deepspeech2offline_aishell'] TODO
|
||||||
|
##################################################################
|
||||||
|
model_type: 'deepspeech2offline_aishell'
|
||||||
|
am_model: # the pdmodel file of am static model [optional]
|
||||||
|
am_params: # the pdiparams file of am static model [optional]
|
||||||
|
lang: 'zh'
|
||||||
|
sample_rate: 16000
|
||||||
|
cfg_path:
|
||||||
|
decode_method:
|
||||||
|
force_yes:
|
||||||
|
|
||||||
|
am_predictor_conf:
|
||||||
|
use_gpu: True
|
||||||
|
enable_mkldnn: True
|
||||||
|
switch_ir_optim: True
|
||||||
|
|
||||||
|
|
||||||
|
##################################################################
|
||||||
|
# OTHERS #
|
||||||
|
##################################################################
|
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
@ -0,0 +1,36 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from paddlespeech.server.engine.engine_factory import EngineFactory
|
||||||
|
|
||||||
|
# global value
|
||||||
|
ENGINE_POOL = {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_engine_pool() -> dict:
|
||||||
|
""" Get engine pool
|
||||||
|
"""
|
||||||
|
global ENGINE_POOL
|
||||||
|
return ENGINE_POOL
|
||||||
|
|
||||||
|
|
||||||
|
def init_engine_pool(config) -> bool:
|
||||||
|
""" Init engine pool
|
||||||
|
"""
|
||||||
|
global ENGINE_POOL
|
||||||
|
for engine in config.engine_backend:
|
||||||
|
ENGINE_POOL[engine] = EngineFactory.get_engine(engine_name=engine, engine_type=config.engine_type[engine])
|
||||||
|
if not ENGINE_POOL[engine].init(config_file=config.engine_backend[engine]):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
@ -0,0 +1,57 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import sys
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
__all__ = ['server_commands', 'client_commands']
|
||||||
|
|
||||||
|
|
||||||
|
def _CommandDict():
|
||||||
|
return defaultdict(_CommandDict)
|
||||||
|
|
||||||
|
|
||||||
|
def server_execute():
|
||||||
|
com = server_commands
|
||||||
|
idx = 0
|
||||||
|
for _argv in (['paddlespeech_server'] + sys.argv[1:]):
|
||||||
|
if _argv not in com:
|
||||||
|
break
|
||||||
|
idx += 1
|
||||||
|
com = com[_argv]
|
||||||
|
|
||||||
|
# The method 'execute' of a command instance returns 'True' for a success
|
||||||
|
# while 'False' for a failure. Here converts this result into a exit status
|
||||||
|
# in bash: 0 for a success and 1 for a failure.
|
||||||
|
status = 0 if com['_entry']().execute(sys.argv[idx:]) else 1
|
||||||
|
return status
|
||||||
|
|
||||||
|
|
||||||
|
def client_execute():
|
||||||
|
com = client_commands
|
||||||
|
idx = 0
|
||||||
|
for _argv in (['paddlespeech_client'] + sys.argv[1:]):
|
||||||
|
if _argv not in com:
|
||||||
|
break
|
||||||
|
idx += 1
|
||||||
|
com = com[_argv]
|
||||||
|
|
||||||
|
# The method 'execute' of a command instance returns 'True' for a success
|
||||||
|
# while 'False' for a failure. Here converts this result into a exit status
|
||||||
|
# in bash: 0 for a success and 1 for a failure.
|
||||||
|
status = 0 if com['_entry']().execute(sys.argv[idx:]) else 1
|
||||||
|
return status
|
||||||
|
|
||||||
|
|
||||||
|
server_commands = _CommandDict()
|
||||||
|
client_commands = _CommandDict()
|
@ -0,0 +1,364 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import hashlib
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import tarfile
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
import zipfile
|
||||||
|
from typing import Any
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import requests
|
||||||
|
import yaml
|
||||||
|
from paddle.framework import load
|
||||||
|
|
||||||
|
import paddleaudio
|
||||||
|
from . import download
|
||||||
|
from .. import __version__
|
||||||
|
from .entry import client_commands
|
||||||
|
from .entry import server_commands
|
||||||
|
|
||||||
|
requests.adapters.DEFAULT_RETRIES = 3
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'cli_server_register',
|
||||||
|
'get_server_command',
|
||||||
|
'cli_client_register',
|
||||||
|
'get_client_command',
|
||||||
|
'download_and_decompress',
|
||||||
|
'load_state_dict_from_url',
|
||||||
|
'stats_wrapper',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def cli_server_register(name: str, description: str='') -> Any:
|
||||||
|
def _warpper(command):
|
||||||
|
items = name.split('.')
|
||||||
|
|
||||||
|
com = server_commands
|
||||||
|
for item in items:
|
||||||
|
com = com[item]
|
||||||
|
com['_entry'] = command
|
||||||
|
if description:
|
||||||
|
com['_description'] = description
|
||||||
|
return command
|
||||||
|
|
||||||
|
return _warpper
|
||||||
|
|
||||||
|
|
||||||
|
def get_server_command(name: str) -> Any:
|
||||||
|
items = name.split('.')
|
||||||
|
com = server_commands
|
||||||
|
for item in items:
|
||||||
|
com = com[item]
|
||||||
|
|
||||||
|
return com['_entry']
|
||||||
|
|
||||||
|
|
||||||
|
def cli_client_register(name: str, description: str='') -> Any:
|
||||||
|
def _warpper(command):
|
||||||
|
items = name.split('.')
|
||||||
|
|
||||||
|
com = client_commands
|
||||||
|
for item in items:
|
||||||
|
com = com[item]
|
||||||
|
com['_entry'] = command
|
||||||
|
if description:
|
||||||
|
com['_description'] = description
|
||||||
|
return command
|
||||||
|
|
||||||
|
return _warpper
|
||||||
|
|
||||||
|
|
||||||
|
def get_client_command(name: str) -> Any:
|
||||||
|
items = name.split('.')
|
||||||
|
com = client_commands
|
||||||
|
for item in items:
|
||||||
|
com = com[item]
|
||||||
|
|
||||||
|
return com['_entry']
|
||||||
|
|
||||||
|
|
||||||
|
def _get_uncompress_path(filepath: os.PathLike) -> os.PathLike:
|
||||||
|
file_dir = os.path.dirname(filepath)
|
||||||
|
is_zip_file = False
|
||||||
|
if tarfile.is_tarfile(filepath):
|
||||||
|
files = tarfile.open(filepath, "r:*")
|
||||||
|
file_list = files.getnames()
|
||||||
|
elif zipfile.is_zipfile(filepath):
|
||||||
|
files = zipfile.ZipFile(filepath, 'r')
|
||||||
|
file_list = files.namelist()
|
||||||
|
is_zip_file = True
|
||||||
|
else:
|
||||||
|
return file_dir
|
||||||
|
|
||||||
|
if download._is_a_single_file(file_list):
|
||||||
|
rootpath = file_list[0]
|
||||||
|
uncompressed_path = os.path.join(file_dir, rootpath)
|
||||||
|
elif download._is_a_single_dir(file_list):
|
||||||
|
if is_zip_file:
|
||||||
|
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[0]
|
||||||
|
else:
|
||||||
|
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1]
|
||||||
|
uncompressed_path = os.path.join(file_dir, rootpath)
|
||||||
|
else:
|
||||||
|
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
|
||||||
|
uncompressed_path = os.path.join(file_dir, rootpath)
|
||||||
|
|
||||||
|
files.close()
|
||||||
|
return uncompressed_path
|
||||||
|
|
||||||
|
|
||||||
|
def download_and_decompress(archive: Dict[str, str], path: str) -> os.PathLike:
|
||||||
|
"""
|
||||||
|
Download archieves and decompress to specific path.
|
||||||
|
"""
|
||||||
|
if not os.path.isdir(path):
|
||||||
|
os.makedirs(path)
|
||||||
|
|
||||||
|
assert 'url' in archive and 'md5' in archive, \
|
||||||
|
'Dictionary keys of "url" and "md5" are required in the archive, but got: {}'.format(list(archive.keys()))
|
||||||
|
|
||||||
|
filepath = os.path.join(path, os.path.basename(archive['url']))
|
||||||
|
if os.path.isfile(filepath) and download._md5check(filepath,
|
||||||
|
archive['md5']):
|
||||||
|
uncompress_path = _get_uncompress_path(filepath)
|
||||||
|
if not os.path.isdir(uncompress_path):
|
||||||
|
download._decompress(filepath)
|
||||||
|
else:
|
||||||
|
StatsWorker(
|
||||||
|
task='download',
|
||||||
|
version=__version__,
|
||||||
|
extra_info={
|
||||||
|
'download_url': archive['url'],
|
||||||
|
'paddle_version': paddle.__version__
|
||||||
|
}).start()
|
||||||
|
uncompress_path = download.get_path_from_url(archive['url'], path,
|
||||||
|
archive['md5'])
|
||||||
|
|
||||||
|
return uncompress_path
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict_from_url(url: str, path: str, md5: str=None) -> os.PathLike:
|
||||||
|
"""
|
||||||
|
Download and load a state dict from url
|
||||||
|
"""
|
||||||
|
if not os.path.isdir(path):
|
||||||
|
os.makedirs(path)
|
||||||
|
|
||||||
|
download.get_path_from_url(url, path, md5)
|
||||||
|
return load(os.path.join(path, os.path.basename(url)))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_user_home():
|
||||||
|
return os.path.expanduser('~')
|
||||||
|
|
||||||
|
|
||||||
|
def _get_paddlespcceh_home():
|
||||||
|
if 'PPSPEECH_HOME' in os.environ:
|
||||||
|
home_path = os.environ['PPSPEECH_HOME']
|
||||||
|
if os.path.exists(home_path):
|
||||||
|
if os.path.isdir(home_path):
|
||||||
|
return home_path
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
'The environment variable PPSPEECH_HOME {} is not a directory.'.
|
||||||
|
format(home_path))
|
||||||
|
else:
|
||||||
|
return home_path
|
||||||
|
return os.path.join(_get_user_home(), '.paddlespeech')
|
||||||
|
|
||||||
|
|
||||||
|
def _get_sub_home(directory):
|
||||||
|
home = os.path.join(_get_paddlespcceh_home(), directory)
|
||||||
|
if not os.path.exists(home):
|
||||||
|
os.makedirs(home)
|
||||||
|
return home
|
||||||
|
|
||||||
|
|
||||||
|
PPSPEECH_HOME = _get_paddlespcceh_home()
|
||||||
|
MODEL_HOME = _get_sub_home('models')
|
||||||
|
CONF_HOME = _get_sub_home('conf')
|
||||||
|
|
||||||
|
|
||||||
|
def _md5(text: str):
|
||||||
|
'''Calculate the md5 value of the input text.'''
|
||||||
|
md5code = hashlib.md5(text.encode())
|
||||||
|
return md5code.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigCache:
|
||||||
|
def __init__(self):
|
||||||
|
self._data = {}
|
||||||
|
self._initialize()
|
||||||
|
self.file = os.path.join(CONF_HOME, 'cache.yaml')
|
||||||
|
if not os.path.exists(self.file):
|
||||||
|
self.flush()
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(self.file, 'r') as file:
|
||||||
|
try:
|
||||||
|
cfg = yaml.load(file, Loader=yaml.FullLoader)
|
||||||
|
self._data.update(cfg)
|
||||||
|
except:
|
||||||
|
self.flush()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache_info(self):
|
||||||
|
return self._data['cache_info']
|
||||||
|
|
||||||
|
def _initialize(self):
|
||||||
|
# Set default configuration values.
|
||||||
|
cache_info = _md5(str(uuid.uuid1())[-12:]) + "-" + str(int(time.time()))
|
||||||
|
self._data['cache_info'] = cache_info
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
'''Flush the current configuration into the configuration file.'''
|
||||||
|
with open(self.file, 'w') as file:
|
||||||
|
cfg = json.loads(json.dumps(self._data))
|
||||||
|
yaml.dump(cfg, file)
|
||||||
|
|
||||||
|
|
||||||
|
stats_api = "http://paddlepaddle.org.cn/paddlehub/stat"
|
||||||
|
cache_info = ConfigCache().cache_info
|
||||||
|
|
||||||
|
|
||||||
|
class StatsWorker(threading.Thread):
|
||||||
|
def __init__(self,
|
||||||
|
task="asr",
|
||||||
|
model=None,
|
||||||
|
version=__version__,
|
||||||
|
extra_info={}):
|
||||||
|
threading.Thread.__init__(self)
|
||||||
|
self._task = task
|
||||||
|
self._model = model
|
||||||
|
self._version = version
|
||||||
|
self._extra_info = extra_info
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
params = {
|
||||||
|
'task': self._task,
|
||||||
|
'version': self._version,
|
||||||
|
'from': 'ppspeech'
|
||||||
|
}
|
||||||
|
if self._model:
|
||||||
|
params['model'] = self._model
|
||||||
|
|
||||||
|
self._extra_info.update({
|
||||||
|
'cache_info': cache_info,
|
||||||
|
})
|
||||||
|
params.update({"extra": json.dumps(self._extra_info)})
|
||||||
|
|
||||||
|
try:
|
||||||
|
requests.get(stats_api, params)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def _note_one_stat(cls_name, params={}):
|
||||||
|
task = cls_name.replace('Executor', '').lower() # XXExecutor
|
||||||
|
extra_info = {
|
||||||
|
'paddle_version': paddle.__version__,
|
||||||
|
}
|
||||||
|
|
||||||
|
if 'model' in params:
|
||||||
|
model = params['model']
|
||||||
|
else:
|
||||||
|
model = None
|
||||||
|
|
||||||
|
if 'audio_file' in params:
|
||||||
|
try:
|
||||||
|
_, sr = paddleaudio.load(params['audio_file'])
|
||||||
|
except Exception:
|
||||||
|
sr = -1
|
||||||
|
|
||||||
|
if task == 'asr':
|
||||||
|
extra_info.update({
|
||||||
|
'lang': params['lang'],
|
||||||
|
'inp_sr': sr,
|
||||||
|
'model_sr': params['sample_rate'],
|
||||||
|
})
|
||||||
|
elif task == 'st':
|
||||||
|
extra_info.update({
|
||||||
|
'lang':
|
||||||
|
params['src_lang'] + '-' + params['tgt_lang'],
|
||||||
|
'inp_sr':
|
||||||
|
sr,
|
||||||
|
'model_sr':
|
||||||
|
params['sample_rate'],
|
||||||
|
})
|
||||||
|
elif task == 'tts':
|
||||||
|
model = params['am']
|
||||||
|
extra_info.update({
|
||||||
|
'lang': params['lang'],
|
||||||
|
'vocoder': params['voc'],
|
||||||
|
})
|
||||||
|
elif task == 'cls':
|
||||||
|
extra_info.update({
|
||||||
|
'inp_sr': sr,
|
||||||
|
})
|
||||||
|
elif task == 'text':
|
||||||
|
extra_info.update({
|
||||||
|
'sub_task': params['task'],
|
||||||
|
'lang': params['lang'],
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
|
StatsWorker(
|
||||||
|
task=task,
|
||||||
|
model=model,
|
||||||
|
version=__version__,
|
||||||
|
extra_info=extra_info, ).start()
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_args(func, *args, **kwargs):
|
||||||
|
# FullArgSpec(args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations)
|
||||||
|
argspec = inspect.getfullargspec(func)
|
||||||
|
|
||||||
|
keys = argspec[0]
|
||||||
|
if keys[0] == 'self': # Remove self pointer.
|
||||||
|
keys = keys[1:]
|
||||||
|
|
||||||
|
default_values = argspec[3]
|
||||||
|
values = [None] * (len(keys) - len(default_values))
|
||||||
|
values.extend(list(default_values))
|
||||||
|
params = dict(zip(keys, values))
|
||||||
|
|
||||||
|
for idx, v in enumerate(args):
|
||||||
|
params[keys[idx]] = v
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
params[k] = v
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def stats_wrapper(executor_func):
|
||||||
|
def _warpper(self, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
_note_one_stat(
|
||||||
|
type(self).__name__, _parse_args(executor_func, *args,
|
||||||
|
**kwargs))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return executor_func(self, *args, **kwargs)
|
||||||
|
|
||||||
|
return _warpper
|
Loading…
Reference in new issue