|
|
|
@ -12,8 +12,10 @@
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
import argparse
|
|
|
|
|
import ast
|
|
|
|
|
import os
|
|
|
|
|
import sys
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
from typing import List
|
|
|
|
|
from typing import Optional
|
|
|
|
|
from typing import Union
|
|
|
|
@ -130,7 +132,7 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
|
self.parser = argparse.ArgumentParser(
|
|
|
|
|
prog='paddlespeech.asr', add_help=True)
|
|
|
|
|
self.parser.add_argument(
|
|
|
|
|
'--input', type=str, required=True, help='Audio file to recognize.')
|
|
|
|
|
'--input', type=str, default=None, help='Audio file to recognize.')
|
|
|
|
|
self.parser.add_argument(
|
|
|
|
|
'--model',
|
|
|
|
|
type=str,
|
|
|
|
@ -180,6 +182,11 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
|
type=str,
|
|
|
|
|
default=paddle.get_device(),
|
|
|
|
|
help='Choose device to execute model inference.')
|
|
|
|
|
self.parser.add_argument(
|
|
|
|
|
'--job_dump_result',
|
|
|
|
|
type=ast.literal_eval,
|
|
|
|
|
default=False,
|
|
|
|
|
help='Save job result into file.')
|
|
|
|
|
|
|
|
|
|
def _get_pretrained_path(self, tag: str) -> os.PathLike:
|
|
|
|
|
"""
|
|
|
|
@ -469,19 +476,31 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
|
sample_rate = parser_args.sample_rate
|
|
|
|
|
config = parser_args.config
|
|
|
|
|
ckpt_path = parser_args.ckpt_path
|
|
|
|
|
audio_file = parser_args.input
|
|
|
|
|
decode_method = parser_args.decode_method
|
|
|
|
|
force_yes = parser_args.yes
|
|
|
|
|
device = parser_args.device
|
|
|
|
|
job_dump_result = parser_args.job_dump_result
|
|
|
|
|
|
|
|
|
|
task_source = self.get_task_source(parser_args.input)
|
|
|
|
|
task_results = OrderedDict()
|
|
|
|
|
has_exceptions = False
|
|
|
|
|
|
|
|
|
|
for id_, input_ in task_source.items():
|
|
|
|
|
try:
|
|
|
|
|
res = self(audio_file, model, lang, sample_rate, config, ckpt_path,
|
|
|
|
|
res = self(input_, model, lang, sample_rate, config, ckpt_path,
|
|
|
|
|
decode_method, force_yes, device)
|
|
|
|
|
logger.info('ASR Result: {}'.format(res))
|
|
|
|
|
return True
|
|
|
|
|
task_results[id_] = res
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.exception(e)
|
|
|
|
|
has_exceptions = True
|
|
|
|
|
task_results[id_] = f'{e.__class__.__name__}: {e}'
|
|
|
|
|
|
|
|
|
|
self.process_task_results(parser_args.input, task_results,
|
|
|
|
|
job_dump_result)
|
|
|
|
|
|
|
|
|
|
if has_exceptions:
|
|
|
|
|
return False
|
|
|
|
|
else:
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
@stats_wrapper
|
|
|
|
|
def __call__(self,
|
|
|
|
|