Merge pull request #1959 from KPatr1ck/cli_register

[CLI] Dynamic cli commands registration.
pull/1961/head
Hui Zhang 3 years ago committed by GitHub
commit e08fd3c555
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -13,14 +13,7 @@
# limitations under the License. # limitations under the License.
import _locale import _locale
from .asr import ASRExecutor
from .base_commands import BaseCommand from .base_commands import BaseCommand
from .base_commands import HelpCommand from .base_commands import HelpCommand
from .cls import CLSExecutor
from .st import STExecutor
from .stats import StatsExecutor
from .text import TextExecutor
from .tts import TTSExecutor
from .vector import VectorExecutor
_locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8']) _locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8'])

@ -29,7 +29,6 @@ from yacs.config import CfgNode
from ..download import get_path_from_url from ..download import get_path_from_url
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..log import logger from ..log import logger
from ..utils import cli_register
from ..utils import CLI_TIMER from ..utils import CLI_TIMER
from ..utils import MODEL_HOME from ..utils import MODEL_HOME
from ..utils import stats_wrapper from ..utils import stats_wrapper
@ -45,8 +44,6 @@ __all__ = ['ASRExecutor']
@timer_register @timer_register
@cli_register(
name='paddlespeech.asr', description='Speech to text infer command.')
class ASRExecutor(BaseExecutor): class ASRExecutor(BaseExecutor):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

