Update batch input.

pull/1460/head
KP 2 years ago
parent 1bb9a04aef
commit 7814fba07f

@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import os
import sys
from collections import OrderedDict
from typing import List
from typing import Optional
from typing import Union
@ -130,7 +132,7 @@ class ASRExecutor(BaseExecutor):
self.parser = argparse.ArgumentParser(
prog='paddlespeech.asr', add_help=True)
self.parser.add_argument(
'--input', type=str, required=True, help='Audio file to recognize.')
'--input', type=str, default=None, help='Audio file to recognize.')
self.parser.add_argument(
'--model',
type=str,
@ -180,6 +182,11 @@ class ASRExecutor(BaseExecutor):
type=str,
default=paddle.get_device(),
help='Choose device to execute model inference.')
self.parser.add_argument(
'--job_dump_result',
type=ast.literal_eval,
default=False,
help='Save job result into file.')
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
@ -469,19 +476,31 @@ class ASRExecutor(BaseExecutor):
sample_rate = parser_args.sample_rate
config = parser_args.config
ckpt_path = parser_args.ckpt_path
audio_file = parser_args.input
decode_method = parser_args.decode_method
force_yes = parser_args.yes
device = parser_args.device
job_dump_result = parser_args.job_dump_result
try:
res = self(audio_file, model, lang, sample_rate, config, ckpt_path,
decode_method, force_yes, device)
logger.info('ASR Result: {}'.format(res))
return True
except Exception as e:
logger.exception(e)
task_source = self.get_task_source(parser_args.input)
task_results = OrderedDict()
has_exceptions = False
for id_, input_ in task_source.items():
try:
res = self(input_, model, lang, sample_rate, config, ckpt_path,
decode_method, force_yes, device)
task_results[id_] = res
except Exception as e:
has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results,
job_dump_result)
if has_exceptions:
return False
else:
return True
@stats_wrapper
def __call__(self,

@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import os
import subprocess
from collections import OrderedDict
from typing import List
from typing import Optional
from typing import Union
@ -69,7 +71,7 @@ class STExecutor(BaseExecutor):
self.parser = argparse.ArgumentParser(
prog="paddlespeech.st", add_help=True)
self.parser.add_argument(
"--input", type=str, required=True, help="Audio file to translate.")
"--input", type=str, default=None, help="Audio file to translate.")
self.parser.add_argument(
"--model",
type=str,
@ -107,6 +109,11 @@ class STExecutor(BaseExecutor):
type=str,
default=paddle.get_device(),
help="Choose device to execute model inference.")
self.parser.add_argument(
'--job_dump_result',
type=ast.literal_eval,
default=False,
help='Save job result into file.')
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
@ -319,17 +326,29 @@ class STExecutor(BaseExecutor):
sample_rate = parser_args.sample_rate
config = parser_args.config
ckpt_path = parser_args.ckpt_path
audio_file = parser_args.input
device = parser_args.device
job_dump_result = parser_args.job_dump_result
try:
res = self(audio_file, model, src_lang, tgt_lang, sample_rate,
config, ckpt_path, device)
logger.info("ST Result: {}".format(res))
return True
except Exception as e:
logger.exception(e)
task_source = self.get_task_source(parser_args.input)
task_results = OrderedDict()
has_exceptions = False
for id_, input_ in task_source.items():
try:
res = self(input_, model, src_lang, tgt_lang, sample_rate,
config, ckpt_path, device)
task_results[id_] = res
except Exception as e:
has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results,
job_dump_result)
if has_exceptions:
return False
else:
return True
@stats_wrapper
def __call__(self,

@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import os
import re
from collections import OrderedDict
from typing import List
from typing import Optional
from typing import Union
@ -80,7 +82,7 @@ class TextExecutor(BaseExecutor):
self.parser = argparse.ArgumentParser(
prog='paddlespeech.text', add_help=True)
self.parser.add_argument(
'--input', type=str, required=True, help='Input text.')
'--input', type=str, default=None, help='Input text.')
self.parser.add_argument(
'--task',
type=str,
@ -119,6 +121,11 @@ class TextExecutor(BaseExecutor):
type=str,
default=paddle.get_device(),
help='Choose device to execute model inference.')
self.parser.add_argument(
'--job_dump_result',
type=ast.literal_eval,
default=False,
help='Save job result into file.')
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
@ -256,7 +263,6 @@ class TextExecutor(BaseExecutor):
"""
parser_args = self.parser.parse_args(argv)
text = parser_args.input
task = parser_args.task
model_type = parser_args.model
lang = parser_args.lang
@ -264,15 +270,28 @@ class TextExecutor(BaseExecutor):
ckpt_path = parser_args.ckpt_path
punc_vocab = parser_args.punc_vocab
device = parser_args.device
job_dump_result = parser_args.job_dump_result
try:
res = self(text, task, model_type, lang, cfg_path, ckpt_path,
punc_vocab, device)
logger.info('Text Result:\n{}'.format(res))
return True
except Exception as e:
logger.exception(e)
task_source = self.get_task_source(parser_args.input)
task_results = OrderedDict()
has_exceptions = False
for id_, input_ in task_source.items():
try:
res = self(input_, task, model_type, lang, cfg_path, ckpt_path,
punc_vocab, device)
task_results[id_] = res
except Exception as e:
has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results,
job_dump_result)
if has_exceptions:
return False
else:
return True
@stats_wrapper
def __call__(

@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import os
from collections import OrderedDict
from typing import Any
from typing import List
from typing import Optional
@ -298,7 +300,7 @@ class TTSExecutor(BaseExecutor):
self.parser = argparse.ArgumentParser(
prog='paddlespeech.tts', add_help=True)
self.parser.add_argument(
'--input', type=str, required=True, help='Input text to generate.')
'--input', type=str, default=None, help='Input text to generate.')
# acoustic model
self.parser.add_argument(
'--am',
@ -397,6 +399,11 @@ class TTSExecutor(BaseExecutor):
self.parser.add_argument(
'--output', type=str, default='output.wav', help='output file name')
self.parser.add_argument(
'--job_dump_result',
type=ast.literal_eval,
default=False,
help='Save job result into file.')
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
@ -671,7 +678,6 @@ class TTSExecutor(BaseExecutor):
args = self.parser.parse_args(argv)
text = args.input
am = args.am
am_config = args.am_config
am_ckpt = args.am_ckpt
@ -686,35 +692,53 @@ class TTSExecutor(BaseExecutor):
voc_stat = args.voc_stat
lang = args.lang
device = args.device
output = args.output
spk_id = args.spk_id
job_dump_result = args.job_dump_result
try:
res = self(
text=text,
# acoustic model related
am=am,
am_config=am_config,
am_ckpt=am_ckpt,
am_stat=am_stat,
phones_dict=phones_dict,
tones_dict=tones_dict,
speaker_dict=speaker_dict,
spk_id=spk_id,
# vocoder related
voc=voc,
voc_config=voc_config,
voc_ckpt=voc_ckpt,
voc_stat=voc_stat,
# other
lang=lang,
device=device,
output=output)
logger.info('Wave file has been generated: {}'.format(res))
return True
except Exception as e:
logger.exception(e)
task_source = self.get_task_source(args.input)
task_results = OrderedDict()
has_exceptions = False
for id_, input_ in task_source.items():
if len(task_source) > 1:
assert isinstance(args.output,
str) and args.output.endswith('.wav')
output = args.output.replace('.wav', f'_{id_}.wav')
else:
output = args.output
try:
res = self(
text=input_,
# acoustic model related
am=am,
am_config=am_config,
am_ckpt=am_ckpt,
am_stat=am_stat,
phones_dict=phones_dict,
tones_dict=tones_dict,
speaker_dict=speaker_dict,
spk_id=spk_id,
# vocoder related
voc=voc,
voc_config=voc_config,
voc_ckpt=voc_ckpt,
voc_stat=voc_stat,
# other
lang=lang,
device=device,
output=output)
task_results[id_] = res
except Exception as e:
has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(args.input, task_results, job_dump_result)
if has_exceptions:
return False
else:
return True
@stats_wrapper
def __call__(self,

Loading…
Cancel
Save