From 29da318379d33e5d0a049f0fa77765e9bf4e88b2 Mon Sep 17 00:00:00 2001 From: KP <109694228@qq.com> Date: Tue, 7 Dec 2021 15:14:28 +0800 Subject: [PATCH] Add audio classification cli. --- paddlespeech/cli/__init__.py | 1 + paddlespeech/cli/asr/infer.py | 4 +- paddlespeech/cli/cls/__init.__py | 0 paddlespeech/cli/cls/__init__.py | 14 ++ paddlespeech/cli/cls/infer.py | 252 +++++++++++++++++++++++++++++++ paddlespeech/cli/utils.py | 16 +- requirements.txt | 1 + 7 files changed, 283 insertions(+), 5 deletions(-) delete mode 100644 paddlespeech/cli/cls/__init.__py create mode 100644 paddlespeech/cli/cls/__init__.py create mode 100644 paddlespeech/cli/cls/infer.py diff --git a/paddlespeech/cli/__init__.py b/paddlespeech/cli/__init__.py index 7e0329041..0aedab576 100644 --- a/paddlespeech/cli/__init__.py +++ b/paddlespeech/cli/__init__.py @@ -14,3 +14,4 @@ from .asr import ASRExecutor from .base_commands import BaseCommand from .base_commands import HelpCommand +from .cls import CLSExecutor diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index b40516e95..3ff5be12e 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -361,7 +361,7 @@ class ASRExecutor(BaseExecutor): audio, audio_sample_rate = soundfile.read( audio_file, dtype="int16", always_2d=True) except Exception as e: - logger.error(str(e)) + logger.exception(e) logger.error( "can not open the audio file, please check the audio file format is 'wav'. \n \ you can try to use sox to change the file format.\n \ @@ -421,7 +421,7 @@ class ASRExecutor(BaseExecutor): logger.info('ASR Result: {}'.format(res)) return True except Exception as e: - print(e) + logger.exception(e) return False def __call__(self, model, lang, sample_rate, config, ckpt_path, audio_file, diff --git a/paddlespeech/cli/cls/__init.__py b/paddlespeech/cli/cls/__init.__py deleted file mode 100644 index e69de29bb..000000000 diff --git a/paddlespeech/cli/cls/__init__.py b/paddlespeech/cli/cls/__init__.py new file mode 100644 index 000000000..13e316f8f --- /dev/null +++ b/paddlespeech/cli/cls/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .infer import CLSExecutor diff --git a/paddlespeech/cli/cls/infer.py b/paddlespeech/cli/cls/infer.py new file mode 100644 index 000000000..2b6260319 --- /dev/null +++ b/paddlespeech/cli/cls/infer.py @@ -0,0 +1,252 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +from typing import List +from typing import Optional +from typing import Union + +import numpy as np +import paddle +import yaml + +from ..executor import BaseExecutor +from ..utils import cli_register +from ..utils import download_and_decompress +from ..utils import logger +from ..utils import MODEL_HOME +from paddleaudio import load +from paddleaudio.features import LogMelSpectrogram +from paddlespeech.s2t.utils.dynamic_import import dynamic_import + +__all__ = ['CLSExecutor'] + +pretrained_models = { + "panns_cnn6": { + 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn6.tar.gz', + 'md5': '051b30c56bcb9a3dd67bc205cc12ffd2', + 'cfg_path': 'panns.yaml', + 'ckpt_path': 'cnn6.pdparams', + 'label_file': 'audioset_labels.txt', + }, + "panns_cnn10": { + 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn10.tar.gz', + 'md5': '97c6f25587685379b1ebcd4c1f624927', + 'cfg_path': 'panns.yaml', + 'ckpt_path': 'cnn10.pdparams', + 'label_file': 'audioset_labels.txt', + }, + "panns_cnn14": { + 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn14.tar.gz', + 'md5': 'e3b9b5614a1595001161d0ab95edee97', + 'cfg_path': 'panns.yaml', + 'ckpt_path': 'cnn14.pdparams', + 'label_file': 'audioset_labels.txt', + }, +} + +model_alias = { + "panns_cnn6": "paddlespeech.cls.models.panns:CNN6", + "panns_cnn10": "paddlespeech.cls.models.panns:CNN10", + "panns_cnn14": "paddlespeech.cls.models.panns:CNN14", +} + + +@cli_register( + name='paddlespeech.cls', description='Audio classification infer command.') +class CLSExecutor(BaseExecutor): + def __init__(self): + super(CLSExecutor, self).__init__() + + self.parser = argparse.ArgumentParser( + prog='paddlespeech.cls', add_help=True) + self.parser.add_argument( + '--input', type=str, required=True, help='Audio file to classify.') + self.parser.add_argument( + '--model', + type=str, + default='panns_cnn14', + help='Choose model type of cls task.') + self.parser.add_argument( + '--config', + type=str, + default=None, + help='Config of cls task. Use deault config when it is None.') + self.parser.add_argument( + '--ckpt_path', + type=str, + default=None, + help='Checkpoint file of model.') + self.parser.add_argument( + '--label_file', + type=str, + default=None, + help='Label file of cls task.') + self.parser.add_argument( + '--topk', + type=int, + default=1, + help='Return topk scores of classification result.') + self.parser.add_argument( + '--device', + type=str, + default=paddle.get_device(), + help='Choose device to execute model inference.') + + def _get_pretrained_path(self, tag: str) -> os.PathLike: + """ + Download and returns pretrained resources path of current task. + """ + assert tag in pretrained_models, 'Can not find pretrained resources of {}.'.format( + tag) + + res_path = os.path.join(MODEL_HOME, tag) + decompressed_path = download_and_decompress(pretrained_models[tag], + res_path) + decompressed_path = os.path.abspath(decompressed_path) + logger.info( + 'Use pretrained model stored in: {}'.format(decompressed_path)) + + return decompressed_path + + def _init_from_path(self, + model_type: str='panns_cnn14', + cfg_path: Optional[os.PathLike]=None, + label_file: Optional[os.PathLike]=None, + ckpt_path: Optional[os.PathLike]=None): + """ + Init model and other resources from a specific path. + """ + if label_file is None or ckpt_path is None: + self.res_path = self._get_pretrained_path(model_type) # panns_cnn14 + self.cfg_path = os.path.join( + self.res_path, pretrained_models[model_type]['cfg_path']) + self.label_file = os.path.join( + self.res_path, pretrained_models[model_type]['label_file']) + self.ckpt_path = os.path.join( + self.res_path, pretrained_models[model_type]['ckpt_path']) + else: + self.cfg_path = os.path.abspath(cfg_path) + self.label_file = os.path.abspath(label_file) + self.ckpt_path = os.path.abspath(ckpt_path) + + # config + with open(self.cfg_path, 'r') as f: + self._conf = yaml.safe_load(f) + + # labels + self._label_list = [] + with open(self.label_file, 'r') as f: + for line in f: + self._label_list.append(line.strip()) + + # model + model_class = dynamic_import(model_type, model_alias) + model_dict = paddle.load(self.ckpt_path) + self._model = model_class(extract_embedding=False) + self._model.set_state_dict(model_dict) + self._model.eval() + + def preprocess(self, audio_file: Union[str, os.PathLike]): + """ + Input preprocess and return paddle.Tensor stored in self.input. + Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet). + """ + feat_conf = self._conf['feature'] + logger.info(feat_conf) + waveform, _ = load( + file=audio_file, + sr=feat_conf['sample_rate'], + mono=True, + dtype='float32') + logger.info("Preprocessing audio_file:" + audio_file) + + # Feature extraction + # TODO: Feature args save into cfg file. + feature_extractor = LogMelSpectrogram( + sr=feat_conf['sample_rate'], + n_fft=feat_conf['n_fft'], + hop_length=feat_conf['hop_length'], + window=feat_conf['window'], + win_length=feat_conf['window_length'], + f_min=feat_conf['f_min'], + f_max=feat_conf['f_max'], + n_mels=feat_conf['n_mels'], ) + feats = feature_extractor( + paddle.to_tensor(paddle.to_tensor(waveform).unsqueeze(0))) + self._inputs['feats'] = paddle.transpose(feats, [0, 2, 1]).unsqueeze( + 1) # [B, N, T] -> [B, 1, T, N] + + @paddle.no_grad() + def infer(self): + """ + Model inference and result stored in self.output. + """ + self._outputs['logits'] = self._model(self._inputs['feats']) + + def _generate_topk_label(self, result: np.ndarray, topk: int) -> str: + assert topk <= len( + self._label_list), 'Value of topk is larger than number of labels.' + + topk_idx = (-result).argsort()[:topk] + ret = '' + for idx in topk_idx: + label, score = self._label_list[idx], result[idx] + ret += f'{label}: {score}\n' + return ret + + def postprocess(self, topk: int) -> Union[str, os.PathLike]: + """ + Output postprocess and return human-readable results such as texts and audio files. + """ + return self._generate_topk_label( + result=self._outputs['logits'].squeeze(0).numpy(), topk=topk) + + def execute(self, argv: List[str]) -> bool: + """ + Command line entry. + """ + parser_args = self.parser.parse_args(argv) + + model_type = parser_args.model + label_file = parser_args.label_file + cfg_path = parser_args.config + ckpt_path = parser_args.ckpt_path + audio_file = parser_args.input + topk = parser_args.topk + device = parser_args.device + + try: + res = self(model_type, cfg_path, label_file, ckpt_path, audio_file, + topk, device) + logger.info('CLS Result:\n{}'.format(res)) + return True + except Exception as e: + logger.exception(e) + return False + + def __call__(self, model_type, cfg_path, label_file, ckpt_path, audio_file, + topk, device): + """ + Python API to call an executor. + """ + audio_file = os.path.abspath(audio_file) + # self._check(audio_file, sample_rate) + paddle.set_device(device) + self._init_from_path(model_type, cfg_path, label_file, ckpt_path) + self.preprocess(audio_file) + self.infer() + res = self.postprocess(topk) # Retrieve result of cls. + + return res diff --git a/paddlespeech/cli/utils.py b/paddlespeech/cli/utils.py index edf579f71..ead3fb053 100644 --- a/paddlespeech/cli/utils.py +++ b/paddlespeech/cli/utils.py @@ -72,7 +72,12 @@ def download_and_decompress(archive: Dict[str, str], path: str) -> os.PathLike: assert 'url' in archive and 'md5' in archive, \ 'Dictionary keys of "url" and "md5" are required in the archive, but got: {}'.format(list(archive.keys())) - return download.get_path_from_url(archive['url'], path, archive['md5']) + + if False: + # TODO: File match md5 and uncompressed_path exist, so skip downloading and decompressing... + pass + else: + return download.get_path_from_url(archive['url'], path, archive['md5']) def load_state_dict_from_url(url: str, path: str, md5: str=None) -> os.PathLike: @@ -128,11 +133,16 @@ class Logger(object): 'EVAL': 22, 'WARNING': 30, 'ERROR': 40, - 'CRITICAL': 50 + 'CRITICAL': 50, + 'EXCEPTION': 100, } for key, level in log_config.items(): logging.addLevelName(level, key) - self.__dict__[key.lower()] = functools.partial(self.__call__, level) + if key == 'EXCEPTION': + self.__dict__[key.lower()] = self.logger.exception + else: + self.__dict__[key.lower()] = functools.partial(self.__call__, + level) self.format = logging.Formatter( fmt='[%(asctime)-15s] [%(levelname)8s] [%(filename)s] [L%(lineno)d] - %(message)s' diff --git a/requirements.txt b/requirements.txt index 658e64c05..3708bb062 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,6 +14,7 @@ loguru matplotlib nara_wpe nltk +paddleaudio paddlespeech_ctcdecoders paddlespeech_feat pandas