fix_model_init, test=doc

pull/2056/head
lym0302 3 years ago
parent 0ea9def0b8
commit 7d4f320836

@ -66,15 +66,15 @@ class TTSServerExecutor(TTSExecutor):
return return
# am # am
am_tag = am + '-' + lang am_tag = am + '-' + lang
if am == "fastspeech2_csmsc_onnx":
# get model info
if am_ckpt is None or phones_dict is None:
self.task_resource.set_task_model( self.task_resource.set_task_model(
model_tag=am_tag, model_tag=am_tag,
model_type=0, # am model_type=0, # am
version=None, # default version version=None, # default version
) )
self.am_res_path = self.task_resource.res_dir self.am_res_path = self.task_resource.res_dir
if am == "fastspeech2_csmsc_onnx":
# get model info
if am_ckpt is None or phones_dict is None:
self.am_ckpt = os.path.join( self.am_ckpt = os.path.join(
self.am_res_path, self.task_resource.res_dict['ckpt'][0]) self.am_res_path, self.task_resource.res_dict['ckpt'][0])
# must have phones_dict in acoustic # must have phones_dict in acoustic
@ -86,13 +86,19 @@ class TTSServerExecutor(TTSExecutor):
self.am_ckpt = os.path.abspath(am_ckpt[0]) self.am_ckpt = os.path.abspath(am_ckpt[0])
self.phones_dict = os.path.abspath(phones_dict) self.phones_dict = os.path.abspath(phones_dict)
self.am_res_path = os.path.dirname( self.am_res_path = os.path.dirname(
os.path.abspath(self.am_ckpt)) os.path.abspath(am_ckpt))
# create am sess # create am sess
self.am_sess = get_sess(self.am_ckpt, am_sess_conf) self.am_sess = get_sess(self.am_ckpt, am_sess_conf)
elif am == "fastspeech2_cnndecoder_csmsc_onnx": elif am == "fastspeech2_cnndecoder_csmsc_onnx":
if am_ckpt is None or am_stat is None or phones_dict is None: if am_ckpt is None or am_stat is None or phones_dict is None:
self.task_resource.set_task_model(
model_tag=am_tag,
model_type=0, # am
version=None, # default version
)
self.am_res_path = self.task_resource.res_dir
self.am_encoder_infer = os.path.join( self.am_encoder_infer = os.path.join(
self.am_res_path, self.task_resource.res_dict['ckpt'][0]) self.am_res_path, self.task_resource.res_dict['ckpt'][0])
self.am_decoder = os.path.join( self.am_decoder = os.path.join(
@ -114,7 +120,7 @@ class TTSServerExecutor(TTSExecutor):
self.phones_dict = os.path.abspath(phones_dict) self.phones_dict = os.path.abspath(phones_dict)
self.am_stat = os.path.abspath(am_stat) self.am_stat = os.path.abspath(am_stat)
self.am_res_path = os.path.dirname( self.am_res_path = os.path.dirname(
os.path.abspath(self.am_ckpt)) os.path.abspath(am_ckpt[0]))
# create am sess # create am sess
self.am_encoder_infer_sess = get_sess(self.am_encoder_infer, self.am_encoder_infer_sess = get_sess(self.am_encoder_infer,
@ -130,12 +136,13 @@ class TTSServerExecutor(TTSExecutor):
# voc model info # voc model info
voc_tag = voc + '-' + lang voc_tag = voc + '-' + lang
if voc_ckpt is None:
self.task_resource.set_task_model( self.task_resource.set_task_model(
model_tag=voc_tag, model_tag=voc_tag,
model_type=1, # vocoder model_type=1, # vocoder
version=None, # default version version=None, # default version
) )
if voc_ckpt is None:
self.voc_res_path = self.task_resource.voc_res_dir self.voc_res_path = self.task_resource.voc_res_dir
self.voc_ckpt = os.path.join( self.voc_ckpt = os.path.join(
self.voc_res_path, self.task_resource.voc_res_dict['ckpt']) self.voc_res_path, self.task_resource.voc_res_dict['ckpt'])

Loading…
Cancel
Save