Merge pull request #1598 from KPatr1ck/cli_rtf

[CLI]Add RTF wrapper for asr.
pull/1879/head
Hui Zhang 3 years ago committed by GitHub
commit b4387ab6bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -14,6 +14,7 @@
import argparse import argparse
import os import os
import sys import sys
import time
from collections import OrderedDict from collections import OrderedDict
from typing import List from typing import List
from typing import Optional from typing import Optional
@ -29,8 +30,10 @@ from ..download import get_path_from_url
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..log import logger from ..log import logger
from ..utils import cli_register from ..utils import cli_register
from ..utils import CLI_TIMER
from ..utils import MODEL_HOME from ..utils import MODEL_HOME
from ..utils import stats_wrapper from ..utils import stats_wrapper
from ..utils import timer_register
from .pretrained_models import model_alias from .pretrained_models import model_alias
from .pretrained_models import pretrained_models from .pretrained_models import pretrained_models
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
@ -41,6 +44,7 @@ from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ['ASRExecutor'] __all__ = ['ASRExecutor']
@timer_register
@cli_register( @cli_register(
name='paddlespeech.asr', description='Speech to text infer command.') name='paddlespeech.asr', description='Speech to text infer command.')
class ASRExecutor(BaseExecutor): class ASRExecutor(BaseExecutor):
@ -99,6 +103,11 @@ class ASRExecutor(BaseExecutor):
default=False, default=False,
help='No additional parameters required. Once set this parameter, it means accepting the request of the program by default, which includes transforming the audio sample rate' help='No additional parameters required. Once set this parameter, it means accepting the request of the program by default, which includes transforming the audio sample rate'
) )
self.parser.add_argument(
'--rtf',
action="store_true",
default=False,
help='Show Real-time Factor(RTF).')
self.parser.add_argument( self.parser.add_argument(
'--device', '--device',
type=str, type=str,
@ -407,6 +416,7 @@ class ASRExecutor(BaseExecutor):
ckpt_path = parser_args.ckpt_path ckpt_path = parser_args.ckpt_path
decode_method = parser_args.decode_method decode_method = parser_args.decode_method
force_yes = parser_args.yes force_yes = parser_args.yes
rtf = parser_args.rtf
device = parser_args.device device = parser_args.device
if not parser_args.verbose: if not parser_args.verbose:
@ -419,12 +429,15 @@ class ASRExecutor(BaseExecutor):
for id_, input_ in task_source.items(): for id_, input_ in task_source.items():
try: try:
res = self(input_, model, lang, sample_rate, config, ckpt_path, res = self(input_, model, lang, sample_rate, config, ckpt_path,
decode_method, force_yes, device) decode_method, force_yes, rtf, device)
task_results[id_] = res task_results[id_] = res
except Exception as e: except Exception as e:
has_exceptions = True has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}' task_results[id_] = f'{e.__class__.__name__}: {e}'
if rtf:
self.show_rtf(CLI_TIMER[self.__class__.__name__])
self.process_task_results(parser_args.input, task_results, self.process_task_results(parser_args.input, task_results,
parser_args.job_dump_result) parser_args.job_dump_result)
@ -443,6 +456,7 @@ class ASRExecutor(BaseExecutor):
ckpt_path: os.PathLike=None, ckpt_path: os.PathLike=None,
decode_method: str='attention_rescoring', decode_method: str='attention_rescoring',
force_yes: bool=False, force_yes: bool=False,
rtf: bool=False,
device=paddle.get_device()): device=paddle.get_device()):
""" """
Python API to call an executor. Python API to call an executor.
@ -453,8 +467,18 @@ class ASRExecutor(BaseExecutor):
paddle.set_device(device) paddle.set_device(device)
self._init_from_path(model, lang, sample_rate, config, decode_method, self._init_from_path(model, lang, sample_rate, config, decode_method,
ckpt_path) ckpt_path)
if rtf:
k = self.__class__.__name__
CLI_TIMER[k]['start'].append(time.time())
self.preprocess(model, audio_file) self.preprocess(model, audio_file)
self.infer(model) self.infer(model)
res = self.postprocess() # Retrieve result of asr. res = self.postprocess() # Retrieve result of asr.
if rtf:
CLI_TIMER[k]['end'].append(time.time())
audio, audio_sample_rate = soundfile.read(
audio_file, dtype="int16", always_2d=True)
CLI_TIMER[k]['extra'].append(audio.shape[0] / audio_sample_rate)
return res return res

@ -235,3 +235,19 @@ class BaseExecutor(ABC):
'Use pretrained model stored in: {}'.format(decompressed_path)) 'Use pretrained model stored in: {}'.format(decompressed_path))
return decompressed_path return decompressed_path
def show_rtf(self, info: Dict[str, List[float]]):
"""
Calculate rft of current task and show results.
"""
num_samples = 0
task_duration = 0.0
wav_duration = 0.0
for start, end, dur in zip(info['start'], info['end'], info['extra']):
num_samples += 1
task_duration += end - start
wav_duration += dur
logger.info('Sample Count: {}'.format(num_samples))
logger.info('Avg RTF: {}'.format(task_duration / wav_duration))

@ -24,11 +24,11 @@ from typing import Any
from typing import Dict from typing import Dict
import paddle import paddle
import paddleaudio
import requests import requests
import yaml import yaml
from paddle.framework import load from paddle.framework import load
import paddleaudio
from . import download from . import download
from .entry import commands from .entry import commands
try: try:
@ -39,6 +39,7 @@ except ImportError:
requests.adapters.DEFAULT_RETRIES = 3 requests.adapters.DEFAULT_RETRIES = 3
__all__ = [ __all__ = [
'timer_register',
'cli_register', 'cli_register',
'get_command', 'get_command',
'download_and_decompress', 'download_and_decompress',
@ -46,6 +47,13 @@ __all__ = [
'stats_wrapper', 'stats_wrapper',
] ]
CLI_TIMER = {}
def timer_register(command):
CLI_TIMER[command.__name__] = {'start': [], 'end': [], 'extra': []}
return command
def cli_register(name: str, description: str='') -> Any: def cli_register(name: str, description: str='') -> Any:
def _warpper(command): def _warpper(command):

Loading…
Cancel
Save