Merge pull request #2056 from lym0302/develop

[server] fix_model_init
pull/2061/head
TianYuan 3 years ago committed by GitHub
commit 46ff848d66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -66,15 +66,15 @@ class TTSServerExecutor(TTSExecutor):
return
# am
am_tag = am + '-' + lang
self.task_resource.set_task_model(
model_tag=am_tag,
model_type=0, # am
version=None, # default version
)
self.am_res_path = self.task_resource.res_dir
if am == "fastspeech2_csmsc_onnx":
# get model info
if am_ckpt is None or phones_dict is None:
self.task_resource.set_task_model(
model_tag=am_tag,
model_type=0, # am
version=None, # default version
)
self.am_res_path = self.task_resource.res_dir
self.am_ckpt = os.path.join(
self.am_res_path, self.task_resource.res_dict['ckpt'][0])
# must have phones_dict in acoustic
@ -86,13 +86,19 @@ class TTSServerExecutor(TTSExecutor):
self.am_ckpt = os.path.abspath(am_ckpt[0])
self.phones_dict = os.path.abspath(phones_dict)
self.am_res_path = os.path.dirname(
os.path.abspath(self.am_ckpt))
os.path.abspath(am_ckpt))
# create am sess
self.am_sess = get_sess(self.am_ckpt, am_sess_conf)
elif am == "fastspeech2_cnndecoder_csmsc_onnx":
if am_ckpt is None or am_stat is None or phones_dict is None:
self.task_resource.set_task_model(
model_tag=am_tag,
model_type=0, # am
version=None, # default version
)
self.am_res_path = self.task_resource.res_dir
self.am_encoder_infer = os.path.join(
self.am_res_path, self.task_resource.res_dict['ckpt'][0])
self.am_decoder = os.path.join(
@ -114,7 +120,7 @@ class TTSServerExecutor(TTSExecutor):
self.phones_dict = os.path.abspath(phones_dict)
self.am_stat = os.path.abspath(am_stat)
self.am_res_path = os.path.dirname(
os.path.abspath(self.am_ckpt))
os.path.abspath(am_ckpt[0]))
# create am sess
self.am_encoder_infer_sess = get_sess(self.am_encoder_infer,
@ -130,12 +136,13 @@ class TTSServerExecutor(TTSExecutor):
# voc model info
voc_tag = voc + '-' + lang
self.task_resource.set_task_model(
model_tag=voc_tag,
model_type=1, # vocoder
version=None, # default version
)
if voc_ckpt is None:
self.task_resource.set_task_model(
model_tag=voc_tag,
model_type=1, # vocoder
version=None, # default version
)
self.voc_res_path = self.task_resource.voc_res_dir
self.voc_ckpt = os.path.join(
self.voc_res_path, self.task_resource.voc_res_dict['ckpt'])

Loading…
Cancel
Save