You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/paddlespeech/cli/utils.py

331 lines
8.9 KiB

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import inspect
import json
import os
import tarfile
import threading
import time
import uuid
import zipfile
from typing import Any
from typing import Dict
import paddle
import requests
import soundfile as sf
import yaml
from paddle.framework import load
from . import download
from ..utils.env import CONF_HOME
from .entry import commands
try:
from .. import __version__
except ImportError:
__version__ = "0.0.0" # for develop branch
requests.adapters.DEFAULT_RETRIES = 3
__all__ = [
'timer_register',
'cli_register',
'explicit_command_register',
'get_command',
'download_and_decompress',
'load_state_dict_from_url',
'stats_wrapper',
]
CLI_TIMER = {}
def timer_register(command):
CLI_TIMER[command.__name__] = {'start': [], 'end': [], 'extra': []}
return command
def cli_register(name: str, description: str='') -> Any:
def _warpper(command):
items = name.split('.')
com = commands
for item in items:
com = com[item]
com['_entry'] = command
if description:
com['_description'] = description
return command
return _warpper
def explicit_command_register(name: str, description: str='', cls: str=''):
items = name.split('.')
com = commands
for item in items:
com = com[item]
com['_entry'] = cls
if description:
com['_description'] = description
def get_command(name: str) -> Any:
items = name.split('.')
com = commands
for item in items:
com = com[item]
return com['_entry']
def _get_uncompress_path(filepath: os.PathLike) -> os.PathLike:
file_dir = os.path.dirname(filepath)
is_zip_file = False
if tarfile.is_tarfile(filepath):
files = tarfile.open(filepath, "r:*")
file_list = files.getnames()
elif zipfile.is_zipfile(filepath):
files = zipfile.ZipFile(filepath, 'r')
file_list = files.namelist()
is_zip_file = True
else:
return file_dir
if download._is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
elif download._is_a_single_dir(file_list):
if is_zip_file:
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[0]
else:
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
files.close()
return uncompressed_path
def download_and_decompress(archive: Dict[str, str], path: str) -> os.PathLike:
"""
Download archieves and decompress to specific path.
"""
if not os.path.isdir(path):
os.makedirs(path)
assert 'url' in archive and 'md5' in archive, \
'Dictionary keys of "url" and "md5" are required in the archive, but got: {}'.format(list(archive.keys()))
filepath = os.path.join(path, os.path.basename(archive['url']))
if os.path.isfile(filepath) and download._md5check(filepath,
archive['md5']):
uncompress_path = _get_uncompress_path(filepath)
if not os.path.isdir(uncompress_path):
download._decompress(filepath)
else:
StatsWorker(
task='download',
version=__version__,
extra_info={
'download_url': archive['url'],
'paddle_version': paddle.__version__
}).start()
uncompress_path = download.get_path_from_url(archive['url'], path,
archive['md5'])
return uncompress_path
def load_state_dict_from_url(url: str, path: str, md5: str=None) -> os.PathLike:
"""
Download and load a state dict from url
"""
if not os.path.isdir(path):
os.makedirs(path)
download.get_path_from_url(url, path, md5)
return load(os.path.join(path, os.path.basename(url)))
def _md5(text: str):
'''Calculate the md5 value of the input text.'''
md5code = hashlib.md5(text.encode())
return md5code.hexdigest()
class ConfigCache:
def __init__(self):
self._data = {}
self._initialize()
self.file = os.path.join(CONF_HOME, 'cache.yaml')
if not os.path.exists(self.file):
self.flush()
return
with open(self.file, 'r') as file:
try:
cfg = yaml.load(file, Loader=yaml.FullLoader)
self._data.update(cfg)
except Exception as e:
self.flush()
@property
def cache_info(self):
return self._data['cache_info']
def _initialize(self):
# Set default configuration values.
cache_info = _md5(str(uuid.uuid1())[-12:]) + "-" + str(int(time.time()))
self._data['cache_info'] = cache_info
def flush(self):
'''Flush the current configuration into the configuration file.'''
with open(self.file, 'w') as file:
cfg = json.loads(json.dumps(self._data))
yaml.dump(cfg, file)
stats_api = "http://paddlepaddle.org.cn/paddlehub/stat"
cache_info = ConfigCache().cache_info
class StatsWorker(threading.Thread):
def __init__(self,
task="asr",
model=None,
version=__version__,
extra_info={}):
threading.Thread.__init__(self)
self._task = task
self._model = model
self._version = version
self._extra_info = extra_info
def run(self):
params = {
'task': self._task,
'version': self._version,
'from': 'ppspeech'
}
if self._model:
params['model'] = self._model
self._extra_info.update({
'cache_info': cache_info,
})
params.update({"extra": json.dumps(self._extra_info)})
try:
requests.get(stats_api, params)
except Exception:
pass
return
def _note_one_stat(cls_name, params={}):
task = cls_name.replace('Executor', '').lower() # XXExecutor
extra_info = {
'paddle_version': paddle.__version__,
}
if 'model' in params:
model = params['model']
else:
model = None
if 'audio_file' in params:
try:
# recursive import cased by: utils.DATA_HOME
_, sr = sf.read(params['audio_file'])
except Exception:
sr = -1
if task == 'asr':
extra_info.update({
'lang': params['lang'],
'inp_sr': sr,
'model_sr': params['sample_rate'],
})
elif task == 'st':
extra_info.update({
'lang':
params['src_lang'] + '-' + params['tgt_lang'],
'inp_sr':
sr,
'model_sr':
params['sample_rate'],
})
elif task == 'tts':
model = params['am']
extra_info.update({
'lang': params['lang'],
'vocoder': params['voc'],
})
elif task == 'cls':
extra_info.update({
'inp_sr': sr,
})
elif task == 'text':
extra_info.update({
'sub_task': params['task'],
'lang': params['lang'],
})
else:
return
StatsWorker(
task=task,
model=model,
version=__version__,
extra_info=extra_info, ).start()
def _parse_args(func, *args, **kwargs):
# FullArgSpec(args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations)
argspec = inspect.getfullargspec(func)
keys = argspec[0]
if keys[0] == 'self': # Remove self pointer.
keys = keys[1:]
default_values = argspec[3]
values = [None] * (len(keys) - len(default_values))
values.extend(list(default_values))
params = dict(zip(keys, values))
for idx, v in enumerate(args):
params[keys[idx]] = v
for k, v in kwargs.items():
params[k] = v
return params
def stats_wrapper(executor_func):
def _warpper(self, *args, **kwargs):
try:
_note_one_stat(
type(self).__name__, _parse_args(executor_func, *args,
**kwargs))
except Exception:
pass
return executor_func(self, *args, **kwargs)
return _warpper