Merge pull request #1472 from KPatr1ck/cli_batch

[CLI][Logger]Add cli logger control.
pull/1477/head
Hui Zhang 3 years ago committed by GitHub
commit 60c0877e7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import ast
import os import os
import sys import sys
from collections import OrderedDict from collections import OrderedDict
@ -183,10 +182,15 @@ class ASRExecutor(BaseExecutor):
default=paddle.get_device(), default=paddle.get_device(),
help='Choose device to execute model inference.') help='Choose device to execute model inference.')
self.parser.add_argument( self.parser.add_argument(
'-d',
'--job_dump_result', '--job_dump_result',
type=ast.literal_eval, action='store_true',
default=False,
help='Save job result into file.') help='Save job result into file.')
self.parser.add_argument(
'-v',
'--verbose',
action='store_true',
help='Increase logger verbosity of current task.')
def _get_pretrained_path(self, tag: str) -> os.PathLike: def _get_pretrained_path(self, tag: str) -> os.PathLike:
""" """
@ -479,7 +483,9 @@ class ASRExecutor(BaseExecutor):
decode_method = parser_args.decode_method decode_method = parser_args.decode_method
force_yes = parser_args.yes force_yes = parser_args.yes
device = parser_args.device device = parser_args.device
job_dump_result = parser_args.job_dump_result
if not parser_args.verbose:
self.disable_task_loggers()
task_source = self.get_task_source(parser_args.input) task_source = self.get_task_source(parser_args.input)
task_results = OrderedDict() task_results = OrderedDict()
@ -495,7 +501,7 @@ class ASRExecutor(BaseExecutor):
task_results[id_] = f'{e.__class__.__name__}: {e}' task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results, self.process_task_results(parser_args.input, task_results,
job_dump_result) parser_args.job_dump_result)
if has_exceptions: if has_exceptions:
return False return False

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import ast
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import List from typing import List
@ -112,10 +111,15 @@ class CLSExecutor(BaseExecutor):
default=paddle.get_device(), default=paddle.get_device(),
help='Choose device to execute model inference.') help='Choose device to execute model inference.')
self.parser.add_argument( self.parser.add_argument(
'-d',
'--job_dump_result', '--job_dump_result',
type=ast.literal_eval, action='store_true',
default=False,
help='Save job result into file.') help='Save job result into file.')
self.parser.add_argument(
'-v',
'--verbose',
action='store_true',
help='Increase logger verbosity of current task.')
def _get_pretrained_path(self, tag: str) -> os.PathLike: def _get_pretrained_path(self, tag: str) -> os.PathLike:
""" """
@ -243,7 +247,9 @@ class CLSExecutor(BaseExecutor):
ckpt_path = parser_args.ckpt_path ckpt_path = parser_args.ckpt_path
topk = parser_args.topk topk = parser_args.topk
device = parser_args.device device = parser_args.device
job_dump_result = parser_args.job_dump_result
if not parser_args.verbose:
self.disable_task_loggers()
task_source = self.get_task_source(parser_args.input) task_source = self.get_task_source(parser_args.input)
task_results = OrderedDict() task_results = OrderedDict()
@ -259,7 +265,7 @@ class CLSExecutor(BaseExecutor):
task_results[id_] = f'{e.__class__.__name__}: {e}' task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results, self.process_task_results(parser_args.input, task_results,
job_dump_result) parser_args.job_dump_result)
if has_exceptions: if has_exceptions:
return False return False

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import os import os
import sys import sys
from abc import ABC from abc import ABC
@ -149,10 +150,16 @@ class BaseExecutor(ABC):
job_dump_result (bool, optional): if True, dumps job results into file. Defaults to False. job_dump_result (bool, optional): if True, dumps job results into file. Defaults to False.
""" """
raw_text = self._format_task_results(results) if not self._is_job_input(input_) and len(
print(raw_text, end='') results) == 1: # Only one input sample
raw_text = list(results.values())[0]
else:
raw_text = self._format_task_results(results)
print(raw_text, end='') # Stdout
if self._is_job_input(input_) and job_dump_result: if self._is_job_input(
input_) and job_dump_result: # Dump to *.job.done
try: try:
job_output_file = os.path.abspath(input_) + '.done' job_output_file = os.path.abspath(input_) + '.done'
sys.stdout = open(job_output_file, 'w') sys.stdout = open(job_output_file, 'w')
@ -209,3 +216,13 @@ class BaseExecutor(ABC):
for k, v in results.items(): for k, v in results.items():
ret += f'{k} {v}\n' ret += f'{k} {v}\n'
return ret return ret
def disable_task_loggers(self):
"""
Disable all loggers in current task.
"""
loggers = [
logging.getLogger(name) for name in logging.root.manager.loggerDict
]
for l in loggers:
l.disabled = True

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import ast
import os import os
import subprocess import subprocess
from collections import OrderedDict from collections import OrderedDict
@ -110,10 +109,15 @@ class STExecutor(BaseExecutor):
default=paddle.get_device(), default=paddle.get_device(),
help="Choose device to execute model inference.") help="Choose device to execute model inference.")
self.parser.add_argument( self.parser.add_argument(
'-d',
'--job_dump_result', '--job_dump_result',
type=ast.literal_eval, action='store_true',
default=False,
help='Save job result into file.') help='Save job result into file.')
self.parser.add_argument(
'-v',
'--verbose',
action='store_true',
help='Increase logger verbosity of current task.')
def _get_pretrained_path(self, tag: str) -> os.PathLike: def _get_pretrained_path(self, tag: str) -> os.PathLike:
""" """
@ -327,7 +331,9 @@ class STExecutor(BaseExecutor):
config = parser_args.config config = parser_args.config
ckpt_path = parser_args.ckpt_path ckpt_path = parser_args.ckpt_path
device = parser_args.device device = parser_args.device
job_dump_result = parser_args.job_dump_result
if not parser_args.verbose:
self.disable_task_loggers()
task_source = self.get_task_source(parser_args.input) task_source = self.get_task_source(parser_args.input)
task_results = OrderedDict() task_results = OrderedDict()
@ -343,7 +349,7 @@ class STExecutor(BaseExecutor):
task_results[id_] = f'{e.__class__.__name__}: {e}' task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results, self.process_task_results(parser_args.input, task_results,
job_dump_result) parser_args.job_dump_result)
if has_exceptions: if has_exceptions:
return False return False

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import ast
import os import os
import re import re
from collections import OrderedDict from collections import OrderedDict
@ -122,10 +121,15 @@ class TextExecutor(BaseExecutor):
default=paddle.get_device(), default=paddle.get_device(),
help='Choose device to execute model inference.') help='Choose device to execute model inference.')
self.parser.add_argument( self.parser.add_argument(
'-d',
'--job_dump_result', '--job_dump_result',
type=ast.literal_eval, action='store_true',
default=False,
help='Save job result into file.') help='Save job result into file.')
self.parser.add_argument(
'-v',
'--verbose',
action='store_true',
help='Increase logger verbosity of current task.')
def _get_pretrained_path(self, tag: str) -> os.PathLike: def _get_pretrained_path(self, tag: str) -> os.PathLike:
""" """
@ -270,7 +274,9 @@ class TextExecutor(BaseExecutor):
ckpt_path = parser_args.ckpt_path ckpt_path = parser_args.ckpt_path
punc_vocab = parser_args.punc_vocab punc_vocab = parser_args.punc_vocab
device = parser_args.device device = parser_args.device
job_dump_result = parser_args.job_dump_result
if not parser_args.verbose:
self.disable_task_loggers()
task_source = self.get_task_source(parser_args.input) task_source = self.get_task_source(parser_args.input)
task_results = OrderedDict() task_results = OrderedDict()
@ -286,7 +292,7 @@ class TextExecutor(BaseExecutor):
task_results[id_] = f'{e.__class__.__name__}: {e}' task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results, self.process_task_results(parser_args.input, task_results,
job_dump_result) parser_args.job_dump_result)
if has_exceptions: if has_exceptions:
return False return False

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import ast
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import Any from typing import Any
@ -400,10 +399,15 @@ class TTSExecutor(BaseExecutor):
self.parser.add_argument( self.parser.add_argument(
'--output', type=str, default='output.wav', help='output file name') '--output', type=str, default='output.wav', help='output file name')
self.parser.add_argument( self.parser.add_argument(
'-d',
'--job_dump_result', '--job_dump_result',
type=ast.literal_eval, action='store_true',
default=False,
help='Save job result into file.') help='Save job result into file.')
self.parser.add_argument(
'-v',
'--verbose',
action='store_true',
help='Increase logger verbosity of current task.')
def _get_pretrained_path(self, tag: str) -> os.PathLike: def _get_pretrained_path(self, tag: str) -> os.PathLike:
""" """
@ -693,7 +697,9 @@ class TTSExecutor(BaseExecutor):
lang = args.lang lang = args.lang
device = args.device device = args.device
spk_id = args.spk_id spk_id = args.spk_id
job_dump_result = args.job_dump_result
if not args.verbose:
self.disable_task_loggers()
task_source = self.get_task_source(args.input) task_source = self.get_task_source(args.input)
task_results = OrderedDict() task_results = OrderedDict()
@ -733,7 +739,8 @@ class TTSExecutor(BaseExecutor):
has_exceptions = True has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}' task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(args.input, task_results, job_dump_result) self.process_task_results(args.input, task_results,
args.job_dump_result)
if has_exceptions: if has_exceptions:
return False return False

Loading…
Cancel
Save