From 7d4f320836c0e03e1d20eda5a631446d6516b9fb Mon Sep 17 00:00:00 2001 From: lym0302 Date: Tue, 21 Jun 2022 07:31:32 +0000 Subject: [PATCH] fix_model_init, test=doc --- .../engine/tts/online/onnx/tts_engine.py | 33 +++++++++++-------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/paddlespeech/server/engine/tts/online/onnx/tts_engine.py b/paddlespeech/server/engine/tts/online/onnx/tts_engine.py index cb9155a2..f64287af 100644 --- a/paddlespeech/server/engine/tts/online/onnx/tts_engine.py +++ b/paddlespeech/server/engine/tts/online/onnx/tts_engine.py @@ -66,15 +66,15 @@ class TTSServerExecutor(TTSExecutor): return # am am_tag = am + '-' + lang - self.task_resource.set_task_model( - model_tag=am_tag, - model_type=0, # am - version=None, # default version - ) - self.am_res_path = self.task_resource.res_dir if am == "fastspeech2_csmsc_onnx": # get model info if am_ckpt is None or phones_dict is None: + self.task_resource.set_task_model( + model_tag=am_tag, + model_type=0, # am + version=None, # default version + ) + self.am_res_path = self.task_resource.res_dir self.am_ckpt = os.path.join( self.am_res_path, self.task_resource.res_dict['ckpt'][0]) # must have phones_dict in acoustic @@ -86,13 +86,19 @@ class TTSServerExecutor(TTSExecutor): self.am_ckpt = os.path.abspath(am_ckpt[0]) self.phones_dict = os.path.abspath(phones_dict) self.am_res_path = os.path.dirname( - os.path.abspath(self.am_ckpt)) + os.path.abspath(am_ckpt)) # create am sess self.am_sess = get_sess(self.am_ckpt, am_sess_conf) elif am == "fastspeech2_cnndecoder_csmsc_onnx": if am_ckpt is None or am_stat is None or phones_dict is None: + self.task_resource.set_task_model( + model_tag=am_tag, + model_type=0, # am + version=None, # default version + ) + self.am_res_path = self.task_resource.res_dir self.am_encoder_infer = os.path.join( self.am_res_path, self.task_resource.res_dict['ckpt'][0]) self.am_decoder = os.path.join( @@ -114,7 +120,7 @@ class TTSServerExecutor(TTSExecutor): self.phones_dict = os.path.abspath(phones_dict) self.am_stat = os.path.abspath(am_stat) self.am_res_path = os.path.dirname( - os.path.abspath(self.am_ckpt)) + os.path.abspath(am_ckpt[0])) # create am sess self.am_encoder_infer_sess = get_sess(self.am_encoder_infer, @@ -130,12 +136,13 @@ class TTSServerExecutor(TTSExecutor): # voc model info voc_tag = voc + '-' + lang - self.task_resource.set_task_model( - model_tag=voc_tag, - model_type=1, # vocoder - version=None, # default version - ) + if voc_ckpt is None: + self.task_resource.set_task_model( + model_tag=voc_tag, + model_type=1, # vocoder + version=None, # default version + ) self.voc_res_path = self.task_resource.voc_res_dir self.voc_ckpt = os.path.join( self.voc_res_path, self.task_resource.voc_res_dict['ckpt'])