@ -15,6 +15,7 @@ from typing import List
from .entry import commands from .entry import commands
from .utils import cli_register from .utils import cli_register
from .utils import explicit_command_register
from .utils import get_command from .utils import get_command
__all__ = [ __all__ = [
@ -73,3 +74,20 @@ class VersionCommand:
print(msg) print(msg)
return True return True
# Dynamic import when running specific command
_commands = {
'asr': ['Speech to text infer command.', 'ASRExecutor'],
'cls': ['Audio classification infer command.', 'CLSExecutor'],
'st': ['Speech translation infer command.', 'STExecutor'],
'text': ['Text command.', 'TextExecutor'],
'tts': ['Text to Speech infer command.', 'TTSExecutor'],
'vector': ['Speech to vector embedding infer command.', 'VectorExecutor'],
}
for com, info in _commands.items():
explicit_command_register(
name='paddlespeech.{}'.format(com),
description=info[0],
cls='paddlespeech.cli.{}.{}'.format(com, info[1]))

@ -27,7 +27,6 @@ from paddlespeech.utils.dynamic_import import dynamic_import
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..log import logger from ..log import logger
from ..utils import cli_register
from ..utils import stats_wrapper from ..utils import stats_wrapper
from .pretrained_models import model_alias from .pretrained_models import model_alias
from .pretrained_models import pretrained_models from .pretrained_models import pretrained_models
@ -36,8 +35,6 @@ from .pretrained_models import pretrained_models
__all__ = ['CLSExecutor'] __all__ = ['CLSExecutor']
@cli_register(
name='paddlespeech.cls', description='Audio classification infer command.')
class CLSExecutor(BaseExecutor): class CLSExecutor(BaseExecutor):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -246,4 +243,4 @@ class CLSExecutor(BaseExecutor):
self.infer() self.infer()
res = self.postprocess(topk) # Retrieve result of cls. res = self.postprocess(topk) # Retrieve result of cls.
return res return res

@ -34,6 +34,11 @@ def _execute():
# The method 'execute' of a command instance returns 'True' for a success # The method 'execute' of a command instance returns 'True' for a success
# while 'False' for a failure. Here converts this result into a exit status # while 'False' for a failure. Here converts this result into a exit status
# in bash: 0 for a success and 1 for a failure. # in bash: 0 for a success and 1 for a failure.
if not callable(com['_entry']):
i = com['_entry'].rindex('.')
module, cls = com['_entry'][:i], com['_entry'][i + 1:]
exec("from {} import {}".format(module, cls))
com['_entry'] = locals()[cls]
status = 0 if com['_entry']().execute(sys.argv[idx:]) else 1 status = 0 if com['_entry']().execute(sys.argv[idx:]) else 1
return status return status

@ -28,7 +28,6 @@ from yacs.config import CfgNode
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..log import logger from ..log import logger
from ..utils import cli_register
from ..utils import download_and_decompress from ..utils import download_and_decompress
from ..utils import MODEL_HOME from ..utils import MODEL_HOME
from ..utils import stats_wrapper from ..utils import stats_wrapper
@ -42,8 +41,6 @@ from paddlespeech.utils.dynamic_import import dynamic_import
__all__ = ["STExecutor"] __all__ = ["STExecutor"]
@cli_register(
name="paddlespeech.st", description="Speech translation infer command.")
class STExecutor(BaseExecutor): class STExecutor(BaseExecutor):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

@ -23,7 +23,6 @@ import paddle
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..log import logger from ..log import logger
from ..utils import cli_register
from ..utils import stats_wrapper from ..utils import stats_wrapper
from .pretrained_models import model_alias from .pretrained_models import model_alias
from .pretrained_models import pretrained_models from .pretrained_models import pretrained_models
@ -33,7 +32,6 @@ from paddlespeech.utils.dynamic_import import dynamic_import
__all__ = ['TextExecutor'] __all__ = ['TextExecutor']
@cli_register(name='paddlespeech.text', description='Text infer command.')
class TextExecutor(BaseExecutor): class TextExecutor(BaseExecutor):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

@ -28,7 +28,6 @@ from yacs.config import CfgNode
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..log import logger from ..log import logger
from ..utils import cli_register
from ..utils import stats_wrapper from ..utils import stats_wrapper
from .pretrained_models import model_alias from .pretrained_models import model_alias
from .pretrained_models import pretrained_models from .pretrained_models import pretrained_models
@ -40,8 +39,6 @@ from paddlespeech.utils.dynamic_import import dynamic_import
__all__ = ['TTSExecutor'] __all__ = ['TTSExecutor']
@cli_register(
name='paddlespeech.tts', description='Text to Speech infer command.')
class TTSExecutor(BaseExecutor): class TTSExecutor(BaseExecutor):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

@ -41,6 +41,7 @@ requests.adapters.DEFAULT_RETRIES = 3
__all__ = [ __all__ = [
'timer_register', 'timer_register',
'cli_register', 'cli_register',
'explicit_command_register',
'get_command', 'get_command',
'download_and_decompress', 'download_and_decompress',
'load_state_dict_from_url', 'load_state_dict_from_url',
@ -70,6 +71,16 @@ def cli_register(name: str, description: str='') -> Any:
return _warpper return _warpper
def explicit_command_register(name: str, description: str='', cls: str=''):
items = name.split('.')
com = commands
for item in items:
com = com[item]
com['_entry'] = cls
if description:
com['_description'] = description
def get_command(name: str) -> Any: def get_command(name: str) -> Any:
items = name.split('.') items = name.split('.')
com = commands com = commands

@ -28,7 +28,6 @@ from yacs.config import CfgNode
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..log import logger from ..log import logger
from ..utils import cli_register
from ..utils import stats_wrapper from ..utils import stats_wrapper
from .pretrained_models import model_alias from .pretrained_models import model_alias
from .pretrained_models import pretrained_models from .pretrained_models import pretrained_models
@ -37,9 +36,6 @@ from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.modules.sid_model import SpeakerIdetification from paddlespeech.vector.modules.sid_model import SpeakerIdetification
@cli_register(
name="paddlespeech.vector",
description="Speech to vector embedding infer command.")
class VectorExecutor(BaseExecutor): class VectorExecutor(BaseExecutor):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -476,4 +472,4 @@ class VectorExecutor(BaseExecutor):
else: else:
logger.info("The audio file format is right") logger.info("The audio file format is right")
return True return True

Loading…
Cancel
Save