|
|
@ -12,7 +12,9 @@
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
# limitations under the License.
|
|
|
|
import argparse
|
|
|
|
import argparse
|
|
|
|
|
|
|
|
import ast
|
|
|
|
import os
|
|
|
|
import os
|
|
|
|
|
|
|
|
import sys
|
|
|
|
from typing import List
|
|
|
|
from typing import List
|
|
|
|
from typing import Optional
|
|
|
|
from typing import Optional
|
|
|
|
from typing import Union
|
|
|
|
from typing import Union
|
|
|
@ -109,6 +111,11 @@ class CLSExecutor(BaseExecutor):
|
|
|
|
type=str,
|
|
|
|
type=str,
|
|
|
|
default=paddle.get_device(),
|
|
|
|
default=paddle.get_device(),
|
|
|
|
help='Choose device to execute model inference.')
|
|
|
|
help='Choose device to execute model inference.')
|
|
|
|
|
|
|
|
self.parser.add_argument(
|
|
|
|
|
|
|
|
'--job_dump_result',
|
|
|
|
|
|
|
|
type=ast.literal_eval,
|
|
|
|
|
|
|
|
default=False,
|
|
|
|
|
|
|
|
help='Save job result into file.')
|
|
|
|
|
|
|
|
|
|
|
|
def _get_pretrained_path(self, tag: str) -> os.PathLike:
|
|
|
|
def _get_pretrained_path(self, tag: str) -> os.PathLike:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -214,7 +221,7 @@ class CLSExecutor(BaseExecutor):
|
|
|
|
ret = ''
|
|
|
|
ret = ''
|
|
|
|
for idx in topk_idx:
|
|
|
|
for idx in topk_idx:
|
|
|
|
label, score = self._label_list[idx], result[idx]
|
|
|
|
label, score = self._label_list[idx], result[idx]
|
|
|
|
ret += f'{label}: {score}\n'
|
|
|
|
ret += f'{label} {score}\n'
|
|
|
|
return ret
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
|
|
def postprocess(self, topk: int) -> Union[str, os.PathLike]:
|
|
|
|
def postprocess(self, topk: int) -> Union[str, os.PathLike]:
|
|
|
@ -234,22 +241,36 @@ class CLSExecutor(BaseExecutor):
|
|
|
|
label_file = parser_args.label_file
|
|
|
|
label_file = parser_args.label_file
|
|
|
|
cfg_path = parser_args.config
|
|
|
|
cfg_path = parser_args.config
|
|
|
|
ckpt_path = parser_args.ckpt_path
|
|
|
|
ckpt_path = parser_args.ckpt_path
|
|
|
|
audio_file = parser_args.input
|
|
|
|
input_file = parser_args.input
|
|
|
|
topk = parser_args.topk
|
|
|
|
topk = parser_args.topk
|
|
|
|
device = parser_args.device
|
|
|
|
device = parser_args.device
|
|
|
|
|
|
|
|
job_dump_result = parser_args.job_dump_result
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
res = self(audio_file, model_type, cfg_path, ckpt_path, label_file,
|
|
|
|
if job_dump_result:
|
|
|
|
topk, device)
|
|
|
|
assert self._is_job_input(
|
|
|
|
logger.info('CLS Result:\n{}'.format(res))
|
|
|
|
input_file
|
|
|
|
|
|
|
|
), 'Input file should be a job file(*.job) when `job_dump_result` is True.'
|
|
|
|
|
|
|
|
job_output_file = os.path.abspath(input_file) + '.done'
|
|
|
|
|
|
|
|
sys.stdout = open(job_output_file, 'w')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(
|
|
|
|
|
|
|
|
self(input_file, model_type, cfg_path, ckpt_path, label_file,
|
|
|
|
|
|
|
|
topk, device))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if job_dump_result:
|
|
|
|
|
|
|
|
logger.info(f'Results had been saved to: {job_output_file}')
|
|
|
|
|
|
|
|
|
|
|
|
return True
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
logger.exception(e)
|
|
|
|
logger.exception(e)
|
|
|
|
return False
|
|
|
|
return False
|
|
|
|
|
|
|
|
finally:
|
|
|
|
|
|
|
|
sys.stdout.close()
|
|
|
|
|
|
|
|
|
|
|
|
@stats_wrapper
|
|
|
|
@stats_wrapper
|
|
|
|
def __call__(self,
|
|
|
|
def __call__(self,
|
|
|
|
audio_file: os.PathLike,
|
|
|
|
input_file: os.PathLike,
|
|
|
|
model: str='panns_cnn14',
|
|
|
|
model: str='panns_cnn14',
|
|
|
|
config: Optional[os.PathLike]=None,
|
|
|
|
config: Optional[os.PathLike]=None,
|
|
|
|
ckpt_path: Optional[os.PathLike]=None,
|
|
|
|
ckpt_path: Optional[os.PathLike]=None,
|
|
|
@ -259,10 +280,23 @@ class CLSExecutor(BaseExecutor):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Python API to call an executor.
|
|
|
|
Python API to call an executor.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
audio_file = os.path.abspath(audio_file)
|
|
|
|
input_file = os.path.abspath(input_file)
|
|
|
|
paddle.set_device(device)
|
|
|
|
paddle.set_device(device)
|
|
|
|
self._init_from_path(model, config, ckpt_path, label_file)
|
|
|
|
self._init_from_path(model, config, ckpt_path, label_file)
|
|
|
|
self.preprocess(audio_file)
|
|
|
|
|
|
|
|
|
|
|
|
if self._is_job_input(input_file): # *.job
|
|
|
|
|
|
|
|
job_outputs = {}
|
|
|
|
|
|
|
|
job_contents = self._job_preprocess(input_file)
|
|
|
|
|
|
|
|
for id_, file in job_contents.items():
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
self.preprocess(file)
|
|
|
|
|
|
|
|
self.infer()
|
|
|
|
|
|
|
|
job_outputs[id_] = self.postprocess(topk).strip()
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
|
|
job_outputs[id_] = f'{e.__class__.__name__}: {e}'
|
|
|
|
|
|
|
|
res = self._job_postprecess(job_outputs)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
self.preprocess(input_file)
|
|
|
|
self.infer()
|
|
|
|
self.infer()
|
|
|
|
res = self.postprocess(topk) # Retrieve result of cls.
|
|
|
|
res = self.postprocess(topk) # Retrieve result of cls.
|
|
|
|
|
|
|
|
|
|
|
|