Update recommended model to cnn14 and argument name in __call__.

pull/1085/head
KP 3 years ago
parent 0b7e0d1e2e
commit 6c1e6e7876

@ -80,7 +80,7 @@ class CLSExecutor(BaseExecutor):
self.parser.add_argument( self.parser.add_argument(
'--model', '--model',
type=str, type=str,
default='panns_cnn10', default='panns_cnn14',
help='Choose model type of cls task.') help='Choose model type of cls task.')
self.parser.add_argument( self.parser.add_argument(
'--config', '--config',
@ -127,8 +127,8 @@ class CLSExecutor(BaseExecutor):
def _init_from_path(self, def _init_from_path(self,
model_type: str='panns_cnn14', model_type: str='panns_cnn14',
cfg_path: Optional[os.PathLike]=None, cfg_path: Optional[os.PathLike]=None,
label_file: Optional[os.PathLike]=None, ckpt_path: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None): label_file: Optional[os.PathLike]=None):
""" """
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
@ -244,15 +244,15 @@ class CLSExecutor(BaseExecutor):
logger.exception(e) logger.exception(e)
return False return False
def __call__(self, model_type, cfg_path, label_file, ckpt_path, audio_file, def __call__(self, model, config, ckpt_path, label_file, audio_file, topk,
topk, device): device):
""" """
Python API to call an executor. Python API to call an executor.
""" """
audio_file = os.path.abspath(audio_file) audio_file = os.path.abspath(audio_file)
# self._check(audio_file, sample_rate) # self._check(audio_file, sample_rate)
paddle.set_device(device) paddle.set_device(device)
self._init_from_path(model_type, cfg_path, label_file, ckpt_path) self._init_from_path(model, config, ckpt_path, label_file)
self.preprocess(audio_file) self.preprocess(audio_file)
self.infer() self.infer()
res = self.postprocess(topk) # Retrieve result of cls. res = self.postprocess(topk) # Retrieve result of cls.

Loading…
Cancel
Save