diff --git a/paddlespeech/cli/README.md b/paddlespeech/cli/README.md index 56afb939..264d66f7 100644 --- a/paddlespeech/cli/README.md +++ b/paddlespeech/cli/README.md @@ -7,3 +7,6 @@ ## ASR `paddlespeech asr --input ./test_audio.wav` + + ## Multi-label Classification + `paddlespeech cls --input ./test_audio.wav` diff --git a/paddlespeech/cli/__init__.py b/paddlespeech/cli/__init__.py index 7e032904..0aedab57 100644 --- a/paddlespeech/cli/__init__.py +++ b/paddlespeech/cli/__init__.py @@ -14,3 +14,4 @@ from .asr import ASRExecutor from .base_commands import BaseCommand from .base_commands import HelpCommand +from .cls import CLSExecutor diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index b40516e9..30e4bb9c 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -137,6 +137,10 @@ class ASRExecutor(BaseExecutor): """ Init model and other resources from a specific path. """ + if hasattr(self, 'model'): + logger.info('Model had been initialized.') + return + if cfg_path is None or ckpt_path is None: sample_rate_str = '16k' if sample_rate == 16000 else '8k' tag = model_type + '_' + lang + '_' + sample_rate_str @@ -361,7 +365,7 @@ class ASRExecutor(BaseExecutor): audio, audio_sample_rate = soundfile.read( audio_file, dtype="int16", always_2d=True) except Exception as e: - logger.error(str(e)) + logger.exception(e) logger.error( "can not open the audio file, please check the audio file format is 'wav'. \n \ you can try to use sox to change the file format.\n \ @@ -421,7 +425,7 @@ class ASRExecutor(BaseExecutor): logger.info('ASR Result: {}'.format(res)) return True except Exception as e: - print(e) + logger.exception(e) return False def __call__(self, model, lang, sample_rate, config, ckpt_path, audio_file, diff --git a/paddlespeech/cli/cls/__init.__py b/paddlespeech/cli/cls/__init.__py deleted file mode 100644 index e69de29b..00000000 diff --git a/paddlespeech/cli/cls/__init__.py b/paddlespeech/cli/cls/__init__.py new file mode 100644 index 00000000..13e316f8 --- /dev/null +++ b/paddlespeech/cli/cls/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .infer import CLSExecutor diff --git a/paddlespeech/cli/cls/infer.py b/paddlespeech/cli/cls/infer.py new file mode 100644 index 00000000..c4206f7e --- /dev/null +++ b/paddlespeech/cli/cls/infer.py @@ -0,0 +1,255 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +from typing import List +from typing import Optional +from typing import Union + +import numpy as np +import paddle +import yaml + +from ..executor import BaseExecutor +from ..utils import cli_register +from ..utils import download_and_decompress +from ..utils import logger +from ..utils import MODEL_HOME +from paddleaudio import load +from paddleaudio.features import LogMelSpectrogram +from paddlespeech.s2t.utils.dynamic_import import dynamic_import + +__all__ = ['CLSExecutor'] + +pretrained_models = { + "panns_cnn6": { + 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn6.tar.gz', + 'md5': '051b30c56bcb9a3dd67bc205cc12ffd2', + 'cfg_path': 'panns.yaml', + 'ckpt_path': 'cnn6.pdparams', + 'label_file': 'audioset_labels.txt', + }, + "panns_cnn10": { + 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn10.tar.gz', + 'md5': '97c6f25587685379b1ebcd4c1f624927', + 'cfg_path': 'panns.yaml', + 'ckpt_path': 'cnn10.pdparams', + 'label_file': 'audioset_labels.txt', + }, + "panns_cnn14": { + 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn14.tar.gz', + 'md5': 'e3b9b5614a1595001161d0ab95edee97', + 'cfg_path': 'panns.yaml', + 'ckpt_path': 'cnn14.pdparams', + 'label_file': 'audioset_labels.txt', + }, +} + +model_alias = { + "panns_cnn6": "paddlespeech.cls.models.panns:CNN6", + "panns_cnn10": "paddlespeech.cls.models.panns:CNN10", + "panns_cnn14": "paddlespeech.cls.models.panns:CNN14", +} + + +@cli_register( + name='paddlespeech.cls', description='Audio classification infer command.') +class CLSExecutor(BaseExecutor): + def __init__(self): + super(CLSExecutor, self).__init__() + + self.parser = argparse.ArgumentParser( + prog='paddlespeech.cls', add_help=True) + self.parser.add_argument( + '--input', type=str, required=True, help='Audio file to classify.') + self.parser.add_argument( + '--model', + type=str, + default='panns_cnn14', + help='Choose model type of cls task.') + self.parser.add_argument( + '--config', + type=str, + default=None, + help='Config of cls task. Use deault config when it is None.') + self.parser.add_argument( + '--ckpt_path', + type=str, + default=None, + help='Checkpoint file of model.') + self.parser.add_argument( + '--label_file', + type=str, + default=None, + help='Label file of cls task.') + self.parser.add_argument( + '--topk', + type=int, + default=1, + help='Return topk scores of classification result.') + self.parser.add_argument( + '--device', + type=str, + default=paddle.get_device(), + help='Choose device to execute model inference.') + + def _get_pretrained_path(self, tag: str) -> os.PathLike: + """ + Download and returns pretrained resources path of current task. + """ + assert tag in pretrained_models, 'Can not find pretrained resources of {}.'.format( + tag) + + res_path = os.path.join(MODEL_HOME, tag) + decompressed_path = download_and_decompress(pretrained_models[tag], + res_path) + decompressed_path = os.path.abspath(decompressed_path) + logger.info( + 'Use pretrained model stored in: {}'.format(decompressed_path)) + + return decompressed_path + + def _init_from_path(self, + model_type: str='panns_cnn14', + cfg_path: Optional[os.PathLike]=None, + label_file: Optional[os.PathLike]=None, + ckpt_path: Optional[os.PathLike]=None): + """ + Init model and other resources from a specific path. + """ + if hasattr(self, 'model'): + logger.info('Model had been initialized.') + return + + if label_file is None or ckpt_path is None: + self.res_path = self._get_pretrained_path(model_type) # panns_cnn14 + self.cfg_path = os.path.join( + self.res_path, pretrained_models[model_type]['cfg_path']) + self.label_file = os.path.join( + self.res_path, pretrained_models[model_type]['label_file']) + self.ckpt_path = os.path.join( + self.res_path, pretrained_models[model_type]['ckpt_path']) + else: + self.cfg_path = os.path.abspath(cfg_path) + self.label_file = os.path.abspath(label_file) + self.ckpt_path = os.path.abspath(ckpt_path) + + # config + with open(self.cfg_path, 'r') as f: + self._conf = yaml.safe_load(f) + + # labels + self._label_list = [] + with open(self.label_file, 'r') as f: + for line in f: + self._label_list.append(line.strip()) + + # model + model_class = dynamic_import(model_type, model_alias) + model_dict = paddle.load(self.ckpt_path) + self.model = model_class(extract_embedding=False) + self.model.set_state_dict(model_dict) + self.model.eval() + + def preprocess(self, audio_file: Union[str, os.PathLike]): + """ + Input preprocess and return paddle.Tensor stored in self.input. + Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet). + """ + feat_conf = self._conf['feature'] + logger.info(feat_conf) + waveform, _ = load( + file=audio_file, + sr=feat_conf['sample_rate'], + mono=True, + dtype='float32') + logger.info("Preprocessing audio_file:" + audio_file) + + # Feature extraction + feature_extractor = LogMelSpectrogram( + sr=feat_conf['sample_rate'], + n_fft=feat_conf['n_fft'], + hop_length=feat_conf['hop_length'], + window=feat_conf['window'], + win_length=feat_conf['window_length'], + f_min=feat_conf['f_min'], + f_max=feat_conf['f_max'], + n_mels=feat_conf['n_mels'], ) + feats = feature_extractor( + paddle.to_tensor(paddle.to_tensor(waveform).unsqueeze(0))) + self._inputs['feats'] = paddle.transpose(feats, [0, 2, 1]).unsqueeze( + 1) # [B, N, T] -> [B, 1, T, N] + + @paddle.no_grad() + def infer(self): + """ + Model inference and result stored in self.output. + """ + self._outputs['logits'] = self.model(self._inputs['feats']) + + def _generate_topk_label(self, result: np.ndarray, topk: int) -> str: + assert topk <= len( + self._label_list), 'Value of topk is larger than number of labels.' + + topk_idx = (-result).argsort()[:topk] + ret = '' + for idx in topk_idx: + label, score = self._label_list[idx], result[idx] + ret += f'{label}: {score}\n' + return ret + + def postprocess(self, topk: int) -> Union[str, os.PathLike]: + """ + Output postprocess and return human-readable results such as texts and audio files. + """ + return self._generate_topk_label( + result=self._outputs['logits'].squeeze(0).numpy(), topk=topk) + + def execute(self, argv: List[str]) -> bool: + """ + Command line entry. + """ + parser_args = self.parser.parse_args(argv) + + model_type = parser_args.model + label_file = parser_args.label_file + cfg_path = parser_args.config + ckpt_path = parser_args.ckpt_path + audio_file = parser_args.input + topk = parser_args.topk + device = parser_args.device + + try: + res = self(model_type, cfg_path, label_file, ckpt_path, audio_file, + topk, device) + logger.info('CLS Result:\n{}'.format(res)) + return True + except Exception as e: + logger.exception(e) + return False + + def __call__(self, model_type, cfg_path, label_file, ckpt_path, audio_file, + topk, device): + """ + Python API to call an executor. + """ + audio_file = os.path.abspath(audio_file) + # self._check(audio_file, sample_rate) + paddle.set_device(device) + self._init_from_path(model_type, cfg_path, label_file, ckpt_path) + self.preprocess(audio_file) + self.infer() + res = self.postprocess(topk) # Retrieve result of cls. + + return res diff --git a/paddlespeech/cli/utils.py b/paddlespeech/cli/utils.py index edf579f7..eb023c11 100644 --- a/paddlespeech/cli/utils.py +++ b/paddlespeech/cli/utils.py @@ -12,10 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools +import hashlib import logging import os +import tarfile +import zipfile from typing import Any from typing import Dict +from typing import List from paddle.framework import load from paddle.utils import download @@ -55,12 +59,69 @@ def get_command(name: str) -> Any: return com['_entry'] -def decompress(file: str) -> os.PathLike: - """ - Extracts all files from a compressed file. - """ - assert os.path.isfile(file), "File: {} not exists.".format(file) - return download._decompress(file) +def _md5check(filepath: os.PathLike, md5sum: str) -> bool: + logger.info("File {} md5 checking...".format(filepath)) + md5 = hashlib.md5() + with open(filepath, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b""): + md5.update(chunk) + calc_md5sum = md5.hexdigest() + + if calc_md5sum != md5sum: + logger.info("File {} md5 check failed, {}(calc) != " + "{}(base)".format(filepath, calc_md5sum, md5sum)) + return False + else: + logger.info("File {} md5 check passed.".format(filepath)) + return True + + +def _get_uncompress_path(filepath: os.PathLike) -> os.PathLike: + file_dir = os.path.dirname(filepath) + + if tarfile.is_tarfile(filepath): + files = tarfile.open(filepath, "r:*") + file_list = files.getnames() + elif zipfile.is_zipfile(filepath): + files = zipfile.ZipFile(filepath, 'r') + file_list = files.namelist() + else: + return file_dir + + if _is_a_single_file(file_list): + rootpath = file_list[0] + uncompressed_path = os.path.join(file_dir, rootpath) + elif _is_a_single_dir(file_list): + rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + else: + rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + + files.close() + return uncompressed_path + + +def _is_a_single_file(file_list: List[os.PathLike]) -> bool: + if len(file_list) == 1 and file_list[0].find(os.sep) < -1: + return True + return False + + +def _is_a_single_dir(file_list: List[os.PathLike]) -> bool: + new_file_list = [] + for file_path in file_list: + if '/' in file_path: + file_path = file_path.replace('/', os.sep) + elif '\\' in file_path: + file_path = file_path.replace('\\', os.sep) + new_file_list.append(file_path) + + file_name = new_file_list[0].split(os.sep)[0] + for i in range(1, len(new_file_list)): + if file_name != new_file_list[i].split(os.sep)[0]: + return False + return True def download_and_decompress(archive: Dict[str, str], path: str) -> os.PathLike: @@ -72,7 +133,17 @@ def download_and_decompress(archive: Dict[str, str], path: str) -> os.PathLike: assert 'url' in archive and 'md5' in archive, \ 'Dictionary keys of "url" and "md5" are required in the archive, but got: {}'.format(list(archive.keys())) - return download.get_path_from_url(archive['url'], path, archive['md5']) + + filepath = os.path.join(path, os.path.basename(archive['url'])) + if os.path.isfile(filepath) and _md5check(filepath, archive['md5']): + uncompress_path = _get_uncompress_path(filepath) + if not os.path.isdir(uncompress_path): + download._decompress(filepath) + else: + uncompress_path = download.get_path_from_url(archive['url'], path, + archive['md5']) + + return uncompress_path def load_state_dict_from_url(url: str, path: str, md5: str=None) -> os.PathLike: @@ -128,11 +199,16 @@ class Logger(object): 'EVAL': 22, 'WARNING': 30, 'ERROR': 40, - 'CRITICAL': 50 + 'CRITICAL': 50, + 'EXCEPTION': 100, } for key, level in log_config.items(): logging.addLevelName(level, key) - self.__dict__[key.lower()] = functools.partial(self.__call__, level) + if key == 'EXCEPTION': + self.__dict__[key.lower()] = self.logger.exception + else: + self.__dict__[key.lower()] = functools.partial(self.__call__, + level) self.format = logging.Formatter( fmt='[%(asctime)-15s] [%(levelname)8s] [%(filename)s] [L%(lineno)d] - %(message)s' diff --git a/requirements.txt b/requirements.txt index 658e64c0..3708bb06 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,6 +14,7 @@ loguru matplotlib nara_wpe nltk +paddleaudio paddlespeech_ctcdecoders paddlespeech_feat pandas diff --git a/setup.py b/setup.py index e25c43a1..7720ba3f 100644 --- a/setup.py +++ b/setup.py @@ -43,6 +43,7 @@ requirements = { "nara_wpe", "nltk", "pandas", + "paddleaudio", "paddlespeech_ctcdecoders", "paddlespeech_feat", "praatio~=4.1",