Optimize model init.

pull/1083/head
KP 3 years ago
parent 528c70e515
commit 61e39daccc

@ -137,6 +137,10 @@ class ASRExecutor(BaseExecutor):
""" """
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
if hasattr(self, 'model'):
logger.info('Model had been initialized.')
return
if cfg_path is None or ckpt_path is None: if cfg_path is None or ckpt_path is None:
sample_rate_str = '16k' if sample_rate == 16000 else '8k' sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '_' + lang + '_' + sample_rate_str tag = model_type + '_' + lang + '_' + sample_rate_str

@ -128,6 +128,10 @@ class CLSExecutor(BaseExecutor):
""" """
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
if hasattr(self, 'model'):
logger.info('Model had been initialized.')
return
if label_file is None or ckpt_path is None: if label_file is None or ckpt_path is None:
self.res_path = self._get_pretrained_path(model_type) # panns_cnn14 self.res_path = self._get_pretrained_path(model_type) # panns_cnn14
self.cfg_path = os.path.join( self.cfg_path = os.path.join(
@ -154,9 +158,9 @@ class CLSExecutor(BaseExecutor):
# model # model
model_class = dynamic_import(model_type, model_alias) model_class = dynamic_import(model_type, model_alias)
model_dict = paddle.load(self.ckpt_path) model_dict = paddle.load(self.ckpt_path)
self._model = model_class(extract_embedding=False) self.model = model_class(extract_embedding=False)
self._model.set_state_dict(model_dict) self.model.set_state_dict(model_dict)
self._model.eval() self.model.eval()
def preprocess(self, audio_file: Union[str, os.PathLike]): def preprocess(self, audio_file: Union[str, os.PathLike]):
""" """
@ -192,7 +196,7 @@ class CLSExecutor(BaseExecutor):
""" """
Model inference and result stored in self.output. Model inference and result stored in self.output.
""" """
self._outputs['logits'] = self._model(self._inputs['feats']) self._outputs['logits'] = self.model(self._inputs['feats'])
def _generate_topk_label(self, result: np.ndarray, topk: int) -> str: def _generate_topk_label(self, result: np.ndarray, topk: int) -> str:
assert topk <= len( assert topk <= len(

Loading…
Cancel
Save