@ -11,15 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import hashlib
import logging
import os
import tarfile
import zipfile
from typing import Any
from typing import Dict
from typing import List
from paddle . framework import load
@ -31,7 +27,6 @@ __all__ = [
' get_command ' ,
' download_and_decompress ' ,
' load_state_dict_from_url ' ,
' logger ' ,
]
@ -59,38 +54,27 @@ def get_command(name: str) -> Any:
return com [ ' _entry ' ]
def _md5check ( filepath : os . PathLike , md5sum : str ) - > bool :
logger . info ( " File {} md5 checking... " . format ( filepath ) )
md5 = hashlib . md5 ( )
with open ( filepath , ' rb ' ) as f :
for chunk in iter ( lambda : f . read ( 4096 ) , b " " ) :
md5 . update ( chunk )
calc_md5sum = md5 . hexdigest ( )
if calc_md5sum != md5sum :
logger . info ( " File {} md5 check failed, {} (calc) != "
" {} (base) " . format ( filepath , calc_md5sum , md5sum ) )
return False
else :
logger . info ( " File {} md5 check passed. " . format ( filepath ) )
return True
def _get_uncompress_path ( filepath : os . PathLike ) - > os . PathLike :
file_dir = os . path . dirname ( filepath )
is_zip_file = False
if tarfile . is_tarfile ( filepath ) :
files = tarfile . open ( filepath , " r:* " )
file_list = files . getnames ( )
elif zipfile . is_zipfile ( filepath ) :
files = zipfile . ZipFile ( filepath , ' r ' )
file_list = files . namelist ( )
is_zip_file = True
else :
return file_dir
if _is_a_single_file ( file_list ) :
if download . _is_a_single_file ( file_list ) :
rootpath = file_list [ 0 ]
uncompressed_path = os . path . join ( file_dir , rootpath )
elif _is_a_single_dir ( file_list ) :
rootpath = os . path . splitext ( file_list [ 0 ] ) [ 0 ] . split ( os . sep ) [ 0 ]
elif download . _is_a_single_dir ( file_list ) :
if is_zip_file :
rootpath = os . path . splitext ( file_list [ 0 ] ) [ 0 ] . split ( os . sep ) [ 0 ]
else :
rootpath = os . path . splitext ( file_list [ 0 ] ) [ 0 ] . split ( os . sep ) [ - 1 ]
uncompressed_path = os . path . join ( file_dir , rootpath )
else :
rootpath = os . path . splitext ( filepath ) [ 0 ] . split ( os . sep ) [ - 1 ]
@ -100,28 +84,6 @@ def _get_uncompress_path(filepath: os.PathLike) -> os.PathLike:
return uncompressed_path
def _is_a_single_file ( file_list : List [ os . PathLike ] ) - > bool :
if len ( file_list ) == 1 and file_list [ 0 ] . find ( os . sep ) < - 1 :
return True
return False
def _is_a_single_dir ( file_list : List [ os . PathLike ] ) - > bool :
new_file_list = [ ]
for file_path in file_list :
if ' / ' in file_path :
file_path = file_path . replace ( ' / ' , os . sep )
elif ' \\ ' in file_path :
file_path = file_path . replace ( ' \\ ' , os . sep )
new_file_list . append ( file_path )
file_name = new_file_list [ 0 ] . split ( os . sep ) [ 0 ]
for i in range ( 1 , len ( new_file_list ) ) :
if file_name != new_file_list [ i ] . split ( os . sep ) [ 0 ] :
return False
return True
def download_and_decompress ( archive : Dict [ str , str ] , path : str ) - > os . PathLike :
"""
Download archieves and decompress to specific path .
@ -133,7 +95,8 @@ def download_and_decompress(archive: Dict[str, str], path: str) -> os.PathLike:
' Dictionary keys of " url " and " md5 " are required in the archive, but got: {} ' . format ( list ( archive . keys ( ) ) )
filepath = os . path . join ( path , os . path . basename ( archive [ ' url ' ] ) )
if os . path . isfile ( filepath ) and _md5check ( filepath , archive [ ' md5 ' ] ) :
if os . path . isfile ( filepath ) and download . _md5check ( filepath ,
archive [ ' md5 ' ] ) :
uncompress_path = _get_uncompress_path ( filepath )
if not os . path . isdir ( uncompress_path ) :
download . _decompress ( filepath )
@ -183,44 +146,3 @@ def _get_sub_home(directory):
PPSPEECH_HOME = _get_paddlespcceh_home ( )
MODEL_HOME = _get_sub_home ( ' models ' )
class Logger ( object ) :
def __init__ ( self , name : str = None ) :
name = ' PaddleSpeech ' if not name else name
self . logger = logging . getLogger ( name )
log_config = {
' DEBUG ' : 10 ,
' INFO ' : 20 ,
' TRAIN ' : 21 ,
' EVAL ' : 22 ,
' WARNING ' : 30 ,
' ERROR ' : 40 ,
' CRITICAL ' : 50 ,
' EXCEPTION ' : 100 ,
}
for key , level in log_config . items ( ) :
logging . addLevelName ( level , key )
if key == ' EXCEPTION ' :
self . __dict__ [ key . lower ( ) ] = self . logger . exception
else :
self . __dict__ [ key . lower ( ) ] = functools . partial ( self . __call__ ,
level )
self . format = logging . Formatter (
fmt = ' [ %(asctime)-15s ] [ %(levelname)8s ] [ %(filename)s ] [L %(lineno)d ] - %(message)s '
)
self . handler = logging . StreamHandler ( )
self . handler . setFormatter ( self . format )
self . logger . addHandler ( self . handler )
self . logger . setLevel ( logging . DEBUG )
self . logger . propagate = False
def __call__ ( self , log_level : str , msg : str ) :
self . logger . log ( log_level , msg )
logger = Logger ( )