Merge pull request #2103 from KPatr1ck/task_resource

[CLI][Resource]Fix unnecessary download
pull/2106/head
TianYuan 3 years ago committed by GitHub
commit 72303e22df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -175,14 +175,21 @@ class TTSExecutor(BaseExecutor):
if hasattr(self, 'am_inference') and hasattr(self, 'voc_inference'): if hasattr(self, 'am_inference') and hasattr(self, 'voc_inference'):
logger.info('Models had been initialized.') logger.info('Models had been initialized.')
return return
# am # am
if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None:
use_pretrained_am = True
else:
use_pretrained_am = False
am_tag = am + '-' + lang am_tag = am + '-' + lang
self.task_resource.set_task_model( self.task_resource.set_task_model(
model_tag=am_tag, model_tag=am_tag,
model_type=0, # am model_type=0, # am
skip_download=not use_pretrained_am,
version=None, # default version version=None, # default version
) )
if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None: if use_pretrained_am:
self.am_res_path = self.task_resource.res_dir self.am_res_path = self.task_resource.res_dir
self.am_config = os.path.join(self.am_res_path, self.am_config = os.path.join(self.am_res_path,
self.task_resource.res_dict['config']) self.task_resource.res_dict['config'])
@ -220,13 +227,19 @@ class TTSExecutor(BaseExecutor):
self.speaker_dict = speaker_dict self.speaker_dict = speaker_dict
# voc # voc
if voc_ckpt is None or voc_config is None or voc_stat is None:
use_pretrained_voc = True
else:
use_pretrained_voc = False
voc_tag = voc + '-' + lang voc_tag = voc + '-' + lang
self.task_resource.set_task_model( self.task_resource.set_task_model(
model_tag=voc_tag, model_tag=voc_tag,
model_type=1, # vocoder model_type=1, # vocoder
skip_download=not use_pretrained_voc,
version=None, # default version version=None, # default version
) )
if voc_ckpt is None or voc_config is None or voc_stat is None: if use_pretrained_voc:
self.voc_res_path = self.task_resource.voc_res_dir self.voc_res_path = self.task_resource.voc_res_dir
self.voc_config = os.path.join( self.voc_config = os.path.join(
self.voc_res_path, self.task_resource.voc_res_dict['config']) self.voc_res_path, self.task_resource.voc_res_dict['config'])

@ -60,6 +60,7 @@ class CommonTaskResource:
def set_task_model(self, def set_task_model(self,
model_tag: str, model_tag: str,
model_type: int=0, model_type: int=0,
skip_download: bool=False,
version: Optional[str]=None): version: Optional[str]=None):
"""Set model tag and version of current task. """Set model tag and version of current task.
@ -83,6 +84,7 @@ class CommonTaskResource:
self.version = version self.version = version
self.res_dict = self.pretrained_models[model_tag][version] self.res_dict = self.pretrained_models[model_tag][version]
self._format_path(self.res_dict) self._format_path(self.res_dict)
if not skip_download:
self.res_dir = self._fetch(self.res_dict, self.res_dir = self._fetch(self.res_dict,
self._get_model_dir(model_type)) self._get_model_dir(model_type))
else: else:
@ -91,6 +93,7 @@ class CommonTaskResource:
self.voc_version = version self.voc_version = version
self.voc_res_dict = self.pretrained_models[model_tag][version] self.voc_res_dict = self.pretrained_models[model_tag][version]
self._format_path(self.voc_res_dict) self._format_path(self.voc_res_dict)
if not skip_download:
self.voc_res_dir = self._fetch(self.voc_res_dict, self.voc_res_dir = self._fetch(self.voc_res_dict,
self._get_model_dir(model_type)) self._get_model_dir(model_type))

@ -105,13 +105,19 @@ class TTSServerExecutor(TTSExecutor):
logger.info('Models had been initialized.') logger.info('Models had been initialized.')
return return
# am model info # am model info
if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None:
use_pretrained_am = True
else:
use_pretrained_am = False
am_tag = am + '-' + lang am_tag = am + '-' + lang
self.task_resource.set_task_model( self.task_resource.set_task_model(
model_tag=am_tag, model_tag=am_tag,
model_type=0, # am model_type=0, # am
skip_download=not use_pretrained_am,
version=None, # default version version=None, # default version
) )
if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None: if use_pretrained_am:
self.am_res_path = self.task_resource.res_dir self.am_res_path = self.task_resource.res_dir
self.am_config = os.path.join(self.am_res_path, self.am_config = os.path.join(self.am_res_path,
self.task_resource.res_dict['config']) self.task_resource.res_dict['config'])
@ -138,13 +144,19 @@ class TTSServerExecutor(TTSExecutor):
self.speaker_dict = None self.speaker_dict = None
# voc model info # voc model info
if voc_ckpt is None or voc_config is None or voc_stat is None:
use_pretrained_voc = True
else:
use_pretrained_voc = False
voc_tag = voc + '-' + lang voc_tag = voc + '-' + lang
self.task_resource.set_task_model( self.task_resource.set_task_model(
model_tag=voc_tag, model_tag=voc_tag,
model_type=1, # vocoder model_type=1, # vocoder
skip_download=not use_pretrained_voc,
version=None, # default version version=None, # default version
) )
if voc_ckpt is None or voc_config is None or voc_stat is None: if use_pretrained_voc:
self.voc_res_path = self.task_resource.voc_res_dir self.voc_res_path = self.task_resource.voc_res_dir
self.voc_config = os.path.join( self.voc_config = os.path.join(
self.voc_res_path, self.task_resource.voc_res_dict['config']) self.voc_res_path, self.task_resource.voc_res_dict['config'])

@ -68,13 +68,19 @@ class TTSServerExecutor(TTSExecutor):
logger.info('Models had been initialized.') logger.info('Models had been initialized.')
return return
# am # am
if am_model is None or am_params is None or phones_dict is None:
use_pretrained_am = True
else:
use_pretrained_am = False
am_tag = am + '-' + lang am_tag = am + '-' + lang
self.task_resource.set_task_model( self.task_resource.set_task_model(
model_tag=am_tag, model_tag=am_tag,
model_type=0, # am model_type=0, # am
skip_download=not use_pretrained_am,
version=None, # default version version=None, # default version
) )
if am_model is None or am_params is None or phones_dict is None: if use_pretrained_am:
self.am_res_path = self.task_resource.res_dir self.am_res_path = self.task_resource.res_dir
self.am_model = os.path.join(self.am_res_path, self.am_model = os.path.join(self.am_res_path,
self.task_resource.res_dict['model']) self.task_resource.res_dict['model'])
@ -113,13 +119,19 @@ class TTSServerExecutor(TTSExecutor):
self.speaker_dict = speaker_dict self.speaker_dict = speaker_dict
# voc # voc
if voc_model is None or voc_params is None:
use_pretrained_voc = True
else:
use_pretrained_voc = False
voc_tag = voc + '-' + lang voc_tag = voc + '-' + lang
self.task_resource.set_task_model( self.task_resource.set_task_model(
model_tag=voc_tag, model_tag=voc_tag,
model_type=1, # vocoder model_type=1, # vocoder
skip_download=not use_pretrained_voc,
version=None, # default version version=None, # default version
) )
if voc_model is None or voc_params is None: if use_pretrained_voc:
self.voc_res_path = self.task_resource.voc_res_dir self.voc_res_path = self.task_resource.voc_res_dir
self.voc_model = os.path.join( self.voc_model = os.path.join(
self.voc_res_path, self.task_resource.voc_res_dict['model']) self.voc_res_path, self.task_resource.voc_res_dict['model'])

Loading…
Cancel
Save