Update batch input.

pull/1460/head
KP 3 years ago
parent 1bb9a04aef
commit 7814fba07f

@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import ast
import os import os
import sys import sys
from collections import OrderedDict
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Union from typing import Union
@ -130,7 +132,7 @@ class ASRExecutor(BaseExecutor):
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(
prog='paddlespeech.asr', add_help=True) prog='paddlespeech.asr', add_help=True)
self.parser.add_argument( self.parser.add_argument(
'--input', type=str, required=True, help='Audio file to recognize.') '--input', type=str, default=None, help='Audio file to recognize.')
self.parser.add_argument( self.parser.add_argument(
'--model', '--model',
type=str, type=str,
@ -180,6 +182,11 @@ class ASRExecutor(BaseExecutor):
type=str, type=str,
default=paddle.get_device(), default=paddle.get_device(),
help='Choose device to execute model inference.') help='Choose device to execute model inference.')
self.parser.add_argument(
'--job_dump_result',
type=ast.literal_eval,
default=False,
help='Save job result into file.')
def _get_pretrained_path(self, tag: str) -> os.PathLike: def _get_pretrained_path(self, tag: str) -> os.PathLike:
""" """
@ -469,19 +476,31 @@ class ASRExecutor(BaseExecutor):
sample_rate = parser_args.sample_rate sample_rate = parser_args.sample_rate
config = parser_args.config config = parser_args.config
ckpt_path = parser_args.ckpt_path ckpt_path = parser_args.ckpt_path
audio_file = parser_args.input
decode_method = parser_args.decode_method decode_method = parser_args.decode_method
force_yes = parser_args.yes force_yes = parser_args.yes
device = parser_args.device device = parser_args.device
job_dump_result = parser_args.job_dump_result
task_source = self.get_task_source(parser_args.input)
task_results = OrderedDict()
has_exceptions = False
for id_, input_ in task_source.items():
try: try:
res = self(audio_file, model, lang, sample_rate, config, ckpt_path, res = self(input_, model, lang, sample_rate, config, ckpt_path,
decode_method, force_yes, device) decode_method, force_yes, device)
logger.info('ASR Result: {}'.format(res)) task_results[id_] = res
return True
except Exception as e: except Exception as e:
logger.exception(e) has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results,
job_dump_result)
if has_exceptions:
return False return False
else:
return True
@stats_wrapper @stats_wrapper
def __call__(self, def __call__(self,

@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import ast
import os import os
import subprocess import subprocess
from collections import OrderedDict
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Union from typing import Union
@ -69,7 +71,7 @@ class STExecutor(BaseExecutor):
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(
prog="paddlespeech.st", add_help=True) prog="paddlespeech.st", add_help=True)
self.parser.add_argument( self.parser.add_argument(
"--input", type=str, required=True, help="Audio file to translate.") "--input", type=str, default=None, help="Audio file to translate.")
self.parser.add_argument( self.parser.add_argument(
"--model", "--model",
type=str, type=str,
@ -107,6 +109,11 @@ class STExecutor(BaseExecutor):
type=str, type=str,
default=paddle.get_device(), default=paddle.get_device(),
help="Choose device to execute model inference.") help="Choose device to execute model inference.")
self.parser.add_argument(
'--job_dump_result',
type=ast.literal_eval,
default=False,
help='Save job result into file.')
def _get_pretrained_path(self, tag: str) -> os.PathLike: def _get_pretrained_path(self, tag: str) -> os.PathLike:
""" """
@ -319,17 +326,29 @@ class STExecutor(BaseExecutor):
sample_rate = parser_args.sample_rate sample_rate = parser_args.sample_rate
config = parser_args.config config = parser_args.config
ckpt_path = parser_args.ckpt_path ckpt_path = parser_args.ckpt_path
audio_file = parser_args.input
device = parser_args.device device = parser_args.device
job_dump_result = parser_args.job_dump_result
task_source = self.get_task_source(parser_args.input)
task_results = OrderedDict()
has_exceptions = False
for id_, input_ in task_source.items():
try: try:
res = self(audio_file, model, src_lang, tgt_lang, sample_rate, res = self(input_, model, src_lang, tgt_lang, sample_rate,
config, ckpt_path, device) config, ckpt_path, device)
logger.info("ST Result: {}".format(res)) task_results[id_] = res
return True
except Exception as e: except Exception as e:
logger.exception(e) has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results,
job_dump_result)
if has_exceptions:
return False return False
else:
return True
@stats_wrapper @stats_wrapper
def __call__(self, def __call__(self,

@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import ast
import os import os
import re import re
from collections import OrderedDict
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Union from typing import Union
@ -80,7 +82,7 @@ class TextExecutor(BaseExecutor):
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(
prog='paddlespeech.text', add_help=True) prog='paddlespeech.text', add_help=True)
self.parser.add_argument( self.parser.add_argument(
'--input', type=str, required=True, help='Input text.') '--input', type=str, default=None, help='Input text.')
self.parser.add_argument( self.parser.add_argument(
'--task', '--task',
type=str, type=str,
@ -119,6 +121,11 @@ class TextExecutor(BaseExecutor):
type=str, type=str,
default=paddle.get_device(), default=paddle.get_device(),
help='Choose device to execute model inference.') help='Choose device to execute model inference.')
self.parser.add_argument(
'--job_dump_result',
type=ast.literal_eval,
default=False,
help='Save job result into file.')
def _get_pretrained_path(self, tag: str) -> os.PathLike: def _get_pretrained_path(self, tag: str) -> os.PathLike:
""" """
@ -256,7 +263,6 @@ class TextExecutor(BaseExecutor):
""" """
parser_args = self.parser.parse_args(argv) parser_args = self.parser.parse_args(argv)
text = parser_args.input
task = parser_args.task task = parser_args.task
model_type = parser_args.model model_type = parser_args.model
lang = parser_args.lang lang = parser_args.lang
@ -264,15 +270,28 @@ class TextExecutor(BaseExecutor):
ckpt_path = parser_args.ckpt_path ckpt_path = parser_args.ckpt_path
punc_vocab = parser_args.punc_vocab punc_vocab = parser_args.punc_vocab
device = parser_args.device device = parser_args.device
job_dump_result = parser_args.job_dump_result
task_source = self.get_task_source(parser_args.input)
task_results = OrderedDict()
has_exceptions = False
for id_, input_ in task_source.items():
try: try:
res = self(text, task, model_type, lang, cfg_path, ckpt_path, res = self(input_, task, model_type, lang, cfg_path, ckpt_path,
punc_vocab, device) punc_vocab, device)
logger.info('Text Result:\n{}'.format(res)) task_results[id_] = res
return True
except Exception as e: except Exception as e:
logger.exception(e) has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results,
job_dump_result)
if has_exceptions:
return False return False
else:
return True
@stats_wrapper @stats_wrapper
def __call__( def __call__(

@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import ast
import os import os
from collections import OrderedDict
from typing import Any from typing import Any
from typing import List from typing import List
from typing import Optional from typing import Optional
@ -298,7 +300,7 @@ class TTSExecutor(BaseExecutor):
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(
prog='paddlespeech.tts', add_help=True) prog='paddlespeech.tts', add_help=True)
self.parser.add_argument( self.parser.add_argument(
'--input', type=str, required=True, help='Input text to generate.') '--input', type=str, default=None, help='Input text to generate.')
# acoustic model # acoustic model
self.parser.add_argument( self.parser.add_argument(
'--am', '--am',
@ -397,6 +399,11 @@ class TTSExecutor(BaseExecutor):
self.parser.add_argument( self.parser.add_argument(
'--output', type=str, default='output.wav', help='output file name') '--output', type=str, default='output.wav', help='output file name')
self.parser.add_argument(
'--job_dump_result',
type=ast.literal_eval,
default=False,
help='Save job result into file.')
def _get_pretrained_path(self, tag: str) -> os.PathLike: def _get_pretrained_path(self, tag: str) -> os.PathLike:
""" """
@ -671,7 +678,6 @@ class TTSExecutor(BaseExecutor):
args = self.parser.parse_args(argv) args = self.parser.parse_args(argv)
text = args.input
am = args.am am = args.am
am_config = args.am_config am_config = args.am_config
am_ckpt = args.am_ckpt am_ckpt = args.am_ckpt
@ -686,12 +692,24 @@ class TTSExecutor(BaseExecutor):
voc_stat = args.voc_stat voc_stat = args.voc_stat
lang = args.lang lang = args.lang
device = args.device device = args.device
output = args.output
spk_id = args.spk_id spk_id = args.spk_id
job_dump_result = args.job_dump_result
task_source = self.get_task_source(args.input)
task_results = OrderedDict()
has_exceptions = False
for id_, input_ in task_source.items():
if len(task_source) > 1:
assert isinstance(args.output,
str) and args.output.endswith('.wav')
output = args.output.replace('.wav', f'_{id_}.wav')
else:
output = args.output
try: try:
res = self( res = self(
text=text, text=input_,
# acoustic model related # acoustic model related
am=am, am=am,
am_config=am_config, am_config=am_config,
@ -710,11 +728,17 @@ class TTSExecutor(BaseExecutor):
lang=lang, lang=lang,
device=device, device=device,
output=output) output=output)
logger.info('Wave file has been generated: {}'.format(res)) task_results[id_] = res
return True
except Exception as e: except Exception as e:
logger.exception(e) has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(args.input, task_results, job_dump_result)
if has_exceptions:
return False return False
else:
return True
@stats_wrapper @stats_wrapper
def __call__(self, def __call__(self,

Loading…
Cancel
Save