Support batch input in cls task.

pull/1460/head
KP 3 years ago
parent a8c3f6d479
commit 1818b058aa

@ -65,7 +65,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/cat.wav https://paddlespe
config=None, # Set `config` and `ckpt_path` to None to use pretrained model. config=None, # Set `config` and `ckpt_path` to None to use pretrained model.
label_file=None, label_file=None,
ckpt_path=None, ckpt_path=None,
audio_file='./cat.wav', input_file='./cat.wav',
topk=10, topk=10,
device=paddle.get_device()) device=paddle.get_device())
print('CLS Result: \n{}'.format(result)) print('CLS Result: \n{}'.format(result))

@ -65,7 +65,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/cat.wav https://paddlespe
config=None, # Set `config` and `ckpt_path` to None to use pretrained model. config=None, # Set `config` and `ckpt_path` to None to use pretrained model.
label_file=None, label_file=None,
ckpt_path=None, ckpt_path=None,
audio_file='./cat.wav', input_file='./cat.wav',
topk=10, topk=10,
device=paddle.get_device()) device=paddle.get_device())
print('CLS Result: \n{}'.format(result)) print('CLS Result: \n{}'.format(result))

@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import ast
import os import os
import sys
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Union from typing import Union
@ -109,6 +111,11 @@ class CLSExecutor(BaseExecutor):
type=str, type=str,
default=paddle.get_device(), default=paddle.get_device(),
help='Choose device to execute model inference.') help='Choose device to execute model inference.')
self.parser.add_argument(
'--job_dump_result',
type=ast.literal_eval,
default=False,
help='Save job result into file.')
def _get_pretrained_path(self, tag: str) -> os.PathLike: def _get_pretrained_path(self, tag: str) -> os.PathLike:
""" """
@ -214,7 +221,7 @@ class CLSExecutor(BaseExecutor):
ret = '' ret = ''
for idx in topk_idx: for idx in topk_idx:
label, score = self._label_list[idx], result[idx] label, score = self._label_list[idx], result[idx]
ret += f'{label}: {score}\n' ret += f'{label} {score}\n'
return ret return ret
def postprocess(self, topk: int) -> Union[str, os.PathLike]: def postprocess(self, topk: int) -> Union[str, os.PathLike]:
@ -234,22 +241,36 @@ class CLSExecutor(BaseExecutor):
label_file = parser_args.label_file label_file = parser_args.label_file
cfg_path = parser_args.config cfg_path = parser_args.config
ckpt_path = parser_args.ckpt_path ckpt_path = parser_args.ckpt_path
audio_file = parser_args.input input_file = parser_args.input
topk = parser_args.topk topk = parser_args.topk
device = parser_args.device device = parser_args.device
job_dump_result = parser_args.job_dump_result
try: try:
res = self(audio_file, model_type, cfg_path, ckpt_path, label_file, if job_dump_result:
topk, device) assert self._is_job_input(
logger.info('CLS Result:\n{}'.format(res)) input_file
), 'Input file should be a job file(*.job) when `job_dump_result` is True.'
job_output_file = os.path.abspath(input_file) + '.done'
sys.stdout = open(job_output_file, 'w')
print(
self(input_file, model_type, cfg_path, ckpt_path, label_file,
topk, device))
if job_dump_result:
logger.info(f'Results had been saved to: {job_output_file}')
return True return True
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
return False return False
finally:
sys.stdout.close()
@stats_wrapper @stats_wrapper
def __call__(self, def __call__(self,
audio_file: os.PathLike, input_file: os.PathLike,
model: str='panns_cnn14', model: str='panns_cnn14',
config: Optional[os.PathLike]=None, config: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None, ckpt_path: Optional[os.PathLike]=None,
@ -259,11 +280,24 @@ class CLSExecutor(BaseExecutor):
""" """
Python API to call an executor. Python API to call an executor.
""" """
audio_file = os.path.abspath(audio_file) input_file = os.path.abspath(input_file)
paddle.set_device(device) paddle.set_device(device)
self._init_from_path(model, config, ckpt_path, label_file) self._init_from_path(model, config, ckpt_path, label_file)
self.preprocess(audio_file)
self.infer() if self._is_job_input(input_file): # *.job
res = self.postprocess(topk) # Retrieve result of cls. job_outputs = {}
job_contents = self._job_preprocess(input_file)
for id_, file in job_contents.items():
try:
self.preprocess(file)
self.infer()
job_outputs[id_] = self.postprocess(topk).strip()
except Exception as e:
job_outputs[id_] = f'{e.__class__.__name__}: {e}'
res = self._job_postprecess(job_outputs)
else:
self.preprocess(input_file)
self.infer()
res = self.postprocess(topk) # Retrieve result of cls.
return res return res

@ -15,6 +15,7 @@ import os
from abc import ABC from abc import ABC
from abc import abstractmethod from abc import abstractmethod
from typing import Any from typing import Any
from typing import Dict
from typing import List from typing import List
from typing import Union from typing import Union
@ -100,3 +101,50 @@ class BaseExecutor(ABC):
Python API to call an executor. Python API to call an executor.
""" """
pass pass
def _is_job_input(self, input_: Union[str, os.PathLike]) -> bool:
"""
Check if current input file is a job input or not.
Args:
input_ (Union[str, os.PathLike]): Input file of current task.
Returns:
bool: return `True` for job input, `False` otherwise.
"""
return os.path.isfile(input_) and input_.endswith('.job')
def _job_preprocess(self, job_input: os.PathLike) -> Dict[str, str]:
"""
Read a job input file and return its contents in a dictionary.
Args:
job_input (os.PathLike): The job input file.
Returns:
Dict[str, str]: Contents of job input.
"""
job_contents = {}
with open(job_input) as f:
for line in f:
line = line.strip()
if not line:
continue
k, v = line.split(' ')
job_contents[k] = v
return job_contents
def _job_postprecess(self, job_outputs: Dict[str, str]) -> str:
"""
Convert job results to string.
Args:
job_outputs (Dict[str, str]): A dictionary with job ids and results.
Returns:
str: A string object contains job outputs.
"""
ret = ''
for k, v in job_outputs.items():
ret += f'{k} {v}\n'
return ret

Loading…
Cancel
Save