diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 0fb54868..9218cfa5 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -14,6 +14,7 @@ import argparse import os import sys +import time from collections import OrderedDict from typing import List from typing import Optional @@ -29,8 +30,10 @@ from ..download import get_path_from_url from ..executor import BaseExecutor from ..log import logger from ..utils import cli_register +from ..utils import CLI_TIMER from ..utils import MODEL_HOME from ..utils import stats_wrapper +from ..utils import timer_register from .pretrained_models import model_alias from .pretrained_models import pretrained_models from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer @@ -41,6 +44,7 @@ from paddlespeech.s2t.utils.utility import UpdateConfig __all__ = ['ASRExecutor'] +@timer_register @cli_register( name='paddlespeech.asr', description='Speech to text infer command.') class ASRExecutor(BaseExecutor): @@ -99,6 +103,11 @@ class ASRExecutor(BaseExecutor): default=False, help='No additional parameters required. Once set this parameter, it means accepting the request of the program by default, which includes transforming the audio sample rate' ) + self.parser.add_argument( + '--rtf', + action="store_true", + default=False, + help='Show Real-time Factor(RTF).') self.parser.add_argument( '--device', type=str, @@ -407,6 +416,7 @@ class ASRExecutor(BaseExecutor): ckpt_path = parser_args.ckpt_path decode_method = parser_args.decode_method force_yes = parser_args.yes + rtf = parser_args.rtf device = parser_args.device if not parser_args.verbose: @@ -419,12 +429,15 @@ class ASRExecutor(BaseExecutor): for id_, input_ in task_source.items(): try: res = self(input_, model, lang, sample_rate, config, ckpt_path, - decode_method, force_yes, device) + decode_method, force_yes, rtf, device) task_results[id_] = res except Exception as e: has_exceptions = True task_results[id_] = f'{e.__class__.__name__}: {e}' + if rtf: + self.show_rtf(CLI_TIMER[self.__class__.__name__]) + self.process_task_results(parser_args.input, task_results, parser_args.job_dump_result) @@ -443,6 +456,7 @@ class ASRExecutor(BaseExecutor): ckpt_path: os.PathLike=None, decode_method: str='attention_rescoring', force_yes: bool=False, + rtf: bool=False, device=paddle.get_device()): """ Python API to call an executor. @@ -454,7 +468,18 @@ class ASRExecutor(BaseExecutor): self._init_from_path(model, lang, sample_rate, config, decode_method, ckpt_path) self.preprocess(model, audio_file) - self.infer(model) + + if rtf: + k = self.__class__.__name__ + CLI_TIMER[k]['start'].append(time.time()) + self.infer(model) + CLI_TIMER[k]['end'].append(time.time()) + audio, audio_sample_rate = soundfile.read( + audio_file, dtype="int16", always_2d=True) + CLI_TIMER[k]['extra'].append(audio.shape[0] / audio_sample_rate) + else: + self.infer(model) + res = self.postprocess() # Retrieve result of asr. return res diff --git a/paddlespeech/cli/executor.py b/paddlespeech/cli/executor.py index df0b6783..4a631c7f 100644 --- a/paddlespeech/cli/executor.py +++ b/paddlespeech/cli/executor.py @@ -235,3 +235,19 @@ class BaseExecutor(ABC): 'Use pretrained model stored in: {}'.format(decompressed_path)) return decompressed_path + + def show_rtf(self, info: Dict[str, List[float]]): + """ + Calculate rft of current task and show results. + """ + num_samples = 0 + task_duration = 0.0 + wav_duration = 0.0 + + for start, end, dur in zip(info['start'], info['end'], info['extra']): + num_samples += 1 + task_duration += end - start + wav_duration += dur + + logger.info('Sample Count: {}'.format(num_samples)) + logger.info('Avg RTF: {}'.format(task_duration / wav_duration)) diff --git a/paddlespeech/cli/utils.py b/paddlespeech/cli/utils.py index 8e094894..82d40c8b 100644 --- a/paddlespeech/cli/utils.py +++ b/paddlespeech/cli/utils.py @@ -24,11 +24,11 @@ from typing import Any from typing import Dict import paddle -import paddleaudio import requests import yaml from paddle.framework import load +import paddleaudio from . import download from .entry import commands try: @@ -39,6 +39,7 @@ except ImportError: requests.adapters.DEFAULT_RETRIES = 3 __all__ = [ + 'timer_register', 'cli_register', 'get_command', 'download_and_decompress', @@ -46,6 +47,13 @@ __all__ = [ 'stats_wrapper', ] +CLI_TIMER = {} + + +def timer_register(command): + CLI_TIMER[command.__name__] = {'start': [], 'end': [], 'extra': []} + return command + def cli_register(name: str, description: str='') -> Any: def _warpper(command):