@ -11,15 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import functools
import hashlib
import logging
import os
import os
import tarfile
import tarfile
import zipfile
import zipfile
from typing import Any
from typing import Any
from typing import Dict
from typing import Dict
from typing import List
from paddle . framework import load
from paddle . framework import load
@ -31,7 +27,6 @@ __all__ = [
' get_command ' ,
' get_command ' ,
' download_and_decompress ' ,
' download_and_decompress ' ,
' load_state_dict_from_url ' ,
' load_state_dict_from_url ' ,
' logger ' ,
]
]
@ -59,38 +54,27 @@ def get_command(name: str) -> Any:
return com [ ' _entry ' ]
return com [ ' _entry ' ]
def _md5check ( filepath : os . PathLike , md5sum : str ) - > bool :
logger . info ( " File {} md5 checking... " . format ( filepath ) )
md5 = hashlib . md5 ( )
with open ( filepath , ' rb ' ) as f :
for chunk in iter ( lambda : f . read ( 4096 ) , b " " ) :
md5 . update ( chunk )
calc_md5sum = md5 . hexdigest ( )
if calc_md5sum != md5sum :
logger . info ( " File {} md5 check failed, {} (calc) != "
" {} (base) " . format ( filepath , calc_md5sum , md5sum ) )
return False
else :
logger . info ( " File {} md5 check passed. " . format ( filepath ) )
return True
def _get_uncompress_path ( filepath : os . PathLike ) - > os . PathLike :
def _get_uncompress_path ( filepath : os . PathLike ) - > os . PathLike :
file_dir = os . path . dirname ( filepath )
file_dir = os . path . dirname ( filepath )
is_zip_file = False
if tarfile . is_tarfile ( filepath ) :
if tarfile . is_tarfile ( filepath ) :
files = tarfile . open ( filepath , " r:* " )
files = tarfile . open ( filepath , " r:* " )
file_list = files . getnames ( )
file_list = files . getnames ( )
elif zipfile . is_zipfile ( filepath ) :
elif zipfile . is_zipfile ( filepath ) :
files = zipfile . ZipFile ( filepath , ' r ' )
files = zipfile . ZipFile ( filepath , ' r ' )
file_list = files . namelist ( )
file_list = files . namelist ( )
is_zip_file = True
else :
else :
return file_dir
return file_dir
if _is_a_single_file ( file_list ) :
if download . _is_a_single_file ( file_list ) :
rootpath = file_list [ 0 ]
rootpath = file_list [ 0 ]
uncompressed_path = os . path . join ( file_dir , rootpath )
uncompressed_path = os . path . join ( file_dir , rootpath )
elif _is_a_single_dir ( file_list ) :
elif download . _is_a_single_dir ( file_list ) :
if is_zip_file :
rootpath = os . path . splitext ( file_list [ 0 ] ) [ 0 ] . split ( os . sep ) [ 0 ]
rootpath = os . path . splitext ( file_list [ 0 ] ) [ 0 ] . split ( os . sep ) [ 0 ]
else :
rootpath = os . path . splitext ( file_list [ 0 ] ) [ 0 ] . split ( os . sep ) [ - 1 ]
uncompressed_path = os . path . join ( file_dir , rootpath )
uncompressed_path = os . path . join ( file_dir , rootpath )
else :
else :
rootpath = os . path . splitext ( filepath ) [ 0 ] . split ( os . sep ) [ - 1 ]
rootpath = os . path . splitext ( filepath ) [ 0 ] . split ( os . sep ) [ - 1 ]
@ -100,28 +84,6 @@ def _get_uncompress_path(filepath: os.PathLike) -> os.PathLike:
return uncompressed_path
return uncompressed_path
def _is_a_single_file ( file_list : List [ os . PathLike ] ) - > bool :
if len ( file_list ) == 1 and file_list [ 0 ] . find ( os . sep ) < - 1 :
return True
return False
def _is_a_single_dir ( file_list : List [ os . PathLike ] ) - > bool :
new_file_list = [ ]
for file_path in file_list :
if ' / ' in file_path :
file_path = file_path . replace ( ' / ' , os . sep )
elif ' \\ ' in file_path :
file_path = file_path . replace ( ' \\ ' , os . sep )
new_file_list . append ( file_path )
file_name = new_file_list [ 0 ] . split ( os . sep ) [ 0 ]
for i in range ( 1 , len ( new_file_list ) ) :
if file_name != new_file_list [ i ] . split ( os . sep ) [ 0 ] :
return False
return True
def download_and_decompress ( archive : Dict [ str , str ] , path : str ) - > os . PathLike :
def download_and_decompress ( archive : Dict [ str , str ] , path : str ) - > os . PathLike :
"""
"""
Download archieves and decompress to specific path .
Download archieves and decompress to specific path .
@ -133,7 +95,8 @@ def download_and_decompress(archive: Dict[str, str], path: str) -> os.PathLike:
' Dictionary keys of " url " and " md5 " are required in the archive, but got: {} ' . format ( list ( archive . keys ( ) ) )
' Dictionary keys of " url " and " md5 " are required in the archive, but got: {} ' . format ( list ( archive . keys ( ) ) )
filepath = os . path . join ( path , os . path . basename ( archive [ ' url ' ] ) )
filepath = os . path . join ( path , os . path . basename ( archive [ ' url ' ] ) )
if os . path . isfile ( filepath ) and _md5check ( filepath , archive [ ' md5 ' ] ) :
if os . path . isfile ( filepath ) and download . _md5check ( filepath ,
archive [ ' md5 ' ] ) :
uncompress_path = _get_uncompress_path ( filepath )
uncompress_path = _get_uncompress_path ( filepath )
if not os . path . isdir ( uncompress_path ) :
if not os . path . isdir ( uncompress_path ) :
download . _decompress ( filepath )
download . _decompress ( filepath )
@ -183,44 +146,3 @@ def _get_sub_home(directory):
PPSPEECH_HOME = _get_paddlespcceh_home ( )
PPSPEECH_HOME = _get_paddlespcceh_home ( )
MODEL_HOME = _get_sub_home ( ' models ' )
MODEL_HOME = _get_sub_home ( ' models ' )
class Logger ( object ) :
def __init__ ( self , name : str = None ) :
name = ' PaddleSpeech ' if not name else name
self . logger = logging . getLogger ( name )
log_config = {
' DEBUG ' : 10 ,
' INFO ' : 20 ,
' TRAIN ' : 21 ,
' EVAL ' : 22 ,
' WARNING ' : 30 ,
' ERROR ' : 40 ,
' CRITICAL ' : 50 ,
' EXCEPTION ' : 100 ,
}
for key , level in log_config . items ( ) :
logging . addLevelName ( level , key )
if key == ' EXCEPTION ' :
self . __dict__ [ key . lower ( ) ] = self . logger . exception
else :
self . __dict__ [ key . lower ( ) ] = functools . partial ( self . __call__ ,
level )
self . format = logging . Formatter (
fmt = ' [ %(asctime)-15s ] [ %(levelname)8s ] [ %(filename)s ] [L %(lineno)d ] - %(message)s '
)
self . handler = logging . StreamHandler ( )
self . handler . setFormatter ( self . format )
self . logger . addHandler ( self . handler )
self . logger . setLevel ( logging . DEBUG )
self . logger . propagate = False
def __call__ ( self , log_level : str , msg : str ) :
self . logger . log ( log_level , msg )
logger = Logger ( )