Update batch input and stdin input.

pull/1460/head
KP 2 years ago
parent 1818b058aa
commit 05288fe1c3

@ -14,7 +14,7 @@
import argparse
import ast
import os
import sys
from collections import OrderedDict
from typing import List
from typing import Optional
from typing import Union
@ -79,7 +79,7 @@ class CLSExecutor(BaseExecutor):
self.parser = argparse.ArgumentParser(
prog='paddlespeech.cls', add_help=True)
self.parser.add_argument(
'--input', type=str, required=True, help='Audio file to classify.')
'--input', type=str, default=None, help='Audio file to classify.')
self.parser.add_argument(
'--model',
type=str,
@ -221,7 +221,7 @@ class CLSExecutor(BaseExecutor):
ret = ''
for idx in topk_idx:
label, score = self._label_list[idx], result[idx]
ret += f'{label} {score}\n'
ret += f'{label} {score} '
return ret
def postprocess(self, topk: int) -> Union[str, os.PathLike]:
@ -241,36 +241,34 @@ class CLSExecutor(BaseExecutor):
label_file = parser_args.label_file
cfg_path = parser_args.config
ckpt_path = parser_args.ckpt_path
input_file = parser_args.input
topk = parser_args.topk
device = parser_args.device
job_dump_result = parser_args.job_dump_result
try:
if job_dump_result:
assert self._is_job_input(
input_file
), 'Input file should be a job file(*.job) when `job_dump_result` is True.'
job_output_file = os.path.abspath(input_file) + '.done'
sys.stdout = open(job_output_file, 'w')
task_source = self.get_task_source(parser_args.input)
task_results = OrderedDict()
has_exceptions = False
print(
self(input_file, model_type, cfg_path, ckpt_path, label_file,
topk, device))
for id_, input_ in task_source.items():
try:
res = self(input_, model_type, cfg_path, ckpt_path, label_file,
topk, device)
task_results[id_] = res
except Exception as e:
has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
if job_dump_result:
logger.info(f'Results had been saved to: {job_output_file}')
self.process_task_results(parser_args.input, task_results,
job_dump_result)
return True
except Exception as e:
logger.exception(e)
if has_exceptions:
return False
finally:
sys.stdout.close()
else:
return True
@stats_wrapper
def __call__(self,
input_file: os.PathLike,
audio_file: os.PathLike,
model: str='panns_cnn14',
config: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None,
@ -280,24 +278,11 @@ class CLSExecutor(BaseExecutor):
"""
Python API to call an executor.
"""
input_file = os.path.abspath(input_file)
audio_file = os.path.abspath(os.path.expanduser(audio_file))
paddle.set_device(device)
self._init_from_path(model, config, ckpt_path, label_file)
if self._is_job_input(input_file): # *.job
job_outputs = {}
job_contents = self._job_preprocess(input_file)
for id_, file in job_contents.items():
try:
self.preprocess(file)
self.infer()
job_outputs[id_] = self.postprocess(topk).strip()
except Exception as e:
job_outputs[id_] = f'{e.__class__.__name__}: {e}'
res = self._job_postprecess(job_outputs)
else:
self.preprocess(input_file)
self.infer()
res = self.postprocess(topk) # Retrieve result of cls.
self.preprocess(audio_file)
self.infer()
res = self.postprocess(topk) # Retrieve result of cls.
return res

@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from abc import ABC
from abc import abstractmethod
from collections import OrderedDict
from typing import Any
from typing import Dict
from typing import List
@ -21,6 +23,8 @@ from typing import Union
import paddle
from .log import logger
class BaseExecutor(ABC):
"""
@ -28,8 +32,8 @@ class BaseExecutor(ABC):
"""
def __init__(self):
self._inputs = dict()
self._outputs = dict()
self._inputs = OrderedDict()
self._outputs = OrderedDict()
@abstractmethod
def _get_pretrained_path(self, tag: str) -> os.PathLike:
@ -102,6 +106,61 @@ class BaseExecutor(ABC):
"""
pass
def get_task_source(self, input_: Union[str, os.PathLike, None]
) -> Dict[str, Union[str, os.PathLike]]:
"""
Get task input source from command line input.
Args:
input_ (Union[str, os.PathLike, None]): Input from command line.
Returns:
Dict[str, Union[str, os.PathLike]]: A dict with ids and inputs.
"""
if self._is_job_input(input_):
ret = self._get_job_contents(input_)
else:
ret = OrderedDict()
if input_ is None: # Take input from stdin
for i, line in enumerate(sys.stdin):
line = line.strip()
if len(line.split(' ')) == 1:
ret[str(i + 1)] = line
elif len(line.split(' ')) == 2:
id_, info = line.split(' ')
ret[id_] = info
else: # No valid input info from one line.
continue
else:
ret[1] = input_
return ret
def process_task_results(self,
input_: Union[str, os.PathLike, None],
results: Dict[str, os.PathLike],
job_dump_result: bool=False):
"""
Handling task results and redirect stdout if needed.
Args:
input_ (Union[str, os.PathLike, None]): Input from command line.
results (Dict[str, os.PathLike]): Task outputs.
job_dump_result (bool, optional): if True, dumps job results into file. Defaults to False.
"""
raw_text = self._format_task_results(results)
print(raw_text, end='')
if self._is_job_input(input_) and job_dump_result:
try:
job_output_file = os.path.abspath(input_) + '.done'
sys.stdout = open(job_output_file, 'w')
print(raw_text, end='')
logger.info(f'Results had been saved to: {job_output_file}')
finally:
sys.stdout.close()
def _is_job_input(self, input_: Union[str, os.PathLike]) -> bool:
"""
Check if current input file is a job input or not.
@ -112,9 +171,10 @@ class BaseExecutor(ABC):
Returns:
bool: return `True` for job input, `False` otherwise.
"""
return os.path.isfile(input_) and input_.endswith('.job')
return input_ and os.path.isfile(input_) and input_.endswith('.job')
def _job_preprocess(self, job_input: os.PathLike) -> Dict[str, str]:
def _get_job_contents(
self, job_input: os.PathLike) -> Dict[str, Union[str, os.PathLike]]:
"""
Read a job input file and return its contents in a dictionary.
@ -124,7 +184,7 @@ class BaseExecutor(ABC):
Returns:
Dict[str, str]: Contents of job input.
"""
job_contents = {}
job_contents = OrderedDict()
with open(job_input) as f:
for line in f:
line = line.strip()
@ -134,17 +194,18 @@ class BaseExecutor(ABC):
job_contents[k] = v
return job_contents
def _job_postprecess(self, job_outputs: Dict[str, str]) -> str:
def _format_task_results(
self, results: Dict[str, Union[str, os.PathLike]]) -> str:
"""
Convert job results to string.
Convert task results to raw text.
Args:
job_outputs (Dict[str, str]): A dictionary with job ids and results.
results (Dict[str, str]): A dictionary of task results.
Returns:
str: A string object contains job outputs.
str: A string object contains task results.
"""
ret = ''
for k, v in job_outputs.items():
for k, v in results.items():
ret += f'{k} {v}\n'
return ret

Loading…
Cancel
Save