From e9798498d686e568d4d3488952f8cd2abec9a05f Mon Sep 17 00:00:00 2001 From: KP <109694228@qq.com> Date: Mon, 29 Nov 2021 18:01:39 +0800 Subject: [PATCH] Update asr inference in paddlespeech.cli. --- paddlespeech/cli/executor.py | 9 +-- paddlespeech/cli/s2t/conf/default_conf.yaml | 0 paddlespeech/cli/s2t/infer.py | 67 +++++++++++++--- paddlespeech/cli/utils.py | 86 ++++++++++++++++++--- 4 files changed, 136 insertions(+), 26 deletions(-) delete mode 100644 paddlespeech/cli/s2t/conf/default_conf.yaml diff --git a/paddlespeech/cli/executor.py b/paddlespeech/cli/executor.py index 45472fa4b..2261e011b 100644 --- a/paddlespeech/cli/executor.py +++ b/paddlespeech/cli/executor.py @@ -14,7 +14,6 @@ import os from abc import ABC from abc import abstractmethod -from typing import Optional from typing import Union import paddle @@ -30,16 +29,16 @@ class BaseExecutor(ABC): self.output = None @abstractmethod - def _get_default_cfg_path(self): + def _get_pretrained_path(self, tag: str) -> os.PathLike: """ - Returns a default config file path of current task. + Download and returns pretrained resources path of current task. """ pass @abstractmethod - def _init_from_cfg(self, cfg_path: Optional[os.PathLike]=None): + def _init_from_path(self, *args, **kwargs): """ - Init model from a specific config file. + Init model and other resources from a specific path. """ pass diff --git a/paddlespeech/cli/s2t/conf/default_conf.yaml b/paddlespeech/cli/s2t/conf/default_conf.yaml deleted file mode 100644 index e69de29bb..000000000 diff --git a/paddlespeech/cli/s2t/infer.py b/paddlespeech/cli/s2t/infer.py index 682279852..6aa29addf 100644 --- a/paddlespeech/cli/s2t/infer.py +++ b/paddlespeech/cli/s2t/infer.py @@ -21,9 +21,21 @@ import paddle from ..executor import BaseExecutor from ..utils import cli_register +from ..utils import download_and_decompress +from ..utils import logger +from ..utils import MODEL_HOME __all__ = ['S2TExecutor'] +pretrained_models = { + "wenetspeech_zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz', + 'md5': + '54e7a558a6e020c2f5fb224874943f97', + } +} + @cli_register( name='paddlespeech.s2t', description='Speech to text infer command.') @@ -33,11 +45,23 @@ class S2TExecutor(BaseExecutor): self.parser = argparse.ArgumentParser( prog='paddlespeech.s2t', add_help=True) + self.parser.add_argument( + '--model', + type=str, + default='wenetspeech', + help='Choose model type of asr task.') + self.parser.add_argument( + '--lang', type=str, default='zh', help='Choose model language.') self.parser.add_argument( '--config', type=str, default=None, help='Config of s2t task. Use deault config when it is None.') + self.parser.add_argument( + '--ckpt_path', + type=str, + default=None, + help='Checkpoint file of model.') self.parser.add_argument( '--input', type=str, help='Audio file to recognize.') self.parser.add_argument( @@ -46,16 +70,39 @@ class S2TExecutor(BaseExecutor): default='cpu', help='Choose device to execute model inference.') - def _get_default_cfg_path(self): + def _get_pretrained_path(self, tag: str) -> os.PathLike: """ - Returns a default config file path of current task. + Download and returns pretrained resources path of current task. """ - pass + assert tag in pretrained_models, 'Can not find pretrained resources of {}.'.format( + tag) + + res_path = os.path.join(MODEL_HOME, tag) + decompressed_path = download_and_decompress(pretrained_models[tag], + res_path) + logger.info( + 'Use pretrained model stored in: {}'.format(decompressed_path)) + return decompressed_path - def _init_from_cfg(self, cfg_path: Optional[os.PathLike]=None): + def _init_from_path(self, + model_type: str='wenetspeech', + lang: str='zh', + cfg_path: Optional[os.PathLike]=None, + ckpt_path: Optional[os.PathLike]=None): """ - Init model from a specific config file. + Init model and other resources from a specific path. """ + if cfg_path is None or ckpt_path is None: + res_path = self._get_pretrained_path( + model_type + '_' + lang) # wenetspeech_zh + cfg_path = os.path.join(res_path, 'conf/conformer.yaml') + ckpt_path = os.path.join( + res_path, 'exp/conformer/checkpoints/wenetspeech.pdparams') + logger.info(res_path) + logger.info(cfg_path) + logger.info(ckpt_path) + + # Init body. pass def preprocess(self, input: Union[str, os.PathLike]): @@ -82,17 +129,15 @@ class S2TExecutor(BaseExecutor): parser_args = self.parser.parse_args(argv) print(parser_args) + model = parser_args.model + lang = parser_args.lang config = parser_args.config + ckpt_path = parser_args.ckpt_path audio_file = parser_args.input device = parser_args.device - if config is not None: - assert os.path.isfile(config), 'Config file is not valid.' - else: - config = self._get_default_cfg_path() - try: - self._init_from_cfg(config) + self._init_from_path(model, lang, config, ckpt_path) self.preprocess(audio_file) self.infer() res = self.postprocess() # Retrieve result of s2t. diff --git a/paddlespeech/cli/utils.py b/paddlespeech/cli/utils.py index c83deee89..edf579f71 100644 --- a/paddlespeech/cli/utils.py +++ b/paddlespeech/cli/utils.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools +import logging import os from typing import Any from typing import Dict -from typing import List from paddle.framework import load from paddle.utils import download @@ -26,6 +27,7 @@ __all__ = [ 'get_command', 'download_and_decompress', 'load_state_dict_from_url', + 'logger', ] @@ -53,29 +55,27 @@ def get_command(name: str) -> Any: return com['_entry'] -def decompress(file: str): +def decompress(file: str) -> os.PathLike: """ Extracts all files from a compressed file. """ assert os.path.isfile(file), "File: {} not exists.".format(file) - download._decompress(file) + return download._decompress(file) -def download_and_decompress(archives: List[Dict[str, str]], path: str): +def download_and_decompress(archive: Dict[str, str], path: str) -> os.PathLike: """ Download archieves and decompress to specific path. """ if not os.path.isdir(path): os.makedirs(path) - for archive in archives: - assert 'url' in archive and 'md5' in archive, \ - 'Dictionary keys of "url" and "md5" are required in the archive, but got: {list(archieve.keys())}' + assert 'url' in archive and 'md5' in archive, \ + 'Dictionary keys of "url" and "md5" are required in the archive, but got: {}'.format(list(archive.keys())) + return download.get_path_from_url(archive['url'], path, archive['md5']) - download.get_path_from_url(archive['url'], path, archive['md5']) - -def load_state_dict_from_url(url: str, path: str, md5: str=None): +def load_state_dict_from_url(url: str, path: str, md5: str=None) -> os.PathLike: """ Download and load a state dict from url """ @@ -84,3 +84,69 @@ def load_state_dict_from_url(url: str, path: str, md5: str=None): download.get_path_from_url(url, path, md5) return load(os.path.join(path, os.path.basename(url))) + + +def _get_user_home(): + return os.path.expanduser('~') + + +def _get_paddlespcceh_home(): + if 'PPSPEECH_HOME' in os.environ: + home_path = os.environ['PPSPEECH_HOME'] + if os.path.exists(home_path): + if os.path.isdir(home_path): + return home_path + else: + raise RuntimeError( + 'The environment variable PPSPEECH_HOME {} is not a directory.'. + format(home_path)) + else: + return home_path + return os.path.join(_get_user_home(), '.paddlespeech') + + +def _get_sub_home(directory): + home = os.path.join(_get_paddlespcceh_home(), directory) + if not os.path.exists(home): + os.makedirs(home) + return home + + +PPSPEECH_HOME = _get_paddlespcceh_home() +MODEL_HOME = _get_sub_home('models') + + +class Logger(object): + def __init__(self, name: str=None): + name = 'PaddleSpeech' if not name else name + self.logger = logging.getLogger(name) + + log_config = { + 'DEBUG': 10, + 'INFO': 20, + 'TRAIN': 21, + 'EVAL': 22, + 'WARNING': 30, + 'ERROR': 40, + 'CRITICAL': 50 + } + for key, level in log_config.items(): + logging.addLevelName(level, key) + self.__dict__[key.lower()] = functools.partial(self.__call__, level) + + self.format = logging.Formatter( + fmt='[%(asctime)-15s] [%(levelname)8s] [%(filename)s] [L%(lineno)d] - %(message)s' + ) + + self.handler = logging.StreamHandler() + self.handler.setFormatter(self.format) + + self.logger.addHandler(self.handler) + self.logger.setLevel(logging.DEBUG) + self.logger.propagate = False + + def __call__(self, log_level: str, msg: str): + self.logger.log(log_level, msg) + + +logger = Logger()