Merge pull request #1460 from KPatr1ck/cli_batch

[CLI][Batch]Support batch input in cli.
pull/1472/head
Hui Zhang 4 years ago committed by GitHub
commit 3151637aef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import ast
import os import os
import sys import sys
from collections import OrderedDict
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Union from typing import Union
@ -130,7 +132,7 @@ class ASRExecutor(BaseExecutor):
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(
prog='paddlespeech.asr', add_help=True) prog='paddlespeech.asr', add_help=True)
self.parser.add_argument( self.parser.add_argument(
'--input', type=str, required=True, help='Audio file to recognize.') '--input', type=str, default=None, help='Audio file to recognize.')
self.parser.add_argument( self.parser.add_argument(
'--model', '--model',
type=str, type=str,
@ -180,6 +182,11 @@ class ASRExecutor(BaseExecutor):
type=str, type=str,
default=paddle.get_device(), default=paddle.get_device(),
help='Choose device to execute model inference.') help='Choose device to execute model inference.')
self.parser.add_argument(
'--job_dump_result',
type=ast.literal_eval,
default=False,
help='Save job result into file.')
def _get_pretrained_path(self, tag: str) -> os.PathLike: def _get_pretrained_path(self, tag: str) -> os.PathLike:
""" """
@ -469,19 +476,31 @@ class ASRExecutor(BaseExecutor):
sample_rate = parser_args.sample_rate sample_rate = parser_args.sample_rate
config = parser_args.config config = parser_args.config
ckpt_path = parser_args.ckpt_path ckpt_path = parser_args.ckpt_path
audio_file = parser_args.input
decode_method = parser_args.decode_method decode_method = parser_args.decode_method
force_yes = parser_args.yes force_yes = parser_args.yes
device = parser_args.device device = parser_args.device
job_dump_result = parser_args.job_dump_result
try: task_source = self.get_task_source(parser_args.input)
res = self(audio_file, model, lang, sample_rate, config, ckpt_path, task_results = OrderedDict()
decode_method, force_yes, device) has_exceptions = False
logger.info('ASR Result: {}'.format(res))
return True for id_, input_ in task_source.items():
except Exception as e: try:
logger.exception(e) res = self(input_, model, lang, sample_rate, config, ckpt_path,
decode_method, force_yes, device)
task_results[id_] = res
except Exception as e:
has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results,
job_dump_result)
if has_exceptions:
return False return False
else:
return True
@stats_wrapper @stats_wrapper
def __call__(self, def __call__(self,

@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import ast
import os import os
from collections import OrderedDict
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Union from typing import Union
@ -77,7 +79,7 @@ class CLSExecutor(BaseExecutor):
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(
prog='paddlespeech.cls', add_help=True) prog='paddlespeech.cls', add_help=True)
self.parser.add_argument( self.parser.add_argument(
'--input', type=str, required=True, help='Audio file to classify.') '--input', type=str, default=None, help='Audio file to classify.')
self.parser.add_argument( self.parser.add_argument(
'--model', '--model',
type=str, type=str,
@ -109,6 +111,11 @@ class CLSExecutor(BaseExecutor):
type=str, type=str,
default=paddle.get_device(), default=paddle.get_device(),
help='Choose device to execute model inference.') help='Choose device to execute model inference.')
self.parser.add_argument(
'--job_dump_result',
type=ast.literal_eval,
default=False,
help='Save job result into file.')
def _get_pretrained_path(self, tag: str) -> os.PathLike: def _get_pretrained_path(self, tag: str) -> os.PathLike:
""" """
@ -214,7 +221,7 @@ class CLSExecutor(BaseExecutor):
ret = '' ret = ''
for idx in topk_idx: for idx in topk_idx:
label, score = self._label_list[idx], result[idx] label, score = self._label_list[idx], result[idx]
ret += f'{label}: {score}\n' ret += f'{label} {score} '
return ret return ret
def postprocess(self, topk: int) -> Union[str, os.PathLike]: def postprocess(self, topk: int) -> Union[str, os.PathLike]:
@ -234,18 +241,30 @@ class CLSExecutor(BaseExecutor):
label_file = parser_args.label_file label_file = parser_args.label_file
cfg_path = parser_args.config cfg_path = parser_args.config
ckpt_path = parser_args.ckpt_path ckpt_path = parser_args.ckpt_path
audio_file = parser_args.input
topk = parser_args.topk topk = parser_args.topk
device = parser_args.device device = parser_args.device
job_dump_result = parser_args.job_dump_result
try: task_source = self.get_task_source(parser_args.input)
res = self(audio_file, model_type, cfg_path, ckpt_path, label_file, task_results = OrderedDict()
topk, device) has_exceptions = False
logger.info('CLS Result:\n{}'.format(res))
return True for id_, input_ in task_source.items():
except Exception as e: try:
logger.exception(e) res = self(input_, model_type, cfg_path, ckpt_path, label_file,
topk, device)
task_results[id_] = res
except Exception as e:
has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results,
job_dump_result)
if has_exceptions:
return False return False
else:
return True
@stats_wrapper @stats_wrapper
def __call__(self, def __call__(self,
@ -259,7 +278,7 @@ class CLSExecutor(BaseExecutor):
""" """
Python API to call an executor. Python API to call an executor.
""" """
audio_file = os.path.abspath(audio_file) audio_file = os.path.abspath(os.path.expanduser(audio_file))
paddle.set_device(device) paddle.set_device(device)
self._init_from_path(model, config, ckpt_path, label_file) self._init_from_path(model, config, ckpt_path, label_file)
self.preprocess(audio_file) self.preprocess(audio_file)

@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import sys
from abc import ABC from abc import ABC
from abc import abstractmethod from abc import abstractmethod
from collections import OrderedDict
from typing import Any from typing import Any
from typing import Dict
from typing import List from typing import List
from typing import Union from typing import Union
import paddle import paddle
from .log import logger
class BaseExecutor(ABC): class BaseExecutor(ABC):
""" """
@ -27,8 +32,8 @@ class BaseExecutor(ABC):
""" """
def __init__(self): def __init__(self):
self._inputs = dict() self._inputs = OrderedDict()
self._outputs = dict() self._outputs = OrderedDict()
@abstractmethod @abstractmethod
def _get_pretrained_path(self, tag: str) -> os.PathLike: def _get_pretrained_path(self, tag: str) -> os.PathLike:
@ -100,3 +105,107 @@ class BaseExecutor(ABC):
Python API to call an executor. Python API to call an executor.
""" """
pass pass
def get_task_source(self, input_: Union[str, os.PathLike, None]
) -> Dict[str, Union[str, os.PathLike]]:
"""
Get task input source from command line input.
Args:
input_ (Union[str, os.PathLike, None]): Input from command line.
Returns:
Dict[str, Union[str, os.PathLike]]: A dict with ids and inputs.
"""
if self._is_job_input(input_):
ret = self._get_job_contents(input_)
else:
ret = OrderedDict()
if input_ is None: # Take input from stdin
for i, line in enumerate(sys.stdin):
line = line.strip()
if len(line.split(' ')) == 1:
ret[str(i + 1)] = line
elif len(line.split(' ')) == 2:
id_, info = line.split(' ')
ret[id_] = info
else: # No valid input info from one line.
continue
else:
ret[1] = input_
return ret
def process_task_results(self,
input_: Union[str, os.PathLike, None],
results: Dict[str, os.PathLike],
job_dump_result: bool=False):
"""
Handling task results and redirect stdout if needed.
Args:
input_ (Union[str, os.PathLike, None]): Input from command line.
results (Dict[str, os.PathLike]): Task outputs.
job_dump_result (bool, optional): if True, dumps job results into file. Defaults to False.
"""
raw_text = self._format_task_results(results)
print(raw_text, end='')
if self._is_job_input(input_) and job_dump_result:
try:
job_output_file = os.path.abspath(input_) + '.done'
sys.stdout = open(job_output_file, 'w')
print(raw_text, end='')
logger.info(f'Results had been saved to: {job_output_file}')
finally:
sys.stdout.close()
def _is_job_input(self, input_: Union[str, os.PathLike]) -> bool:
"""
Check if current input file is a job input or not.
Args:
input_ (Union[str, os.PathLike]): Input file of current task.
Returns:
bool: return `True` for job input, `False` otherwise.
"""
return input_ and os.path.isfile(input_) and input_.endswith('.job')
def _get_job_contents(
self, job_input: os.PathLike) -> Dict[str, Union[str, os.PathLike]]:
"""
Read a job input file and return its contents in a dictionary.
Args:
job_input (os.PathLike): The job input file.
Returns:
Dict[str, str]: Contents of job input.
"""
job_contents = OrderedDict()
with open(job_input) as f:
for line in f:
line = line.strip()
if not line:
continue
k, v = line.split(' ')
job_contents[k] = v
return job_contents
def _format_task_results(
self, results: Dict[str, Union[str, os.PathLike]]) -> str:
"""
Convert task results to raw text.
Args:
results (Dict[str, str]): A dictionary of task results.
Returns:
str: A string object contains task results.
"""
ret = ''
for k, v in results.items():
ret += f'{k} {v}\n'
return ret

@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import ast
import os import os
import subprocess import subprocess
from collections import OrderedDict
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Union from typing import Union
@ -69,7 +71,7 @@ class STExecutor(BaseExecutor):
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(
prog="paddlespeech.st", add_help=True) prog="paddlespeech.st", add_help=True)
self.parser.add_argument( self.parser.add_argument(
"--input", type=str, required=True, help="Audio file to translate.") "--input", type=str, default=None, help="Audio file to translate.")
self.parser.add_argument( self.parser.add_argument(
"--model", "--model",
type=str, type=str,
@ -107,6 +109,11 @@ class STExecutor(BaseExecutor):
type=str, type=str,
default=paddle.get_device(), default=paddle.get_device(),
help="Choose device to execute model inference.") help="Choose device to execute model inference.")
self.parser.add_argument(
'--job_dump_result',
type=ast.literal_eval,
default=False,
help='Save job result into file.')
def _get_pretrained_path(self, tag: str) -> os.PathLike: def _get_pretrained_path(self, tag: str) -> os.PathLike:
""" """
@ -319,17 +326,29 @@ class STExecutor(BaseExecutor):
sample_rate = parser_args.sample_rate sample_rate = parser_args.sample_rate
config = parser_args.config config = parser_args.config
ckpt_path = parser_args.ckpt_path ckpt_path = parser_args.ckpt_path
audio_file = parser_args.input
device = parser_args.device device = parser_args.device
job_dump_result = parser_args.job_dump_result
try: task_source = self.get_task_source(parser_args.input)
res = self(audio_file, model, src_lang, tgt_lang, sample_rate, task_results = OrderedDict()
config, ckpt_path, device) has_exceptions = False
logger.info("ST Result: {}".format(res))
return True for id_, input_ in task_source.items():
except Exception as e: try:
logger.exception(e) res = self(input_, model, src_lang, tgt_lang, sample_rate,
config, ckpt_path, device)
task_results[id_] = res
except Exception as e:
has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results,
job_dump_result)
if has_exceptions:
return False return False
else:
return True
@stats_wrapper @stats_wrapper
def __call__(self, def __call__(self,

@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import ast
import os import os
import re import re
from collections import OrderedDict
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Union from typing import Union
@ -80,7 +82,7 @@ class TextExecutor(BaseExecutor):
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(
prog='paddlespeech.text', add_help=True) prog='paddlespeech.text', add_help=True)
self.parser.add_argument( self.parser.add_argument(
'--input', type=str, required=True, help='Input text.') '--input', type=str, default=None, help='Input text.')
self.parser.add_argument( self.parser.add_argument(
'--task', '--task',
type=str, type=str,
@ -119,6 +121,11 @@ class TextExecutor(BaseExecutor):
type=str, type=str,
default=paddle.get_device(), default=paddle.get_device(),
help='Choose device to execute model inference.') help='Choose device to execute model inference.')
self.parser.add_argument(
'--job_dump_result',
type=ast.literal_eval,
default=False,
help='Save job result into file.')
def _get_pretrained_path(self, tag: str) -> os.PathLike: def _get_pretrained_path(self, tag: str) -> os.PathLike:
""" """
@ -256,7 +263,6 @@ class TextExecutor(BaseExecutor):
""" """
parser_args = self.parser.parse_args(argv) parser_args = self.parser.parse_args(argv)
text = parser_args.input
task = parser_args.task task = parser_args.task
model_type = parser_args.model model_type = parser_args.model
lang = parser_args.lang lang = parser_args.lang
@ -264,15 +270,28 @@ class TextExecutor(BaseExecutor):
ckpt_path = parser_args.ckpt_path ckpt_path = parser_args.ckpt_path
punc_vocab = parser_args.punc_vocab punc_vocab = parser_args.punc_vocab
device = parser_args.device device = parser_args.device
job_dump_result = parser_args.job_dump_result
try: task_source = self.get_task_source(parser_args.input)
res = self(text, task, model_type, lang, cfg_path, ckpt_path, task_results = OrderedDict()
punc_vocab, device) has_exceptions = False
logger.info('Text Result:\n{}'.format(res))
return True for id_, input_ in task_source.items():
except Exception as e: try:
logger.exception(e) res = self(input_, task, model_type, lang, cfg_path, ckpt_path,
punc_vocab, device)
task_results[id_] = res
except Exception as e:
has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results,
job_dump_result)
if has_exceptions:
return False return False
else:
return True
@stats_wrapper @stats_wrapper
def __call__( def __call__(

@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import ast
import os import os
from collections import OrderedDict
from typing import Any from typing import Any
from typing import List from typing import List
from typing import Optional from typing import Optional
@ -298,7 +300,7 @@ class TTSExecutor(BaseExecutor):
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(
prog='paddlespeech.tts', add_help=True) prog='paddlespeech.tts', add_help=True)
self.parser.add_argument( self.parser.add_argument(
'--input', type=str, required=True, help='Input text to generate.') '--input', type=str, default=None, help='Input text to generate.')
# acoustic model # acoustic model
self.parser.add_argument( self.parser.add_argument(
'--am', '--am',
@ -397,6 +399,11 @@ class TTSExecutor(BaseExecutor):
self.parser.add_argument( self.parser.add_argument(
'--output', type=str, default='output.wav', help='output file name') '--output', type=str, default='output.wav', help='output file name')
self.parser.add_argument(
'--job_dump_result',
type=ast.literal_eval,
default=False,
help='Save job result into file.')
def _get_pretrained_path(self, tag: str) -> os.PathLike: def _get_pretrained_path(self, tag: str) -> os.PathLike:
""" """
@ -671,7 +678,6 @@ class TTSExecutor(BaseExecutor):
args = self.parser.parse_args(argv) args = self.parser.parse_args(argv)
text = args.input
am = args.am am = args.am
am_config = args.am_config am_config = args.am_config
am_ckpt = args.am_ckpt am_ckpt = args.am_ckpt
@ -686,35 +692,53 @@ class TTSExecutor(BaseExecutor):
voc_stat = args.voc_stat voc_stat = args.voc_stat
lang = args.lang lang = args.lang
device = args.device device = args.device
output = args.output
spk_id = args.spk_id spk_id = args.spk_id
job_dump_result = args.job_dump_result
try: task_source = self.get_task_source(args.input)
res = self( task_results = OrderedDict()
text=text, has_exceptions = False
# acoustic model related
am=am, for id_, input_ in task_source.items():
am_config=am_config, if len(task_source) > 1:
am_ckpt=am_ckpt, assert isinstance(args.output,
am_stat=am_stat, str) and args.output.endswith('.wav')
phones_dict=phones_dict, output = args.output.replace('.wav', f'_{id_}.wav')
tones_dict=tones_dict, else:
speaker_dict=speaker_dict, output = args.output
spk_id=spk_id,
# vocoder related try:
voc=voc, res = self(
voc_config=voc_config, text=input_,
voc_ckpt=voc_ckpt, # acoustic model related
voc_stat=voc_stat, am=am,
# other am_config=am_config,
lang=lang, am_ckpt=am_ckpt,
device=device, am_stat=am_stat,
output=output) phones_dict=phones_dict,
logger.info('Wave file has been generated: {}'.format(res)) tones_dict=tones_dict,
return True speaker_dict=speaker_dict,
except Exception as e: spk_id=spk_id,
logger.exception(e) # vocoder related
voc=voc,
voc_config=voc_config,
voc_ckpt=voc_ckpt,
voc_stat=voc_stat,
# other
lang=lang,
device=device,
output=output)
task_results[id_] = res
except Exception as e:
has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(args.input, task_results, job_dump_result)
if has_exceptions:
return False return False
else:
return True
@stats_wrapper @stats_wrapper
def __call__(self, def __call__(self,

Loading…
Cancel
Save