|
|
|
@ -128,6 +128,10 @@ class CLSExecutor(BaseExecutor):
|
|
|
|
|
"""
|
|
|
|
|
Init model and other resources from a specific path.
|
|
|
|
|
"""
|
|
|
|
|
if hasattr(self, 'model'):
|
|
|
|
|
logger.info('Model had been initialized.')
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if label_file is None or ckpt_path is None:
|
|
|
|
|
self.res_path = self._get_pretrained_path(model_type) # panns_cnn14
|
|
|
|
|
self.cfg_path = os.path.join(
|
|
|
|
@ -154,9 +158,9 @@ class CLSExecutor(BaseExecutor):
|
|
|
|
|
# model
|
|
|
|
|
model_class = dynamic_import(model_type, model_alias)
|
|
|
|
|
model_dict = paddle.load(self.ckpt_path)
|
|
|
|
|
self._model = model_class(extract_embedding=False)
|
|
|
|
|
self._model.set_state_dict(model_dict)
|
|
|
|
|
self._model.eval()
|
|
|
|
|
self.model = model_class(extract_embedding=False)
|
|
|
|
|
self.model.set_state_dict(model_dict)
|
|
|
|
|
self.model.eval()
|
|
|
|
|
|
|
|
|
|
def preprocess(self, audio_file: Union[str, os.PathLike]):
|
|
|
|
|
"""
|
|
|
|
@ -192,7 +196,7 @@ class CLSExecutor(BaseExecutor):
|
|
|
|
|
"""
|
|
|
|
|
Model inference and result stored in self.output.
|
|
|
|
|
"""
|
|
|
|
|
self._outputs['logits'] = self._model(self._inputs['feats'])
|
|
|
|
|
self._outputs['logits'] = self.model(self._inputs['feats'])
|
|
|
|
|
|
|
|
|
|
def _generate_topk_label(self, result: np.ndarray, topk: int) -> str:
|
|
|
|
|
assert topk <= len(
|
|
|
|
|