diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index ef769fbc..72edd9d1 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import ast import os import sys +from collections import OrderedDict from typing import List from typing import Optional from typing import Union @@ -130,7 +132,7 @@ class ASRExecutor(BaseExecutor): self.parser = argparse.ArgumentParser( prog='paddlespeech.asr', add_help=True) self.parser.add_argument( - '--input', type=str, required=True, help='Audio file to recognize.') + '--input', type=str, default=None, help='Audio file to recognize.') self.parser.add_argument( '--model', type=str, @@ -180,6 +182,11 @@ class ASRExecutor(BaseExecutor): type=str, default=paddle.get_device(), help='Choose device to execute model inference.') + self.parser.add_argument( + '--job_dump_result', + type=ast.literal_eval, + default=False, + help='Save job result into file.') def _get_pretrained_path(self, tag: str) -> os.PathLike: """ @@ -469,19 +476,31 @@ class ASRExecutor(BaseExecutor): sample_rate = parser_args.sample_rate config = parser_args.config ckpt_path = parser_args.ckpt_path - audio_file = parser_args.input decode_method = parser_args.decode_method force_yes = parser_args.yes device = parser_args.device + job_dump_result = parser_args.job_dump_result - try: - res = self(audio_file, model, lang, sample_rate, config, ckpt_path, - decode_method, force_yes, device) - logger.info('ASR Result: {}'.format(res)) - return True - except Exception as e: - logger.exception(e) + task_source = self.get_task_source(parser_args.input) + task_results = OrderedDict() + has_exceptions = False + + for id_, input_ in task_source.items(): + try: + res = self(input_, model, lang, sample_rate, config, ckpt_path, + decode_method, force_yes, device) + task_results[id_] = res + except Exception as e: + has_exceptions = True + task_results[id_] = f'{e.__class__.__name__}: {e}' + + self.process_task_results(parser_args.input, task_results, + job_dump_result) + + if has_exceptions: return False + else: + return True @stats_wrapper def __call__(self, diff --git a/paddlespeech/cli/st/infer.py b/paddlespeech/cli/st/infer.py index 1709c754..a11509ea 100644 --- a/paddlespeech/cli/st/infer.py +++ b/paddlespeech/cli/st/infer.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import ast import os import subprocess +from collections import OrderedDict from typing import List from typing import Optional from typing import Union @@ -69,7 +71,7 @@ class STExecutor(BaseExecutor): self.parser = argparse.ArgumentParser( prog="paddlespeech.st", add_help=True) self.parser.add_argument( - "--input", type=str, required=True, help="Audio file to translate.") + "--input", type=str, default=None, help="Audio file to translate.") self.parser.add_argument( "--model", type=str, @@ -107,6 +109,11 @@ class STExecutor(BaseExecutor): type=str, default=paddle.get_device(), help="Choose device to execute model inference.") + self.parser.add_argument( + '--job_dump_result', + type=ast.literal_eval, + default=False, + help='Save job result into file.') def _get_pretrained_path(self, tag: str) -> os.PathLike: """ @@ -319,17 +326,29 @@ class STExecutor(BaseExecutor): sample_rate = parser_args.sample_rate config = parser_args.config ckpt_path = parser_args.ckpt_path - audio_file = parser_args.input device = parser_args.device + job_dump_result = parser_args.job_dump_result - try: - res = self(audio_file, model, src_lang, tgt_lang, sample_rate, - config, ckpt_path, device) - logger.info("ST Result: {}".format(res)) - return True - except Exception as e: - logger.exception(e) + task_source = self.get_task_source(parser_args.input) + task_results = OrderedDict() + has_exceptions = False + + for id_, input_ in task_source.items(): + try: + res = self(input_, model, src_lang, tgt_lang, sample_rate, + config, ckpt_path, device) + task_results[id_] = res + except Exception as e: + has_exceptions = True + task_results[id_] = f'{e.__class__.__name__}: {e}' + + self.process_task_results(parser_args.input, task_results, + job_dump_result) + + if has_exceptions: return False + else: + return True @stats_wrapper def __call__(self, diff --git a/paddlespeech/cli/text/infer.py b/paddlespeech/cli/text/infer.py index b0977c88..cc902be2 100644 --- a/paddlespeech/cli/text/infer.py +++ b/paddlespeech/cli/text/infer.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import ast import os import re +from collections import OrderedDict from typing import List from typing import Optional from typing import Union @@ -80,7 +82,7 @@ class TextExecutor(BaseExecutor): self.parser = argparse.ArgumentParser( prog='paddlespeech.text', add_help=True) self.parser.add_argument( - '--input', type=str, required=True, help='Input text.') + '--input', type=str, default=None, help='Input text.') self.parser.add_argument( '--task', type=str, @@ -119,6 +121,11 @@ class TextExecutor(BaseExecutor): type=str, default=paddle.get_device(), help='Choose device to execute model inference.') + self.parser.add_argument( + '--job_dump_result', + type=ast.literal_eval, + default=False, + help='Save job result into file.') def _get_pretrained_path(self, tag: str) -> os.PathLike: """ @@ -256,7 +263,6 @@ class TextExecutor(BaseExecutor): """ parser_args = self.parser.parse_args(argv) - text = parser_args.input task = parser_args.task model_type = parser_args.model lang = parser_args.lang @@ -264,15 +270,28 @@ class TextExecutor(BaseExecutor): ckpt_path = parser_args.ckpt_path punc_vocab = parser_args.punc_vocab device = parser_args.device + job_dump_result = parser_args.job_dump_result - try: - res = self(text, task, model_type, lang, cfg_path, ckpt_path, - punc_vocab, device) - logger.info('Text Result:\n{}'.format(res)) - return True - except Exception as e: - logger.exception(e) + task_source = self.get_task_source(parser_args.input) + task_results = OrderedDict() + has_exceptions = False + + for id_, input_ in task_source.items(): + try: + res = self(input_, task, model_type, lang, cfg_path, ckpt_path, + punc_vocab, device) + task_results[id_] = res + except Exception as e: + has_exceptions = True + task_results[id_] = f'{e.__class__.__name__}: {e}' + + self.process_task_results(parser_args.input, task_results, + job_dump_result) + + if has_exceptions: return False + else: + return True @stats_wrapper def __call__( diff --git a/paddlespeech/cli/tts/infer.py b/paddlespeech/cli/tts/infer.py index dfd6a42f..3f650c40 100644 --- a/paddlespeech/cli/tts/infer.py +++ b/paddlespeech/cli/tts/infer.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import ast import os +from collections import OrderedDict from typing import Any from typing import List from typing import Optional @@ -298,7 +300,7 @@ class TTSExecutor(BaseExecutor): self.parser = argparse.ArgumentParser( prog='paddlespeech.tts', add_help=True) self.parser.add_argument( - '--input', type=str, required=True, help='Input text to generate.') + '--input', type=str, default=None, help='Input text to generate.') # acoustic model self.parser.add_argument( '--am', @@ -397,6 +399,11 @@ class TTSExecutor(BaseExecutor): self.parser.add_argument( '--output', type=str, default='output.wav', help='output file name') + self.parser.add_argument( + '--job_dump_result', + type=ast.literal_eval, + default=False, + help='Save job result into file.') def _get_pretrained_path(self, tag: str) -> os.PathLike: """ @@ -671,7 +678,6 @@ class TTSExecutor(BaseExecutor): args = self.parser.parse_args(argv) - text = args.input am = args.am am_config = args.am_config am_ckpt = args.am_ckpt @@ -686,35 +692,53 @@ class TTSExecutor(BaseExecutor): voc_stat = args.voc_stat lang = args.lang device = args.device - output = args.output spk_id = args.spk_id + job_dump_result = args.job_dump_result - try: - res = self( - text=text, - # acoustic model related - am=am, - am_config=am_config, - am_ckpt=am_ckpt, - am_stat=am_stat, - phones_dict=phones_dict, - tones_dict=tones_dict, - speaker_dict=speaker_dict, - spk_id=spk_id, - # vocoder related - voc=voc, - voc_config=voc_config, - voc_ckpt=voc_ckpt, - voc_stat=voc_stat, - # other - lang=lang, - device=device, - output=output) - logger.info('Wave file has been generated: {}'.format(res)) - return True - except Exception as e: - logger.exception(e) + task_source = self.get_task_source(args.input) + task_results = OrderedDict() + has_exceptions = False + + for id_, input_ in task_source.items(): + if len(task_source) > 1: + assert isinstance(args.output, + str) and args.output.endswith('.wav') + output = args.output.replace('.wav', f'_{id_}.wav') + else: + output = args.output + + try: + res = self( + text=input_, + # acoustic model related + am=am, + am_config=am_config, + am_ckpt=am_ckpt, + am_stat=am_stat, + phones_dict=phones_dict, + tones_dict=tones_dict, + speaker_dict=speaker_dict, + spk_id=spk_id, + # vocoder related + voc=voc, + voc_config=voc_config, + voc_ckpt=voc_ckpt, + voc_stat=voc_stat, + # other + lang=lang, + device=device, + output=output) + task_results[id_] = res + except Exception as e: + has_exceptions = True + task_results[id_] = f'{e.__class__.__name__}: {e}' + + self.process_task_results(args.input, task_results, job_dump_result) + + if has_exceptions: return False + else: + return True @stats_wrapper def __call__(self,