Support batch input in cls task.

pull/1460/head
KP 2 years ago
parent a8c3f6d479
commit 1818b058aa

@ -65,7 +65,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/cat.wav https://paddlespe
config=None, # Set `config` and `ckpt_path` to None to use pretrained model.
label_file=None,
ckpt_path=None,
audio_file='./cat.wav',
input_file='./cat.wav',
topk=10,
device=paddle.get_device())
print('CLS Result: \n{}'.format(result))

@ -65,7 +65,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/cat.wav https://paddlespe
config=None, # Set `config` and `ckpt_path` to None to use pretrained model.
label_file=None,
ckpt_path=None,
audio_file='./cat.wav',
input_file='./cat.wav',
topk=10,
device=paddle.get_device())
print('CLS Result: \n{}'.format(result))

@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import os
import sys
from typing import List
from typing import Optional
from typing import Union
@ -109,6 +111,11 @@ class CLSExecutor(BaseExecutor):
type=str,
default=paddle.get_device(),
help='Choose device to execute model inference.')
self.parser.add_argument(
'--job_dump_result',
type=ast.literal_eval,
default=False,
help='Save job result into file.')
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
@ -214,7 +221,7 @@ class CLSExecutor(BaseExecutor):
ret = ''
for idx in topk_idx:
label, score = self._label_list[idx], result[idx]
ret += f'{label}: {score}\n'
ret += f'{label} {score}\n'
return ret
def postprocess(self, topk: int) -> Union[str, os.PathLike]:
@ -234,22 +241,36 @@ class CLSExecutor(BaseExecutor):
label_file = parser_args.label_file
cfg_path = parser_args.config
ckpt_path = parser_args.ckpt_path
audio_file = parser_args.input
input_file = parser_args.input
topk = parser_args.topk
device = parser_args.device
job_dump_result = parser_args.job_dump_result
try:
res = self(audio_file, model_type, cfg_path, ckpt_path, label_file,
topk, device)
logger.info('CLS Result:\n{}'.format(res))
if job_dump_result:
assert self._is_job_input(
input_file
), 'Input file should be a job file(*.job) when `job_dump_result` is True.'
job_output_file = os.path.abspath(input_file) + '.done'
sys.stdout = open(job_output_file, 'w')
print(
self(input_file, model_type, cfg_path, ckpt_path, label_file,
topk, device))
if job_dump_result:
logger.info(f'Results had been saved to: {job_output_file}')
return True
except Exception as e:
logger.exception(e)
return False
finally:
sys.stdout.close()
@stats_wrapper
def __call__(self,
audio_file: os.PathLike,
input_file: os.PathLike,
model: str='panns_cnn14',
config: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None,
@ -259,11 +280,24 @@ class CLSExecutor(BaseExecutor):
"""
Python API to call an executor.
"""
audio_file = os.path.abspath(audio_file)
input_file = os.path.abspath(input_file)
paddle.set_device(device)
self._init_from_path(model, config, ckpt_path, label_file)
self.preprocess(audio_file)
self.infer()
res = self.postprocess(topk) # Retrieve result of cls.
if self._is_job_input(input_file): # *.job
job_outputs = {}
job_contents = self._job_preprocess(input_file)
for id_, file in job_contents.items():
try:
self.preprocess(file)
self.infer()
job_outputs[id_] = self.postprocess(topk).strip()
except Exception as e:
job_outputs[id_] = f'{e.__class__.__name__}: {e}'
res = self._job_postprecess(job_outputs)
else:
self.preprocess(input_file)
self.infer()
res = self.postprocess(topk) # Retrieve result of cls.
return res

@ -15,6 +15,7 @@ import os
from abc import ABC
from abc import abstractmethod
from typing import Any
from typing import Dict
from typing import List
from typing import Union
@ -100,3 +101,50 @@ class BaseExecutor(ABC):
Python API to call an executor.
"""
pass
def _is_job_input(self, input_: Union[str, os.PathLike]) -> bool:
"""
Check if current input file is a job input or not.
Args:
input_ (Union[str, os.PathLike]): Input file of current task.
Returns:
bool: return `True` for job input, `False` otherwise.
"""
return os.path.isfile(input_) and input_.endswith('.job')
def _job_preprocess(self, job_input: os.PathLike) -> Dict[str, str]:
"""
Read a job input file and return its contents in a dictionary.
Args:
job_input (os.PathLike): The job input file.
Returns:
Dict[str, str]: Contents of job input.
"""
job_contents = {}
with open(job_input) as f:
for line in f:
line = line.strip()
if not line:
continue
k, v = line.split(' ')
job_contents[k] = v
return job_contents
def _job_postprecess(self, job_outputs: Dict[str, str]) -> str:
"""
Convert job results to string.
Args:
job_outputs (Dict[str, str]): A dictionary with job ids and results.
Returns:
str: A string object contains job outputs.
"""
ret = ''
for k, v in job_outputs.items():
ret += f'{k} {v}\n'
return ret

Loading…
Cancel
Save