diff --git a/paddlespeech/cli/utils.py b/paddlespeech/cli/utils.py index ead3fb053..eb023c11b 100644 --- a/paddlespeech/cli/utils.py +++ b/paddlespeech/cli/utils.py @@ -12,10 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools +import hashlib import logging import os +import tarfile +import zipfile from typing import Any from typing import Dict +from typing import List from paddle.framework import load from paddle.utils import download @@ -55,12 +59,69 @@ def get_command(name: str) -> Any: return com['_entry'] -def decompress(file: str) -> os.PathLike: - """ - Extracts all files from a compressed file. - """ - assert os.path.isfile(file), "File: {} not exists.".format(file) - return download._decompress(file) +def _md5check(filepath: os.PathLike, md5sum: str) -> bool: + logger.info("File {} md5 checking...".format(filepath)) + md5 = hashlib.md5() + with open(filepath, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b""): + md5.update(chunk) + calc_md5sum = md5.hexdigest() + + if calc_md5sum != md5sum: + logger.info("File {} md5 check failed, {}(calc) != " + "{}(base)".format(filepath, calc_md5sum, md5sum)) + return False + else: + logger.info("File {} md5 check passed.".format(filepath)) + return True + + +def _get_uncompress_path(filepath: os.PathLike) -> os.PathLike: + file_dir = os.path.dirname(filepath) + + if tarfile.is_tarfile(filepath): + files = tarfile.open(filepath, "r:*") + file_list = files.getnames() + elif zipfile.is_zipfile(filepath): + files = zipfile.ZipFile(filepath, 'r') + file_list = files.namelist() + else: + return file_dir + + if _is_a_single_file(file_list): + rootpath = file_list[0] + uncompressed_path = os.path.join(file_dir, rootpath) + elif _is_a_single_dir(file_list): + rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + else: + rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + + files.close() + return uncompressed_path + + +def _is_a_single_file(file_list: List[os.PathLike]) -> bool: + if len(file_list) == 1 and file_list[0].find(os.sep) < -1: + return True + return False + + +def _is_a_single_dir(file_list: List[os.PathLike]) -> bool: + new_file_list = [] + for file_path in file_list: + if '/' in file_path: + file_path = file_path.replace('/', os.sep) + elif '\\' in file_path: + file_path = file_path.replace('\\', os.sep) + new_file_list.append(file_path) + + file_name = new_file_list[0].split(os.sep)[0] + for i in range(1, len(new_file_list)): + if file_name != new_file_list[i].split(os.sep)[0]: + return False + return True def download_and_decompress(archive: Dict[str, str], path: str) -> os.PathLike: @@ -73,11 +134,16 @@ def download_and_decompress(archive: Dict[str, str], path: str) -> os.PathLike: assert 'url' in archive and 'md5' in archive, \ 'Dictionary keys of "url" and "md5" are required in the archive, but got: {}'.format(list(archive.keys())) - if False: - # TODO: File match md5 and uncompressed_path exist, so skip downloading and decompressing... - pass + filepath = os.path.join(path, os.path.basename(archive['url'])) + if os.path.isfile(filepath) and _md5check(filepath, archive['md5']): + uncompress_path = _get_uncompress_path(filepath) + if not os.path.isdir(uncompress_path): + download._decompress(filepath) else: - return download.get_path_from_url(archive['url'], path, archive['md5']) + uncompress_path = download.get_path_from_url(archive['url'], path, + archive['md5']) + + return uncompress_path def load_state_dict_from_url(url: str, path: str, md5: str=None) -> os.PathLike: diff --git a/setup.py b/setup.py index fbb3a24a4..b728dd469 100644 --- a/setup.py +++ b/setup.py @@ -43,6 +43,7 @@ requirements = { "nara_wpe", "nltk", "pandas", + "paddleaudio", "paddlespeech_ctcdecoders", "paddlespeech_feat", "praatio~=4.1",