Update batch input and stdin input.

pull/1460/head
KP 3 years ago
parent 1818b058aa
commit 05288fe1c3

@ -14,7 +14,7 @@
import argparse import argparse
import ast import ast
import os import os
import sys from collections import OrderedDict
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Union from typing import Union
@ -79,7 +79,7 @@ class CLSExecutor(BaseExecutor):
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(
prog='paddlespeech.cls', add_help=True) prog='paddlespeech.cls', add_help=True)
self.parser.add_argument( self.parser.add_argument(
'--input', type=str, required=True, help='Audio file to classify.') '--input', type=str, default=None, help='Audio file to classify.')
self.parser.add_argument( self.parser.add_argument(
'--model', '--model',
type=str, type=str,
@ -221,7 +221,7 @@ class CLSExecutor(BaseExecutor):
ret = '' ret = ''
for idx in topk_idx: for idx in topk_idx:
label, score = self._label_list[idx], result[idx] label, score = self._label_list[idx], result[idx]
ret += f'{label} {score}\n' ret += f'{label} {score} '
return ret return ret
def postprocess(self, topk: int) -> Union[str, os.PathLike]: def postprocess(self, topk: int) -> Union[str, os.PathLike]:
@ -241,36 +241,34 @@ class CLSExecutor(BaseExecutor):
label_file = parser_args.label_file label_file = parser_args.label_file
cfg_path = parser_args.config cfg_path = parser_args.config
ckpt_path = parser_args.ckpt_path ckpt_path = parser_args.ckpt_path
input_file = parser_args.input
topk = parser_args.topk topk = parser_args.topk
device = parser_args.device device = parser_args.device
job_dump_result = parser_args.job_dump_result job_dump_result = parser_args.job_dump_result
try: task_source = self.get_task_source(parser_args.input)
if job_dump_result: task_results = OrderedDict()
assert self._is_job_input( has_exceptions = False
input_file
), 'Input file should be a job file(*.job) when `job_dump_result` is True.'
job_output_file = os.path.abspath(input_file) + '.done'
sys.stdout = open(job_output_file, 'w')
print( for id_, input_ in task_source.items():
self(input_file, model_type, cfg_path, ckpt_path, label_file, try:
topk, device)) res = self(input_, model_type, cfg_path, ckpt_path, label_file,
topk, device)
task_results[id_] = res
except Exception as e:
has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
if job_dump_result: self.process_task_results(parser_args.input, task_results,
logger.info(f'Results had been saved to: {job_output_file}') job_dump_result)
return True if has_exceptions:
except Exception as e:
logger.exception(e)
return False return False
finally: else:
sys.stdout.close() return True
@stats_wrapper @stats_wrapper
def __call__(self, def __call__(self,
input_file: os.PathLike, audio_file: os.PathLike,
model: str='panns_cnn14', model: str='panns_cnn14',
config: Optional[os.PathLike]=None, config: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None, ckpt_path: Optional[os.PathLike]=None,
@ -280,24 +278,11 @@ class CLSExecutor(BaseExecutor):
""" """
Python API to call an executor. Python API to call an executor.
""" """
input_file = os.path.abspath(input_file) audio_file = os.path.abspath(os.path.expanduser(audio_file))
paddle.set_device(device) paddle.set_device(device)
self._init_from_path(model, config, ckpt_path, label_file) self._init_from_path(model, config, ckpt_path, label_file)
self.preprocess(audio_file)
if self._is_job_input(input_file): # *.job self.infer()
job_outputs = {} res = self.postprocess(topk) # Retrieve result of cls.
job_contents = self._job_preprocess(input_file)
for id_, file in job_contents.items():
try:
self.preprocess(file)
self.infer()
job_outputs[id_] = self.postprocess(topk).strip()
except Exception as e:
job_outputs[id_] = f'{e.__class__.__name__}: {e}'
res = self._job_postprecess(job_outputs)
else:
self.preprocess(input_file)
self.infer()
res = self.postprocess(topk) # Retrieve result of cls.
return res return res

@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import sys
from abc import ABC from abc import ABC
from abc import abstractmethod from abc import abstractmethod
from collections import OrderedDict
from typing import Any from typing import Any
from typing import Dict from typing import Dict
from typing import List from typing import List
@ -21,6 +23,8 @@ from typing import Union
import paddle import paddle
from .log import logger
class BaseExecutor(ABC): class BaseExecutor(ABC):
""" """
@ -28,8 +32,8 @@ class BaseExecutor(ABC):
""" """
def __init__(self): def __init__(self):
self._inputs = dict() self._inputs = OrderedDict()
self._outputs = dict() self._outputs = OrderedDict()
@abstractmethod @abstractmethod
def _get_pretrained_path(self, tag: str) -> os.PathLike: def _get_pretrained_path(self, tag: str) -> os.PathLike:
@ -102,6 +106,61 @@ class BaseExecutor(ABC):
""" """
pass pass
def get_task_source(self, input_: Union[str, os.PathLike, None]
) -> Dict[str, Union[str, os.PathLike]]:
"""
Get task input source from command line input.
Args:
input_ (Union[str, os.PathLike, None]): Input from command line.
Returns:
Dict[str, Union[str, os.PathLike]]: A dict with ids and inputs.
"""
if self._is_job_input(input_):
ret = self._get_job_contents(input_)
else:
ret = OrderedDict()
if input_ is None: # Take input from stdin
for i, line in enumerate(sys.stdin):
line = line.strip()
if len(line.split(' ')) == 1:
ret[str(i + 1)] = line
elif len(line.split(' ')) == 2:
id_, info = line.split(' ')
ret[id_] = info
else: # No valid input info from one line.
continue
else:
ret[1] = input_
return ret
def process_task_results(self,
input_: Union[str, os.PathLike, None],
results: Dict[str, os.PathLike],
job_dump_result: bool=False):
"""
Handling task results and redirect stdout if needed.
Args:
input_ (Union[str, os.PathLike, None]): Input from command line.
results (Dict[str, os.PathLike]): Task outputs.
job_dump_result (bool, optional): if True, dumps job results into file. Defaults to False.
"""
raw_text = self._format_task_results(results)
print(raw_text, end='')
if self._is_job_input(input_) and job_dump_result:
try:
job_output_file = os.path.abspath(input_) + '.done'
sys.stdout = open(job_output_file, 'w')
print(raw_text, end='')
logger.info(f'Results had been saved to: {job_output_file}')
finally:
sys.stdout.close()
def _is_job_input(self, input_: Union[str, os.PathLike]) -> bool: def _is_job_input(self, input_: Union[str, os.PathLike]) -> bool:
""" """
Check if current input file is a job input or not. Check if current input file is a job input or not.
@ -112,9 +171,10 @@ class BaseExecutor(ABC):
Returns: Returns:
bool: return `True` for job input, `False` otherwise. bool: return `True` for job input, `False` otherwise.
""" """
return os.path.isfile(input_) and input_.endswith('.job') return input_ and os.path.isfile(input_) and input_.endswith('.job')
def _job_preprocess(self, job_input: os.PathLike) -> Dict[str, str]: def _get_job_contents(
self, job_input: os.PathLike) -> Dict[str, Union[str, os.PathLike]]:
""" """
Read a job input file and return its contents in a dictionary. Read a job input file and return its contents in a dictionary.
@ -124,7 +184,7 @@ class BaseExecutor(ABC):
Returns: Returns:
Dict[str, str]: Contents of job input. Dict[str, str]: Contents of job input.
""" """
job_contents = {} job_contents = OrderedDict()
with open(job_input) as f: with open(job_input) as f:
for line in f: for line in f:
line = line.strip() line = line.strip()
@ -134,17 +194,18 @@ class BaseExecutor(ABC):
job_contents[k] = v job_contents[k] = v
return job_contents return job_contents
def _job_postprecess(self, job_outputs: Dict[str, str]) -> str: def _format_task_results(
self, results: Dict[str, Union[str, os.PathLike]]) -> str:
""" """
Convert job results to string. Convert task results to raw text.
Args: Args:
job_outputs (Dict[str, str]): A dictionary with job ids and results. results (Dict[str, str]): A dictionary of task results.
Returns: Returns:
str: A string object contains job outputs. str: A string object contains task results.
""" """
ret = '' ret = ''
for k, v in job_outputs.items(): for k, v in results.items():
ret += f'{k} {v}\n' ret += f'{k} {v}\n'
return ret return ret

Loading…
Cancel
Save