Merge pull request #1460 from KPatr1ck/cli_batch

[CLI][Batch]Support batch input in cli.
pull/1472/head
Hui Zhang 4 years ago committed by GitHub
commit 3151637aef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import os
import sys
from collections import OrderedDict
from typing import List
from typing import Optional
from typing import Union
@ -130,7 +132,7 @@ class ASRExecutor(BaseExecutor):
self.parser = argparse.ArgumentParser(
prog='paddlespeech.asr', add_help=True)
self.parser.add_argument(
'--input', type=str, required=True, help='Audio file to recognize.')
'--input', type=str, default=None, help='Audio file to recognize.')
self.parser.add_argument(
'--model',
type=str,
@ -180,6 +182,11 @@ class ASRExecutor(BaseExecutor):
type=str,
default=paddle.get_device(),
help='Choose device to execute model inference.')
self.parser.add_argument(
'--job_dump_result',
type=ast.literal_eval,
default=False,
help='Save job result into file.')
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
@ -469,19 +476,31 @@ class ASRExecutor(BaseExecutor):
sample_rate = parser_args.sample_rate
config = parser_args.config
ckpt_path = parser_args.ckpt_path
audio_file = parser_args.input
decode_method = parser_args.decode_method
force_yes = parser_args.yes
device = parser_args.device
job_dump_result = parser_args.job_dump_result
try:
res = self(audio_file, model, lang, sample_rate, config, ckpt_path,
decode_method, force_yes, device)
logger.info('ASR Result: {}'.format(res))
return True
except Exception as e:
logger.exception(e)
task_source = self.get_task_source(parser_args.input)
task_results = OrderedDict()
has_exceptions = False
for id_, input_ in task_source.items():
try:
res = self(input_, model, lang, sample_rate, config, ckpt_path,
decode_method, force_yes, device)
task_results[id_] = res
except Exception as e:
has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results,
job_dump_result)
if has_exceptions:
return False
else:
return True
@stats_wrapper
def __call__(self,

@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import os
from collections import OrderedDict
from typing import List
from typing import Optional
from typing import Union
@ -77,7 +79,7 @@ class CLSExecutor(BaseExecutor):
self.parser = argparse.ArgumentParser(
prog='paddlespeech.cls', add_help=True)
self.parser.add_argument(
'--input', type=str, required=True, help='Audio file to classify.')
'--input', type=str, default=None, help='Audio file to classify.')
self.parser.add_argument(
'--model',
type=str,
@ -109,6 +111,11 @@ class CLSExecutor(BaseExecutor):
type=str,
default=paddle.get_device(),
help='Choose device to execute model inference.')
self.parser.add_argument(
'--job_dump_result',
type=ast.literal_eval,
default=False,
help='Save job result into file.')
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
@ -214,7 +221,7 @@ class CLSExecutor(BaseExecutor):
ret = ''
for idx in topk_idx:
label, score = self._label_list[idx], result[idx]
ret += f'{label}: {score}\n'
ret += f'{label} {score} '
return ret
def postprocess(self, topk: int) -> Union[str, os.PathLike]:
@ -234,18 +241,30 @@ class CLSExecutor(BaseExecutor):
label_file = parser_args.label_file
cfg_path = parser_args.config
ckpt_path = parser_args.ckpt_path
audio_file = parser_args.input
topk = parser_args.topk
device = parser_args.device
job_dump_result = parser_args.job_dump_result
try:
res = self(audio_file, model_type, cfg_path, ckpt_path, label_file,
topk, device)
logger.info('CLS Result:\n{}'.format(res))
return True
except Exception as e:
logger.exception(e)
task_source = self.get_task_source(parser_args.input)
task_results = OrderedDict()
has_exceptions = False
for id_, input_ in task_source.items():
try:
res = self(input_, model_type, cfg_path, ckpt_path, label_file,
topk, device)
task_results[id_] = res
except Exception as e:
has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results,
job_dump_result)
if has_exceptions:
return False
else:
return True
@stats_wrapper
def __call__(self,
@ -259,7 +278,7 @@ class CLSExecutor(BaseExecutor):
"""
Python API to call an executor.
"""
audio_file = os.path.abspath(audio_file)
audio_file = os.path.abspath(os.path.expanduser(audio_file))
paddle.set_device(device)
self._init_from_path(model, config, ckpt_path, label_file)
self.preprocess(audio_file)

@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from abc import ABC
from abc import abstractmethod
from collections import OrderedDict
from typing import Any
from typing import Dict
from typing import List
from typing import Union
import paddle
from .log import logger
class BaseExecutor(ABC):
"""
@ -27,8 +32,8 @@ class BaseExecutor(ABC):
"""
def __init__(self):
self._inputs = dict()
self._outputs = dict()
self._inputs = OrderedDict()
self._outputs = OrderedDict()
@abstractmethod
def _get_pretrained_path(self, tag: str) -> os.PathLike:
@ -100,3 +105,107 @@ class BaseExecutor(ABC):
Python API to call an executor.
"""
pass
def get_task_source(self, input_: Union[str, os.PathLike, None]
) -> Dict[str, Union[str, os.PathLike]]:
"""
Get task input source from command line input.
Args:
input_ (Union[str, os.PathLike, None]): Input from command line.
Returns:
Dict[str, Union[str, os.PathLike]]: A dict with ids and inputs.
"""
if self._is_job_input(input_):
ret = self._get_job_contents(input_)
else:
ret = OrderedDict()
if input_ is None: # Take input from stdin
for i, line in enumerate(sys.stdin):
line = line.strip()
if len(line.split(' ')) == 1:
ret[str(i + 1)] = line
elif len(line.split(' ')) == 2:
id_, info = line.split(' ')
ret[id_] = info
else: # No valid input info from one line.
continue
else:
ret[1] = input_
return ret
def process_task_results(self,
input_: Union[str, os.PathLike, None],
results: Dict[str, os.PathLike],
job_dump_result: bool=False):
"""
Handling task results and redirect stdout if needed.
Args:
input_ (Union[str, os.PathLike, None]): Input from command line.
results (Dict[str, os.PathLike]): Task outputs.
job_dump_result (bool, optional): if True, dumps job results into file. Defaults to False.
"""
raw_text = self._format_task_results(results)
print(raw_text, end='')
if self._is_job_input(input_) and job_dump_result:
try:
job_output_file = os.path.abspath(input_) + '.done'
sys.stdout = open(job_output_file, 'w')
print(raw_text, end='')
logger.info(f'Results had been saved to: {job_output_file}')
finally:
sys.stdout.close()
def _is_job_input(self, input_: Union[str, os.PathLike]) -> bool:
"""
Check if current input file is a job input or not.
Args:
input_ (Union[str, os.PathLike]): Input file of current task.
Returns:
bool: return `True` for job input, `False` otherwise.
"""
return input_ and os.path.isfile(input_) and input_.endswith('.job')
def _get_job_contents(
self, job_input: os.PathLike) -> Dict[str, Union[str, os.PathLike]]:
"""
Read a job input file and return its contents in a dictionary.
Args:
job_input (os.PathLike): The job input file.
Returns:
Dict[str, str]: Contents of job input.
"""
job_contents = OrderedDict()
with open(job_input) as f:
for line in f:
line = line.strip()
if not line:
continue
k, v = line.split(' ')
job_contents[k] = v
return job_contents
def _format_task_results(
self, results: Dict[str, Union[str, os.PathLike]]) -> str:
"""
Convert task results to raw text.
Args:
results (Dict[str, str]): A dictionary of task results.
Returns:
str: A string object contains task results.
"""
ret = ''
for k, v in results.items():
ret += f'{k} {v}\n'
return ret

@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import os
import subprocess
from collections import OrderedDict
from typing import List
from typing import Optional
from typing import Union
@ -69,7 +71,7 @@ class STExecutor(BaseExecutor):
self.parser = argparse.ArgumentParser(
prog="paddlespeech.st", add_help=True)
self.parser.add_argument(
"--input", type=str, required=True, help="Audio file to translate.")
"--input", type=str, default=None, help="Audio file to translate.")
self.parser.add_argument(
"--model",
type=str,
@ -107,6 +109,11 @@ class STExecutor(BaseExecutor):
type=str,
default=paddle.get_device(),
help="Choose device to execute model inference.")
self.parser.add_argument(
'--job_dump_result',
type=ast.literal_eval,
default=False,
help='Save job result into file.')
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
@ -319,17 +326,29 @@ class STExecutor(BaseExecutor):
sample_rate = parser_args.sample_rate
config = parser_args.config
ckpt_path = parser_args.ckpt_path
audio_file = parser_args.input
device = parser_args.device
job_dump_result = parser_args.job_dump_result
try:
res = self(audio_file, model, src_lang, tgt_lang, sample_rate,
config, ckpt_path, device)
logger.info("ST Result: {}".format(res))
return True
except Exception as e:
logger.exception(e)
task_source = self.get_task_source(parser_args.input)
task_results = OrderedDict()
has_exceptions = False
for id_, input_ in task_source.items():
try:
res = self(input_, model, src_lang, tgt_lang, sample_rate,
config, ckpt_path, device)
task_results[id_] = res
except Exception as e:
has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results,
job_dump_result)
if has_exceptions:
return False
else:
return True
@stats_wrapper
def __call__(self,

@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import os
import re
from collections import OrderedDict
from typing import List
from typing import Optional
from typing import Union
@ -80,7 +82,7 @@ class TextExecutor(BaseExecutor):
self.parser = argparse.ArgumentParser(
prog='paddlespeech.text', add_help=True)
self.parser.add_argument(
'--input', type=str, required=True, help='Input text.')
'--input', type=str, default=None, help='Input text.')
self.parser.add_argument(
'--task',
type=str,
@ -119,6 +121,11 @@ class TextExecutor(BaseExecutor):
type=str,
default=paddle.get_device(),
help='Choose device to execute model inference.')
self.parser.add_argument(
'--job_dump_result',
type=ast.literal_eval,
default=False,
help='Save job result into file.')
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
@ -256,7 +263,6 @@ class TextExecutor(BaseExecutor):
"""
parser_args = self.parser.parse_args(argv)
text = parser_args.input
task = parser_args.task
model_type = parser_args.model
lang = parser_args.lang
@ -264,15 +270,28 @@ class TextExecutor(BaseExecutor):
ckpt_path = parser_args.ckpt_path
punc_vocab = parser_args.punc_vocab
device = parser_args.device
job_dump_result = parser_args.job_dump_result
try:
res = self(text, task, model_type, lang, cfg_path, ckpt_path,
punc_vocab, device)
logger.info('Text Result:\n{}'.format(res))
return True
except Exception as e:
logger.exception(e)
task_source = self.get_task_source(parser_args.input)
task_results = OrderedDict()
has_exceptions = False
for id_, input_ in task_source.items():
try:
res = self(input_, task, model_type, lang, cfg_path, ckpt_path,
punc_vocab, device)
task_results[id_] = res
except Exception as e:
has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results,
job_dump_result)
if has_exceptions:
return False
else:
return True
@stats_wrapper
def __call__(

@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import os
from collections import OrderedDict
from typing import Any
from typing import List
from typing import Optional
@ -298,7 +300,7 @@ class TTSExecutor(BaseExecutor):
self.parser = argparse.ArgumentParser(
prog='paddlespeech.tts', add_help=True)
self.parser.add_argument(
'--input', type=str, required=True, help='Input text to generate.')
'--input', type=str, default=None, help='Input text to generate.')
# acoustic model
self.parser.add_argument(
'--am',
@ -397,6 +399,11 @@ class TTSExecutor(BaseExecutor):
self.parser.add_argument(
'--output', type=str, default='output.wav', help='output file name')
self.parser.add_argument(
'--job_dump_result',
type=ast.literal_eval,
default=False,
help='Save job result into file.')
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
@ -671,7 +678,6 @@ class TTSExecutor(BaseExecutor):
args = self.parser.parse_args(argv)
text = args.input
am = args.am
am_config = args.am_config
am_ckpt = args.am_ckpt
@ -686,35 +692,53 @@ class TTSExecutor(BaseExecutor):
voc_stat = args.voc_stat
lang = args.lang
device = args.device
output = args.output
spk_id = args.spk_id
job_dump_result = args.job_dump_result
try:
res = self(
text=text,
# acoustic model related
am=am,
am_config=am_config,
am_ckpt=am_ckpt,
am_stat=am_stat,
phones_dict=phones_dict,
tones_dict=tones_dict,
speaker_dict=speaker_dict,
spk_id=spk_id,
# vocoder related
voc=voc,
voc_config=voc_config,
voc_ckpt=voc_ckpt,
voc_stat=voc_stat,
# other
lang=lang,
device=device,
output=output)
logger.info('Wave file has been generated: {}'.format(res))
return True
except Exception as e:
logger.exception(e)
task_source = self.get_task_source(args.input)
task_results = OrderedDict()
has_exceptions = False
for id_, input_ in task_source.items():
if len(task_source) > 1:
assert isinstance(args.output,
str) and args.output.endswith('.wav')
output = args.output.replace('.wav', f'_{id_}.wav')
else:
output = args.output
try:
res = self(
text=input_,
# acoustic model related
am=am,
am_config=am_config,
am_ckpt=am_ckpt,
am_stat=am_stat,
phones_dict=phones_dict,
tones_dict=tones_dict,
speaker_dict=speaker_dict,
spk_id=spk_id,
# vocoder related
voc=voc,
voc_config=voc_config,
voc_ckpt=voc_ckpt,
voc_stat=voc_stat,
# other
lang=lang,
device=device,
output=output)
task_results[id_] = res
except Exception as e:
has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(args.input, task_results, job_dump_result)
if has_exceptions:
return False
else:
return True
@stats_wrapper
def __call__(self,

Loading…
Cancel
Save