parent
35e3be9ac8
commit
da3ea7bb40
@ -0,0 +1,82 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List
|
||||
|
||||
from .entry import client_commands
|
||||
from .entry import server_commands
|
||||
from .util import cli_client_register
|
||||
from .util import cli_server_register
|
||||
from .util import get_client_command
|
||||
from .util import get_server_command
|
||||
|
||||
__all__ = [
|
||||
'ServerBaseCommand',
|
||||
'ServerHelpCommand',
|
||||
'ClientBaseCommand',
|
||||
'ClientHelpCommand',
|
||||
]
|
||||
|
||||
|
||||
@cli_server_register(name='paddlespeech_server')
|
||||
class ServerBaseCommand:
|
||||
def execute(self, argv: List[str]) -> bool:
|
||||
help = get_server_command('paddlespeech_server.help')
|
||||
return help().execute(argv)
|
||||
|
||||
|
||||
@cli_server_register(
|
||||
name='paddlespeech_server.help', description='Show help for commands.')
|
||||
class ServerHelpCommand:
|
||||
def execute(self, argv: List[str]) -> bool:
|
||||
msg = 'Usage:\n'
|
||||
msg += ' paddlespeech_server <command> <options>\n\n'
|
||||
msg += 'Commands:\n'
|
||||
for command, detail in server_commands['paddlespeech_server'].items():
|
||||
if command.startswith('_'):
|
||||
continue
|
||||
|
||||
if '_description' not in detail:
|
||||
continue
|
||||
msg += ' {:<15} {}\n'.format(command,
|
||||
detail['_description'])
|
||||
|
||||
print(msg)
|
||||
return True
|
||||
|
||||
|
||||
@cli_client_register(name='paddlespeech_client')
|
||||
class ClientBaseCommand:
|
||||
def execute(self, argv: List[str]) -> bool:
|
||||
help = get_client_command('paddlespeech_client.help')
|
||||
return help().execute(argv)
|
||||
|
||||
|
||||
@cli_client_register(
|
||||
name='paddlespeech_client.help', description='Show help for commands.')
|
||||
class ClientHelpCommand:
|
||||
def execute(self, argv: List[str]) -> bool:
|
||||
msg = 'Usage:\n'
|
||||
msg += ' paddlespeech_client <command> <options>\n\n'
|
||||
msg += 'Commands:\n'
|
||||
for command, detail in client_commands['paddlespeech_client'].items():
|
||||
if command.startswith('_'):
|
||||
continue
|
||||
|
||||
if '_description' not in detail:
|
||||
continue
|
||||
msg += ' {:<15} {}\n'.format(command,
|
||||
detail['_description'])
|
||||
|
||||
print(msg)
|
||||
return True
|
@ -0,0 +1,16 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .paddlespeech_client import ASRClientExecutor
|
||||
from .paddlespeech_client import TTSClientExecutor
|
||||
from .paddlespeech_server import ServerExecutor
|
@ -0,0 +1,156 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import soundfile
|
||||
|
||||
from ..util import cli_client_register
|
||||
from paddlespeech.server.utils.audio_process import wav2pcm
|
||||
from paddlespeech.server.utils.util import wav2base64
|
||||
|
||||
__all__ = ['TTSClientExecutor', 'ASRClientExecutor']
|
||||
|
||||
|
||||
@cli_client_register(
|
||||
name='paddlespeech_client.tts', description='visit tts service')
|
||||
class TTSClientExecutor():
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.parser = argparse.ArgumentParser()
|
||||
self.parser.add_argument(
|
||||
'--server_ip', type=str, default='127.0.0.1', help='server ip')
|
||||
self.parser.add_argument(
|
||||
'--port', type=int, default=8090, help='server port')
|
||||
self.parser.add_argument(
|
||||
'--text',
|
||||
type=str,
|
||||
default="你好,欢迎使用语音合成服务",
|
||||
help='A sentence to be synthesized')
|
||||
self.parser.add_argument(
|
||||
'--spk_id', type=int, default=0, help='Speaker id')
|
||||
self.parser.add_argument(
|
||||
'--speed', type=float, default=1.0, help='Audio speed')
|
||||
self.parser.add_argument(
|
||||
'--volume', type=float, default=1.0, help='Audio volume')
|
||||
self.parser.add_argument(
|
||||
'--sample_rate',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Sampling rate, the default is the same as the model')
|
||||
self.parser.add_argument(
|
||||
'--output',
|
||||
type=str,
|
||||
default="./out.wav",
|
||||
help='Synthesized audio file')
|
||||
|
||||
# Request and response
|
||||
def tts_client(self, args):
|
||||
""" Request and response
|
||||
Args:
|
||||
text: A sentence to be synthesized
|
||||
outfile: Synthetic audio file
|
||||
"""
|
||||
url = 'http://' + args.server_ip + ":" + str(
|
||||
args.port) + '/paddlespeech/tts'
|
||||
request = {
|
||||
"text": args.text,
|
||||
"spk_id": args.spk_id,
|
||||
"speed": args.speed,
|
||||
"volume": args.volume,
|
||||
"sample_rate": args.sample_rate,
|
||||
"save_path": args.output
|
||||
}
|
||||
|
||||
response = requests.post(url, json.dumps(request))
|
||||
response_dict = response.json()
|
||||
print(response_dict["message"])
|
||||
wav_base64 = response_dict["result"]["audio"]
|
||||
|
||||
audio_data_byte = base64.b64decode(wav_base64)
|
||||
# from byte
|
||||
samples, sample_rate = soundfile.read(
|
||||
io.BytesIO(audio_data_byte), dtype='float32')
|
||||
|
||||
# transform audio
|
||||
outfile = args.output
|
||||
if outfile.endswith(".wav"):
|
||||
soundfile.write(outfile, samples, sample_rate)
|
||||
elif outfile.endswith(".pcm"):
|
||||
temp_wav = str(random.getrandbits(128)) + ".wav"
|
||||
soundfile.write(temp_wav, samples, sample_rate)
|
||||
wav2pcm(temp_wav, outfile, data_type=np.int16)
|
||||
os.system("rm %s" % (temp_wav))
|
||||
else:
|
||||
print("The format for saving audio only supports wav or pcm")
|
||||
|
||||
return len(samples), sample_rate
|
||||
|
||||
def execute(self, argv: List[str]) -> bool:
|
||||
args = self.parser.parse_args(argv)
|
||||
st = time.time()
|
||||
try:
|
||||
samples_length, sample_rate = self.tts_client(args)
|
||||
time_consume = time.time() - st
|
||||
print("Save synthesized audio successfully on %s." % (args.output))
|
||||
print("Inference time: %f s." % (time_consume))
|
||||
except:
|
||||
print("Failed to synthesized audio.")
|
||||
|
||||
|
||||
@cli_client_register(
|
||||
name='paddlespeech_client.asr', description='visit asr service')
|
||||
class ASRClientExecutor():
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.parser = argparse.ArgumentParser()
|
||||
self.parser.add_argument(
|
||||
'--server_ip', type=str, default='127.0.0.1', help='server ip')
|
||||
self.parser.add_argument(
|
||||
'--port', type=int, default=8090, help='server port')
|
||||
self.parser.add_argument(
|
||||
'--audio_file',
|
||||
type=str,
|
||||
default="./paddlespeech/server/tests/16_audio.wav",
|
||||
help='Audio file to be recognized')
|
||||
self.parser.add_argument(
|
||||
'--sample_rate', type=int, default=16000, help='audio sample rate')
|
||||
|
||||
def execute(self, argv: List[str]) -> bool:
|
||||
args = self.parser.parse_args(argv)
|
||||
url = 'http://' + args.server_ip + ":" + str(
|
||||
args.port) + '/paddlespeech/asr'
|
||||
audio = wav2base64(args.audio_file)
|
||||
data = {
|
||||
"audio": audio,
|
||||
"audio_format": "wav",
|
||||
"sample_rate": args.sample_rate,
|
||||
"lang": "zh_cn",
|
||||
}
|
||||
time_start = time.time()
|
||||
try:
|
||||
r = requests.post(url=url, data=json.dumps(data))
|
||||
# ending Timestamp
|
||||
time_end = time.time()
|
||||
print('time cost', time_end - time_start, 's')
|
||||
except:
|
||||
print("Failed to speech recognition.")
|
@ -0,0 +1,78 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
from typing import List
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
|
||||
from ..util import cli_server_register
|
||||
from paddlespeech.server.engine.engine_factory import EngineFactory
|
||||
from paddlespeech.server.restful.api import setup_router
|
||||
from paddlespeech.server.utils.config import get_config
|
||||
|
||||
__all__ = ['ServerExecutor']
|
||||
|
||||
app = FastAPI(
|
||||
title="PaddleSpeech Serving API", description="Api", version="0.0.1")
|
||||
|
||||
|
||||
@cli_server_register(
|
||||
name='paddlespeech_server.server', description='Start the service')
|
||||
class ServerExecutor():
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.parser = argparse.ArgumentParser()
|
||||
self.parser.add_argument(
|
||||
"--config_file",
|
||||
action="store",
|
||||
help="yaml file of the app",
|
||||
default="./conf/application.yaml")
|
||||
|
||||
self.parser.add_argument(
|
||||
"--log_file",
|
||||
action="store",
|
||||
help="log file",
|
||||
default="./log/paddlespeech.log")
|
||||
|
||||
def init(self, config) -> bool:
|
||||
"""system initialization
|
||||
|
||||
Args:
|
||||
config (CfgNode): config object
|
||||
|
||||
Returns:
|
||||
bool:
|
||||
"""
|
||||
# init api
|
||||
api_list = list(config.engine_backend)
|
||||
api_router = setup_router(api_list)
|
||||
app.include_router(api_router)
|
||||
|
||||
# init engine
|
||||
engine_pool = []
|
||||
for engine in config.engine_backend:
|
||||
engine_pool.append(EngineFactory.get_engine(engine_name=engine))
|
||||
if not engine_pool[-1].init(
|
||||
config_file=config.engine_backend[engine]):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def execute(self, argv: List[str]) -> bool:
|
||||
args = self.parser.parse_args(argv)
|
||||
config = get_config(args.config_file)
|
||||
|
||||
if self.init(config):
|
||||
uvicorn.run(app, host=config.host, port=config.port, debug=True)
|
@ -1,7 +1,7 @@
|
||||
model: 'conformer_wenetspeech'
|
||||
lang: 'zh'
|
||||
sample_rate: 16000
|
||||
cfg_path:
|
||||
ckpt_path:
|
||||
cfg_path: # [optional]
|
||||
ckpt_path: # [optional]
|
||||
decode_method: 'attention_rescoring'
|
||||
force_yes: False
|
||||
|
@ -0,0 +1,25 @@
|
||||
# This is the parameter configuration file for ASR server.
|
||||
# These are the static models that support paddle inference.
|
||||
|
||||
##################################################################
|
||||
# ACOUSTIC MODEL SETTING #
|
||||
# am choices=['deepspeech2offline_aishell'] TODO
|
||||
##################################################################
|
||||
model_type: 'deepspeech2offline_aishell'
|
||||
am_model: # the pdmodel file of am static model [optional]
|
||||
am_params: # the pdiparams file of am static model [optional]
|
||||
lang: 'zh'
|
||||
sample_rate: 16000
|
||||
cfg_path:
|
||||
decode_method:
|
||||
force_yes:
|
||||
|
||||
am_predictor_conf:
|
||||
use_gpu: True
|
||||
enable_mkldnn: True
|
||||
switch_ir_optim: True
|
||||
|
||||
|
||||
##################################################################
|
||||
# OTHERS #
|
||||
##################################################################
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,36 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from paddlespeech.server.engine.engine_factory import EngineFactory
|
||||
|
||||
# global value
|
||||
ENGINE_POOL = {}
|
||||
|
||||
|
||||
def get_engine_pool() -> dict:
|
||||
""" Get engine pool
|
||||
"""
|
||||
global ENGINE_POOL
|
||||
return ENGINE_POOL
|
||||
|
||||
|
||||
def init_engine_pool(config) -> bool:
|
||||
""" Init engine pool
|
||||
"""
|
||||
global ENGINE_POOL
|
||||
for engine in config.engine_backend:
|
||||
ENGINE_POOL[engine] = EngineFactory.get_engine(engine_name=engine, engine_type=config.engine_type[engine])
|
||||
if not ENGINE_POOL[engine].init(config_file=config.engine_backend[engine]):
|
||||
return False
|
||||
|
||||
return True
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,57 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
|
||||
__all__ = ['server_commands', 'client_commands']
|
||||
|
||||
|
||||
def _CommandDict():
|
||||
return defaultdict(_CommandDict)
|
||||
|
||||
|
||||
def server_execute():
|
||||
com = server_commands
|
||||
idx = 0
|
||||
for _argv in (['paddlespeech_server'] + sys.argv[1:]):
|
||||
if _argv not in com:
|
||||
break
|
||||
idx += 1
|
||||
com = com[_argv]
|
||||
|
||||
# The method 'execute' of a command instance returns 'True' for a success
|
||||
# while 'False' for a failure. Here converts this result into a exit status
|
||||
# in bash: 0 for a success and 1 for a failure.
|
||||
status = 0 if com['_entry']().execute(sys.argv[idx:]) else 1
|
||||
return status
|
||||
|
||||
|
||||
def client_execute():
|
||||
com = client_commands
|
||||
idx = 0
|
||||
for _argv in (['paddlespeech_client'] + sys.argv[1:]):
|
||||
if _argv not in com:
|
||||
break
|
||||
idx += 1
|
||||
com = com[_argv]
|
||||
|
||||
# The method 'execute' of a command instance returns 'True' for a success
|
||||
# while 'False' for a failure. Here converts this result into a exit status
|
||||
# in bash: 0 for a success and 1 for a failure.
|
||||
status = 0 if com['_entry']().execute(sys.argv[idx:]) else 1
|
||||
return status
|
||||
|
||||
|
||||
server_commands = _CommandDict()
|
||||
client_commands = _CommandDict()
|
@ -0,0 +1,364 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import tarfile
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
import zipfile
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
|
||||
import paddle
|
||||
import requests
|
||||
import yaml
|
||||
from paddle.framework import load
|
||||
|
||||
import paddleaudio
|
||||
from . import download
|
||||
from .. import __version__
|
||||
from .entry import client_commands
|
||||
from .entry import server_commands
|
||||
|
||||
requests.adapters.DEFAULT_RETRIES = 3
|
||||
|
||||
__all__ = [
|
||||
'cli_server_register',
|
||||
'get_server_command',
|
||||
'cli_client_register',
|
||||
'get_client_command',
|
||||
'download_and_decompress',
|
||||
'load_state_dict_from_url',
|
||||
'stats_wrapper',
|
||||
]
|
||||
|
||||
|
||||
def cli_server_register(name: str, description: str='') -> Any:
|
||||
def _warpper(command):
|
||||
items = name.split('.')
|
||||
|
||||
com = server_commands
|
||||
for item in items:
|
||||
com = com[item]
|
||||
com['_entry'] = command
|
||||
if description:
|
||||
com['_description'] = description
|
||||
return command
|
||||
|
||||
return _warpper
|
||||
|
||||
|
||||
def get_server_command(name: str) -> Any:
|
||||
items = name.split('.')
|
||||
com = server_commands
|
||||
for item in items:
|
||||
com = com[item]
|
||||
|
||||
return com['_entry']
|
||||
|
||||
|
||||
def cli_client_register(name: str, description: str='') -> Any:
|
||||
def _warpper(command):
|
||||
items = name.split('.')
|
||||
|
||||
com = client_commands
|
||||
for item in items:
|
||||
com = com[item]
|
||||
com['_entry'] = command
|
||||
if description:
|
||||
com['_description'] = description
|
||||
return command
|
||||
|
||||
return _warpper
|
||||
|
||||
|
||||
def get_client_command(name: str) -> Any:
|
||||
items = name.split('.')
|
||||
com = client_commands
|
||||
for item in items:
|
||||
com = com[item]
|
||||
|
||||
return com['_entry']
|
||||
|
||||
|
||||
def _get_uncompress_path(filepath: os.PathLike) -> os.PathLike:
|
||||
file_dir = os.path.dirname(filepath)
|
||||
is_zip_file = False
|
||||
if tarfile.is_tarfile(filepath):
|
||||
files = tarfile.open(filepath, "r:*")
|
||||
file_list = files.getnames()
|
||||
elif zipfile.is_zipfile(filepath):
|
||||
files = zipfile.ZipFile(filepath, 'r')
|
||||
file_list = files.namelist()
|
||||
is_zip_file = True
|
||||
else:
|
||||
return file_dir
|
||||
|
||||
if download._is_a_single_file(file_list):
|
||||
rootpath = file_list[0]
|
||||
uncompressed_path = os.path.join(file_dir, rootpath)
|
||||
elif download._is_a_single_dir(file_list):
|
||||
if is_zip_file:
|
||||
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[0]
|
||||
else:
|
||||
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1]
|
||||
uncompressed_path = os.path.join(file_dir, rootpath)
|
||||
else:
|
||||
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
|
||||
uncompressed_path = os.path.join(file_dir, rootpath)
|
||||
|
||||
files.close()
|
||||
return uncompressed_path
|
||||
|
||||
|
||||
def download_and_decompress(archive: Dict[str, str], path: str) -> os.PathLike:
|
||||
"""
|
||||
Download archieves and decompress to specific path.
|
||||
"""
|
||||
if not os.path.isdir(path):
|
||||
os.makedirs(path)
|
||||
|
||||
assert 'url' in archive and 'md5' in archive, \
|
||||
'Dictionary keys of "url" and "md5" are required in the archive, but got: {}'.format(list(archive.keys()))
|
||||
|
||||
filepath = os.path.join(path, os.path.basename(archive['url']))
|
||||
if os.path.isfile(filepath) and download._md5check(filepath,
|
||||
archive['md5']):
|
||||
uncompress_path = _get_uncompress_path(filepath)
|
||||
if not os.path.isdir(uncompress_path):
|
||||
download._decompress(filepath)
|
||||
else:
|
||||
StatsWorker(
|
||||
task='download',
|
||||
version=__version__,
|
||||
extra_info={
|
||||
'download_url': archive['url'],
|
||||
'paddle_version': paddle.__version__
|
||||
}).start()
|
||||
uncompress_path = download.get_path_from_url(archive['url'], path,
|
||||
archive['md5'])
|
||||
|
||||
return uncompress_path
|
||||
|
||||
|
||||
def load_state_dict_from_url(url: str, path: str, md5: str=None) -> os.PathLike:
|
||||
"""
|
||||
Download and load a state dict from url
|
||||
"""
|
||||
if not os.path.isdir(path):
|
||||
os.makedirs(path)
|
||||
|
||||
download.get_path_from_url(url, path, md5)
|
||||
return load(os.path.join(path, os.path.basename(url)))
|
||||
|
||||
|
||||
def _get_user_home():
|
||||
return os.path.expanduser('~')
|
||||
|
||||
|
||||
def _get_paddlespcceh_home():
|
||||
if 'PPSPEECH_HOME' in os.environ:
|
||||
home_path = os.environ['PPSPEECH_HOME']
|
||||
if os.path.exists(home_path):
|
||||
if os.path.isdir(home_path):
|
||||
return home_path
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'The environment variable PPSPEECH_HOME {} is not a directory.'.
|
||||
format(home_path))
|
||||
else:
|
||||
return home_path
|
||||
return os.path.join(_get_user_home(), '.paddlespeech')
|
||||
|
||||
|
||||
def _get_sub_home(directory):
|
||||
home = os.path.join(_get_paddlespcceh_home(), directory)
|
||||
if not os.path.exists(home):
|
||||
os.makedirs(home)
|
||||
return home
|
||||
|
||||
|
||||
PPSPEECH_HOME = _get_paddlespcceh_home()
|
||||
MODEL_HOME = _get_sub_home('models')
|
||||
CONF_HOME = _get_sub_home('conf')
|
||||
|
||||
|
||||
def _md5(text: str):
|
||||
'''Calculate the md5 value of the input text.'''
|
||||
md5code = hashlib.md5(text.encode())
|
||||
return md5code.hexdigest()
|
||||
|
||||
|
||||
class ConfigCache:
|
||||
def __init__(self):
|
||||
self._data = {}
|
||||
self._initialize()
|
||||
self.file = os.path.join(CONF_HOME, 'cache.yaml')
|
||||
if not os.path.exists(self.file):
|
||||
self.flush()
|
||||
return
|
||||
|
||||
with open(self.file, 'r') as file:
|
||||
try:
|
||||
cfg = yaml.load(file, Loader=yaml.FullLoader)
|
||||
self._data.update(cfg)
|
||||
except:
|
||||
self.flush()
|
||||
|
||||
@property
|
||||
def cache_info(self):
|
||||
return self._data['cache_info']
|
||||
|
||||
def _initialize(self):
|
||||
# Set default configuration values.
|
||||
cache_info = _md5(str(uuid.uuid1())[-12:]) + "-" + str(int(time.time()))
|
||||
self._data['cache_info'] = cache_info
|
||||
|
||||
def flush(self):
|
||||
'''Flush the current configuration into the configuration file.'''
|
||||
with open(self.file, 'w') as file:
|
||||
cfg = json.loads(json.dumps(self._data))
|
||||
yaml.dump(cfg, file)
|
||||
|
||||
|
||||
stats_api = "http://paddlepaddle.org.cn/paddlehub/stat"
|
||||
cache_info = ConfigCache().cache_info
|
||||
|
||||
|
||||
class StatsWorker(threading.Thread):
|
||||
def __init__(self,
|
||||
task="asr",
|
||||
model=None,
|
||||
version=__version__,
|
||||
extra_info={}):
|
||||
threading.Thread.__init__(self)
|
||||
self._task = task
|
||||
self._model = model
|
||||
self._version = version
|
||||
self._extra_info = extra_info
|
||||
|
||||
def run(self):
|
||||
params = {
|
||||
'task': self._task,
|
||||
'version': self._version,
|
||||
'from': 'ppspeech'
|
||||
}
|
||||
if self._model:
|
||||
params['model'] = self._model
|
||||
|
||||
self._extra_info.update({
|
||||
'cache_info': cache_info,
|
||||
})
|
||||
params.update({"extra": json.dumps(self._extra_info)})
|
||||
|
||||
try:
|
||||
requests.get(stats_api, params)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return
|
||||
|
||||
|
||||
def _note_one_stat(cls_name, params={}):
|
||||
task = cls_name.replace('Executor', '').lower() # XXExecutor
|
||||
extra_info = {
|
||||
'paddle_version': paddle.__version__,
|
||||
}
|
||||
|
||||
if 'model' in params:
|
||||
model = params['model']
|
||||
else:
|
||||
model = None
|
||||
|
||||
if 'audio_file' in params:
|
||||
try:
|
||||
_, sr = paddleaudio.load(params['audio_file'])
|
||||
except Exception:
|
||||
sr = -1
|
||||
|
||||
if task == 'asr':
|
||||
extra_info.update({
|
||||
'lang': params['lang'],
|
||||
'inp_sr': sr,
|
||||
'model_sr': params['sample_rate'],
|
||||
})
|
||||
elif task == 'st':
|
||||
extra_info.update({
|
||||
'lang':
|
||||
params['src_lang'] + '-' + params['tgt_lang'],
|
||||
'inp_sr':
|
||||
sr,
|
||||
'model_sr':
|
||||
params['sample_rate'],
|
||||
})
|
||||
elif task == 'tts':
|
||||
model = params['am']
|
||||
extra_info.update({
|
||||
'lang': params['lang'],
|
||||
'vocoder': params['voc'],
|
||||
})
|
||||
elif task == 'cls':
|
||||
extra_info.update({
|
||||
'inp_sr': sr,
|
||||
})
|
||||
elif task == 'text':
|
||||
extra_info.update({
|
||||
'sub_task': params['task'],
|
||||
'lang': params['lang'],
|
||||
})
|
||||
else:
|
||||
return
|
||||
|
||||
StatsWorker(
|
||||
task=task,
|
||||
model=model,
|
||||
version=__version__,
|
||||
extra_info=extra_info, ).start()
|
||||
|
||||
|
||||
def _parse_args(func, *args, **kwargs):
|
||||
# FullArgSpec(args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations)
|
||||
argspec = inspect.getfullargspec(func)
|
||||
|
||||
keys = argspec[0]
|
||||
if keys[0] == 'self': # Remove self pointer.
|
||||
keys = keys[1:]
|
||||
|
||||
default_values = argspec[3]
|
||||
values = [None] * (len(keys) - len(default_values))
|
||||
values.extend(list(default_values))
|
||||
params = dict(zip(keys, values))
|
||||
|
||||
for idx, v in enumerate(args):
|
||||
params[keys[idx]] = v
|
||||
for k, v in kwargs.items():
|
||||
params[k] = v
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def stats_wrapper(executor_func):
|
||||
def _warpper(self, *args, **kwargs):
|
||||
try:
|
||||
_note_one_stat(
|
||||
type(self).__name__, _parse_args(executor_func, *args,
|
||||
**kwargs))
|
||||
except Exception:
|
||||
pass
|
||||
return executor_func(self, *args, **kwargs)
|
||||
|
||||
return _warpper
|
Loading…
Reference in new issue