Merge pull request #1472 from KPatr1ck/cli_batch

[CLI][Logger]Add cli logger control.
pull/1477/head
Hui Zhang 2 years ago committed by GitHub
commit 60c0877e7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import os
import sys
from collections import OrderedDict
@ -183,10 +182,15 @@ class ASRExecutor(BaseExecutor):
default=paddle.get_device(),
help='Choose device to execute model inference.')
self.parser.add_argument(
'-d',
'--job_dump_result',
type=ast.literal_eval,
default=False,
action='store_true',
help='Save job result into file.')
self.parser.add_argument(
'-v',
'--verbose',
action='store_true',
help='Increase logger verbosity of current task.')
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
@ -479,7 +483,9 @@ class ASRExecutor(BaseExecutor):
decode_method = parser_args.decode_method
force_yes = parser_args.yes
device = parser_args.device
job_dump_result = parser_args.job_dump_result
if not parser_args.verbose:
self.disable_task_loggers()
task_source = self.get_task_source(parser_args.input)
task_results = OrderedDict()
@ -495,7 +501,7 @@ class ASRExecutor(BaseExecutor):
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results,
job_dump_result)
parser_args.job_dump_result)
if has_exceptions:
return False

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import os
from collections import OrderedDict
from typing import List
@ -112,10 +111,15 @@ class CLSExecutor(BaseExecutor):
default=paddle.get_device(),
help='Choose device to execute model inference.')
self.parser.add_argument(
'-d',
'--job_dump_result',
type=ast.literal_eval,
default=False,
action='store_true',
help='Save job result into file.')
self.parser.add_argument(
'-v',
'--verbose',
action='store_true',
help='Increase logger verbosity of current task.')
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
@ -243,7 +247,9 @@ class CLSExecutor(BaseExecutor):
ckpt_path = parser_args.ckpt_path
topk = parser_args.topk
device = parser_args.device
job_dump_result = parser_args.job_dump_result
if not parser_args.verbose:
self.disable_task_loggers()
task_source = self.get_task_source(parser_args.input)
task_results = OrderedDict()
@ -259,7 +265,7 @@ class CLSExecutor(BaseExecutor):
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results,
job_dump_result)
parser_args.job_dump_result)
if has_exceptions:
return False

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
from abc import ABC
@ -149,10 +150,16 @@ class BaseExecutor(ABC):
job_dump_result (bool, optional): if True, dumps job results into file. Defaults to False.
"""
raw_text = self._format_task_results(results)
print(raw_text, end='')
if not self._is_job_input(input_) and len(
results) == 1: # Only one input sample
raw_text = list(results.values())[0]
else:
raw_text = self._format_task_results(results)
print(raw_text, end='') # Stdout
if self._is_job_input(input_) and job_dump_result:
if self._is_job_input(
input_) and job_dump_result: # Dump to *.job.done
try:
job_output_file = os.path.abspath(input_) + '.done'
sys.stdout = open(job_output_file, 'w')
@ -209,3 +216,13 @@ class BaseExecutor(ABC):
for k, v in results.items():
ret += f'{k} {v}\n'
return ret
def disable_task_loggers(self):
"""
Disable all loggers in current task.
"""
loggers = [
logging.getLogger(name) for name in logging.root.manager.loggerDict
]
for l in loggers:
l.disabled = True

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import os
import subprocess
from collections import OrderedDict
@ -110,10 +109,15 @@ class STExecutor(BaseExecutor):
default=paddle.get_device(),
help="Choose device to execute model inference.")
self.parser.add_argument(
'-d',
'--job_dump_result',
type=ast.literal_eval,
default=False,
action='store_true',
help='Save job result into file.')
self.parser.add_argument(
'-v',
'--verbose',
action='store_true',
help='Increase logger verbosity of current task.')
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
@ -327,7 +331,9 @@ class STExecutor(BaseExecutor):
config = parser_args.config
ckpt_path = parser_args.ckpt_path
device = parser_args.device
job_dump_result = parser_args.job_dump_result
if not parser_args.verbose:
self.disable_task_loggers()
task_source = self.get_task_source(parser_args.input)
task_results = OrderedDict()
@ -343,7 +349,7 @@ class STExecutor(BaseExecutor):
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results,
job_dump_result)
parser_args.job_dump_result)
if has_exceptions:
return False

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import os
import re
from collections import OrderedDict
@ -122,10 +121,15 @@ class TextExecutor(BaseExecutor):
default=paddle.get_device(),
help='Choose device to execute model inference.')
self.parser.add_argument(
'-d',
'--job_dump_result',
type=ast.literal_eval,
default=False,
action='store_true',
help='Save job result into file.')
self.parser.add_argument(
'-v',
'--verbose',
action='store_true',
help='Increase logger verbosity of current task.')
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
@ -270,7 +274,9 @@ class TextExecutor(BaseExecutor):
ckpt_path = parser_args.ckpt_path
punc_vocab = parser_args.punc_vocab
device = parser_args.device
job_dump_result = parser_args.job_dump_result
if not parser_args.verbose:
self.disable_task_loggers()
task_source = self.get_task_source(parser_args.input)
task_results = OrderedDict()
@ -286,7 +292,7 @@ class TextExecutor(BaseExecutor):
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results,
job_dump_result)
parser_args.job_dump_result)
if has_exceptions:
return False

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import os
from collections import OrderedDict
from typing import Any
@ -400,10 +399,15 @@ class TTSExecutor(BaseExecutor):
self.parser.add_argument(
'--output', type=str, default='output.wav', help='output file name')
self.parser.add_argument(
'-d',
'--job_dump_result',
type=ast.literal_eval,
default=False,
action='store_true',
help='Save job result into file.')
self.parser.add_argument(
'-v',
'--verbose',
action='store_true',
help='Increase logger verbosity of current task.')
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
@ -693,7 +697,9 @@ class TTSExecutor(BaseExecutor):
lang = args.lang
device = args.device
spk_id = args.spk_id
job_dump_result = args.job_dump_result
if not args.verbose:
self.disable_task_loggers()
task_source = self.get_task_source(args.input)
task_results = OrderedDict()
@ -733,7 +739,8 @@ class TTSExecutor(BaseExecutor):
has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(args.input, task_results, job_dump_result)
self.process_task_results(args.input, task_results,
args.job_dump_result)
if has_exceptions:
return False

Loading…
Cancel
Save