From 3701fba0be50a680a8ff3d4eb4fd46209fd21905 Mon Sep 17 00:00:00 2001 From: KP <109694228@qq.com> Date: Thu, 9 Dec 2021 20:03:56 +0800 Subject: [PATCH] Update download logic and fix README typos. --- demos/audio_tagging/README.md | 2 +- demos/speech_recognition/README.md | 2 +- demos/speech_translation/README.md | 2 +- paddlespeech/cli/asr/infer.py | 2 +- paddlespeech/cli/cls/infer.py | 6 +- paddlespeech/cli/download.py | 57 ++---------------- paddlespeech/cli/log.py | 60 +++++++++++++++++++ paddlespeech/cli/st/infer.py | 2 +- paddlespeech/cli/tts/infer.py | 4 +- paddlespeech/cli/utils.py | 95 ++---------------------------- 10 files changed, 81 insertions(+), 151 deletions(-) create mode 100644 paddlespeech/cli/log.py diff --git a/demos/audio_tagging/README.md b/demos/audio_tagging/README.md index 5073393d4..1144cbb1f 100644 --- a/demos/audio_tagging/README.md +++ b/demos/audio_tagging/README.md @@ -22,7 +22,7 @@ wget https://paddlespeech.bj.bcebos.com/PaddleAudio/cat.wav https://paddlespeech ### 3. Usage - Command Line(Recommended) ```bash - paddlespeech cls --input ~/cat.wav --topk 10 + paddlespeech cls --input ./cat.wav --topk 10 ``` Usage: ```bash diff --git a/demos/speech_recognition/README.md b/demos/speech_recognition/README.md index 60ee8e4d4..c91165315 100644 --- a/demos/speech_recognition/README.md +++ b/demos/speech_recognition/README.md @@ -22,7 +22,7 @@ wget https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech. ### 3. Usage - Command Line(Recommended) ```bash - paddlespeech asr --input ~/zh.wav + paddlespeech asr --input ./zh.wav ``` Usage: ```bash diff --git a/demos/speech_translation/README.md b/demos/speech_translation/README.md index b2f29168a..8bb322c52 100644 --- a/demos/speech_translation/README.md +++ b/demos/speech_translation/README.md @@ -22,7 +22,7 @@ wget https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech. ### 3. Usage - Command Line(Recommended) ```bash - paddlespeech st --input ~/en.wav + paddlespeech st --input ./en.wav ``` Usage: ```bash diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 2db239c0d..00f212932 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -27,9 +27,9 @@ import yaml from yacs.config import CfgNode from ..executor import BaseExecutor +from ..log import logger from ..utils import cli_register from ..utils import download_and_decompress -from ..utils import logger from ..utils import MODEL_HOME from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.transform.transformation import Transformation diff --git a/paddlespeech/cli/cls/infer.py b/paddlespeech/cli/cls/infer.py index 795d59f68..37f2a9d29 100644 --- a/paddlespeech/cli/cls/infer.py +++ b/paddlespeech/cli/cls/infer.py @@ -20,14 +20,14 @@ from typing import Union import numpy as np import paddle import yaml -from paddleaudio import load -from paddleaudio.features import LogMelSpectrogram from ..executor import BaseExecutor +from ..log import logger from ..utils import cli_register from ..utils import download_and_decompress -from ..utils import logger from ..utils import MODEL_HOME +from paddleaudio import load +from paddleaudio.features import LogMelSpectrogram from paddlespeech.s2t.utils.dynamic_import import dynamic_import __all__ = ['CLSExecutor'] diff --git a/paddlespeech/cli/download.py b/paddlespeech/cli/download.py index 8de9f0459..0f09b6fad 100644 --- a/paddlespeech/cli/download.py +++ b/paddlespeech/cli/download.py @@ -20,49 +20,21 @@ import os import os.path as osp import shutil import subprocess -import sys import tarfile import time import zipfile import requests +from tqdm import tqdm -try: - from tqdm import tqdm -except: +from .log import logger - class tqdm(object): - def __init__(self, total=None): - self.total = total - self.n = 0 - - def update(self, n): - self.n += n - if self.total is None: - sys.stderr.write("\r{0:.1f} bytes".format(self.n)) - else: - sys.stderr.write( - "\r{0:.1f}%".format(100 * self.n / float(self.total))) - sys.stderr.flush() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - sys.stderr.write('\n') - - -import logging -logger = logging.getLogger(__name__) - -__all__ = ['get_weights_path_from_url'] - -WEIGHTS_HOME = osp.expanduser("~/.cache/paddle/hapi/weights") +__all__ = ['get_path_from_url'] DOWNLOAD_RETRY_LIMIT = 3 -def is_url(path): +def _is_url(path): """ Whether path is URL. Args: @@ -71,25 +43,6 @@ def is_url(path): return path.startswith('http://') or path.startswith('https://') -def get_weights_path_from_url(url, md5sum=None): - """Get weights path from WEIGHT_HOME, if not exists, - download it from url. - Args: - url (str): download url - md5sum (str): md5 sum of download package - - Returns: - str: a local path to save downloaded weights. - Examples: - .. code-block:: python - from paddle.utils.download import get_weights_path_from_url - resnet18_pretrained_weight_url = 'https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams' - local_weight_path = get_weights_path_from_url(resnet18_pretrained_weight_url) - """ - path = get_path_from_url(url, WEIGHTS_HOME, md5sum) - return path - - def _map_path(url, root_dir): # parse path after download under root_dir fname = osp.split(url)[-1] @@ -135,7 +88,7 @@ def get_path_from_url(url, from paddle.fluid.dygraph.parallel import ParallelEnv - assert is_url(url), "downloading from {} not a url".format(url) + assert _is_url(url), "downloading from {} not a url".format(url) # parse path after download to decompress under root_dir fullpath = _map_path(url, root_dir) # Mainly used to solve the problem of downloading data from different diff --git a/paddlespeech/cli/log.py b/paddlespeech/cli/log.py new file mode 100644 index 000000000..891b71a94 --- /dev/null +++ b/paddlespeech/cli/log.py @@ -0,0 +1,60 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +import logging + +__all__ = [ + 'logger', +] + + +class Logger(object): + def __init__(self, name: str=None): + name = 'PaddleSpeech' if not name else name + self.logger = logging.getLogger(name) + + log_config = { + 'DEBUG': 10, + 'INFO': 20, + 'TRAIN': 21, + 'EVAL': 22, + 'WARNING': 30, + 'ERROR': 40, + 'CRITICAL': 50, + 'EXCEPTION': 100, + } + for key, level in log_config.items(): + logging.addLevelName(level, key) + if key == 'EXCEPTION': + self.__dict__[key.lower()] = self.logger.exception + else: + self.__dict__[key.lower()] = functools.partial(self.__call__, + level) + + self.format = logging.Formatter( + fmt='[%(asctime)-15s] [%(levelname)8s] [%(filename)s] [L%(lineno)d] - %(message)s' + ) + + self.handler = logging.StreamHandler() + self.handler.setFormatter(self.format) + + self.logger.addHandler(self.handler) + self.logger.setLevel(logging.DEBUG) + self.logger.propagate = False + + def __call__(self, log_level: str, msg: str): + self.logger.log(log_level, msg) + + +logger = Logger() diff --git a/paddlespeech/cli/st/infer.py b/paddlespeech/cli/st/infer.py index 32f9d425a..6bb828210 100644 --- a/paddlespeech/cli/st/infer.py +++ b/paddlespeech/cli/st/infer.py @@ -26,9 +26,9 @@ from kaldiio import WriteHelper from yacs.config import CfgNode from ..executor import BaseExecutor +from ..log import logger from ..utils import cli_register from ..utils import download_and_decompress -from ..utils import logger from ..utils import MODEL_HOME from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.utils.dynamic_import import dynamic_import diff --git a/paddlespeech/cli/tts/infer.py b/paddlespeech/cli/tts/infer.py index 771b7d6dc..d5eac2b24 100644 --- a/paddlespeech/cli/tts/infer.py +++ b/paddlespeech/cli/tts/infer.py @@ -25,9 +25,9 @@ import yaml from yacs.config import CfgNode from ..executor import BaseExecutor +from ..log import logger from ..utils import cli_register from ..utils import download_and_decompress -from ..utils import logger from ..utils import MODEL_HOME from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.t2s.frontend import English @@ -535,7 +535,7 @@ class TTSExecutor(BaseExecutor): wav = self.voc_inference(mel) self._outputs['wav'] = wav - def postprocess(self, output: str='output.wav'): + def postprocess(self, output: str='output.wav') -> Union[str, os.PathLike]: """ Output postprocess and return results. This method get model output from self._outputs and convert it into human-readable results. diff --git a/paddlespeech/cli/utils.py b/paddlespeech/cli/utils.py index 6ae6e7e52..8ba780a71 100644 --- a/paddlespeech/cli/utils.py +++ b/paddlespeech/cli/utils.py @@ -11,15 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools -import hashlib -import logging import os import tarfile import zipfile from typing import Any from typing import Dict -from typing import List from paddle.framework import load @@ -31,7 +27,6 @@ __all__ = [ 'get_command', 'download_and_decompress', 'load_state_dict_from_url', - 'logger', ] @@ -59,23 +54,6 @@ def get_command(name: str) -> Any: return com['_entry'] -def _md5check(filepath: os.PathLike, md5sum: str) -> bool: - logger.info("File {} md5 checking...".format(filepath)) - md5 = hashlib.md5() - with open(filepath, 'rb') as f: - for chunk in iter(lambda: f.read(4096), b""): - md5.update(chunk) - calc_md5sum = md5.hexdigest() - - if calc_md5sum != md5sum: - logger.info("File {} md5 check failed, {}(calc) != " - "{}(base)".format(filepath, calc_md5sum, md5sum)) - return False - else: - logger.info("File {} md5 check passed.".format(filepath)) - return True - - def _get_uncompress_path(filepath: os.PathLike) -> os.PathLike: file_dir = os.path.dirname(filepath) if tarfile.is_tarfile(filepath): @@ -86,11 +64,12 @@ def _get_uncompress_path(filepath: os.PathLike) -> os.PathLike: file_list = files.namelist() else: return file_dir - if _is_a_single_file(file_list): + + if download._is_a_single_file(file_list): rootpath = file_list[0] uncompressed_path = os.path.join(file_dir, rootpath) - elif _is_a_single_dir(file_list): - rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[0] + elif download._is_a_single_dir(file_list): + rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1] uncompressed_path = os.path.join(file_dir, rootpath) else: rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] @@ -100,28 +79,6 @@ def _get_uncompress_path(filepath: os.PathLike) -> os.PathLike: return uncompressed_path -def _is_a_single_file(file_list: List[os.PathLike]) -> bool: - if len(file_list) == 1 and file_list[0].find(os.sep) < -1: - return True - return False - - -def _is_a_single_dir(file_list: List[os.PathLike]) -> bool: - new_file_list = [] - for file_path in file_list: - if '/' in file_path: - file_path = file_path.replace('/', os.sep) - elif '\\' in file_path: - file_path = file_path.replace('\\', os.sep) - new_file_list.append(file_path) - - file_name = new_file_list[0].split(os.sep)[0] - for i in range(1, len(new_file_list)): - if file_name != new_file_list[i].split(os.sep)[0]: - return False - return True - - def download_and_decompress(archive: Dict[str, str], path: str) -> os.PathLike: """ Download archieves and decompress to specific path. @@ -133,7 +90,8 @@ def download_and_decompress(archive: Dict[str, str], path: str) -> os.PathLike: 'Dictionary keys of "url" and "md5" are required in the archive, but got: {}'.format(list(archive.keys())) filepath = os.path.join(path, os.path.basename(archive['url'])) - if os.path.isfile(filepath) and _md5check(filepath, archive['md5']): + if os.path.isfile(filepath) and download._md5check(filepath, + archive['md5']): uncompress_path = _get_uncompress_path(filepath) if not os.path.isdir(uncompress_path): download._decompress(filepath) @@ -183,44 +141,3 @@ def _get_sub_home(directory): PPSPEECH_HOME = _get_paddlespcceh_home() MODEL_HOME = _get_sub_home('models') - - -class Logger(object): - def __init__(self, name: str=None): - name = 'PaddleSpeech' if not name else name - self.logger = logging.getLogger(name) - - log_config = { - 'DEBUG': 10, - 'INFO': 20, - 'TRAIN': 21, - 'EVAL': 22, - 'WARNING': 30, - 'ERROR': 40, - 'CRITICAL': 50, - 'EXCEPTION': 100, - } - for key, level in log_config.items(): - logging.addLevelName(level, key) - if key == 'EXCEPTION': - self.__dict__[key.lower()] = self.logger.exception - else: - self.__dict__[key.lower()] = functools.partial(self.__call__, - level) - - self.format = logging.Formatter( - fmt='[%(asctime)-15s] [%(levelname)8s] [%(filename)s] [L%(lineno)d] - %(message)s' - ) - - self.handler = logging.StreamHandler() - self.handler.setFormatter(self.format) - - self.logger.addHandler(self.handler) - self.logger.setLevel(logging.DEBUG) - self.logger.propagate = False - - def __call__(self, log_level: str, msg: str): - self.logger.log(log_level, msg) - - -logger = Logger()