From b230dfbdec77de7dcf4c00daddef75d39fbc6893 Mon Sep 17 00:00:00 2001 From: KP <109694228@qq.com> Date: Thu, 9 Jun 2022 16:44:15 +0800 Subject: [PATCH] Add kws cli and demo. --- demos/keyword_spotting/README.md | 79 ++++++++ demos/keyword_spotting/README_cn.md | 76 +++++++ demos/keyword_spotting/run.sh | 7 + paddlespeech/cli/base_commands.py | 3 +- paddlespeech/cli/kws/__init__.py | 14 ++ paddlespeech/cli/kws/infer.py | 219 +++++++++++++++++++++ paddlespeech/resource/model_alias.py | 6 + paddlespeech/resource/pretrained_models.py | 18 ++ paddlespeech/resource/resource.py | 3 +- 9 files changed, 422 insertions(+), 3 deletions(-) create mode 100644 demos/keyword_spotting/README.md create mode 100644 demos/keyword_spotting/README_cn.md create mode 100644 demos/keyword_spotting/run.sh create mode 100644 paddlespeech/cli/kws/__init__.py create mode 100644 paddlespeech/cli/kws/infer.py diff --git a/demos/keyword_spotting/README.md b/demos/keyword_spotting/README.md new file mode 100644 index 000000000..6544cf71e --- /dev/null +++ b/demos/keyword_spotting/README.md @@ -0,0 +1,79 @@ +([简体中文](./README_cn.md)|English) +# KWS (Keyword Spotting) + +## Introduction +KWS(Keyword Spotting) is a technique to recognize keyword from a giving speech audio. + +This demo is an implementation to recognize keyword from a specific audio file. It can be done by a single command or a few lines in python using `PaddleSpeech`. + +## Usage +### 1. Installation +see [installation](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install.md). + +You can choose one way from easy, meduim and hard to install paddlespeech. + +### 2. Prepare Input File +The input of this demo should be a WAV file(`.wav`), and the sample rate must be the same as the model. + +Here are sample files for this demo that can be downloaded: +```bash +wget -c https://paddlespeech.bj.bcebos.com/kws/hey_snips.wav https://paddlespeech.bj.bcebos.com/kws/non-keyword.wav +``` + +### 3. Usage +- Command Line(Recommended) + ```bash + paddlespeech kws --input ./hey_snips.wav + paddlespeech kws --input ./non-keyword.wav + ``` + + Usage: + ```bash + paddlespeech kws --help + ``` + Arguments: + - `input`(required): Audio file to recognize. + - `threshold`:Score threshold for kws. Default: `0.8`. + - `model`: Model type of kws task. Default: `mdtc_heysnips`. + - `config`: Config of kws task. Use pretrained model when it is None. Default: `None`. + - `ckpt_path`: Model checkpoint. Use pretrained model when it is None. Default: `None`. + - `device`: Choose device to execute model inference. Default: default device of paddlepaddle in current environment. + - `verbose`: Show the log information. + + Output: + ```bash + # Input file: ./hey_snips.wav + Score: 1.000, Threshold: 0.8, Is keyword: True + # Input file: ./non-keyword.wav + Score: 0.000, Threshold: 0.8, Is keyword: False + ``` + +- Python API + ```python + import paddle + from paddlespeech.cli.kws import KWSExecutor + + kws_executor = KWSExecutor() + result = kws_executor( + audio_file='./hey_snips.wav', + threshold=0.8, + model='mdtc_heysnips', + config=None, + ckpt_path=None, + device=paddle.get_device()) + print('KWS Result: \n{}'.format(result)) + ``` + + Output: + ```bash + KWS Result: + Score: 1.000, Threshold: 0.8, Is keyword: True + ``` + +### 4.Pretrained Models + +Here is a list of pretrained models released by PaddleSpeech that can be used by command and python API: + +| Model | Language | Sample Rate +| :--- | :---: | :---: | +| mdtc_heysnips | en | 16k diff --git a/demos/keyword_spotting/README_cn.md b/demos/keyword_spotting/README_cn.md new file mode 100644 index 000000000..0d8f44a53 --- /dev/null +++ b/demos/keyword_spotting/README_cn.md @@ -0,0 +1,76 @@ +(简体中文|[English](./README.md)) + +# 关键词识别 +## 介绍 +关键词识别是一项用于识别一段语音内是否包含特定的关键词。 + +这个 demo 是一个从给定音频文件识别特定关键词的实现,它可以通过使用 `PaddleSpeech` 的单个命令或 python 中的几行代码来实现。 +## 使用方法 +### 1. 安装 +请看[安装文档](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install_cn.md)。 + +你可以从 easy,medium,hard 三中方式中选择一种方式安装。 + +### 2. 准备输入 +这个 demo 的输入应该是一个 WAV 文件(`.wav`),并且采样率必须与模型的采样率相同。 + +可以下载此 demo 的示例音频: +```bash +wget -c https://paddlespeech.bj.bcebos.com/kws/hey_snips.wav https://paddlespeech.bj.bcebos.com/kws/non-keyword.wav +``` +### 3. 使用方法 +- 命令行 (推荐使用) + ```bash + paddlespeech kws --input ./hey_snips.wav + paddlespeech kws --input ./non-keyword.wav + ``` + + 使用方法: + ```bash + paddlespeech kws --help + ``` + 参数: + - `input`(必须输入):用于识别关键词的音频文件。 + - `threshold`:用于判别是包含关键词的得分阈值,默认值:`0.8`。 + - `model`:KWS 任务的模型,默认值:`mdtc_heysnips`。 + - `config`:KWS 任务的参数文件,若不设置则使用预训练模型中的默认配置,默认值:`None`。 + - `ckpt_path`:模型参数文件,若不设置则下载预训练模型使用,默认值:`None`。 + - `device`:执行预测的设备,默认值:当前系统下 paddlepaddle 的默认 device。 + - `verbose`: 如果使用,显示 logger 信息。 + + 输出: + ```bash + # 输入为 ./hey_snips.wav + Score: 1.000, Threshold: 0.8, Is keyword: True + # 输入为 ./non-keyword.wav + Score: 0.000, Threshold: 0.8, Is keyword: False + ``` + +- Python API + ```python + import paddle + from paddlespeech.cli.kws import KWSExecutor + + kws_executor = KWSExecutor() + result = kws_executor( + audio_file='./hey_snips.wav', + threshold=0.8, + model='mdtc_heysnips', + config=None, + ckpt_path=None, + device=paddle.get_device()) + print('KWS Result: \n{}'.format(result)) + ``` + + 输出: + ```bash + KWS Result: + Score: 1.000, Threshold: 0.8, Is keyword: True + ``` + +### 4.预训练模型 +以下是 PaddleSpeech 提供的可以被命令行和 python API 使用的预训练模型列表: + +| 模型 | 语言 | 采样率 +| :--- | :---: | :---: | +| mdtc_heysnips | en | 16k diff --git a/demos/keyword_spotting/run.sh b/demos/keyword_spotting/run.sh new file mode 100644 index 000000000..7f9e0ebba --- /dev/null +++ b/demos/keyword_spotting/run.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +wget -c https://paddlespeech.bj.bcebos.com/kws/hey_snips.wav https://paddlespeech.bj.bcebos.com/kws/non-keyword.wav + +# kws +paddlespeech kws --input ./hey_snips.wav +paddlespeech kws --input non-keyword.wav diff --git a/paddlespeech/cli/base_commands.py b/paddlespeech/cli/base_commands.py index f5e2246d8..f9e2a55f8 100644 --- a/paddlespeech/cli/base_commands.py +++ b/paddlespeech/cli/base_commands.py @@ -94,7 +94,7 @@ class StatsCommand: def __init__(self): self.parser = argparse.ArgumentParser( prog='paddlespeech.stats', add_help=True) - self.task_choices = ['asr', 'cls', 'st', 'text', 'tts', 'vector'] + self.task_choices = ['asr', 'cls', 'st', 'text', 'tts', 'vector', 'kws'] self.parser.add_argument( '--task', type=str, @@ -138,6 +138,7 @@ _commands = { 'text': ['Text command.', 'TextExecutor'], 'tts': ['Text to Speech infer command.', 'TTSExecutor'], 'vector': ['Speech to vector embedding infer command.', 'VectorExecutor'], + 'kws': ['Keyword Spotting infer command.', 'KWSExecutor'], } for com, info in _commands.items(): diff --git a/paddlespeech/cli/kws/__init__.py b/paddlespeech/cli/kws/__init__.py new file mode 100644 index 000000000..db7bd50eb --- /dev/null +++ b/paddlespeech/cli/kws/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .infer import KWSExecutor diff --git a/paddlespeech/cli/kws/infer.py b/paddlespeech/cli/kws/infer.py new file mode 100644 index 000000000..e3f426f57 --- /dev/null +++ b/paddlespeech/cli/kws/infer.py @@ -0,0 +1,219 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +from collections import OrderedDict +from typing import List +from typing import Optional +from typing import Union + +import paddle +import yaml + +from ..executor import BaseExecutor +from ..log import logger +from ..utils import stats_wrapper +from paddlespeech.audio import load +from paddlespeech.audio.compliance.kaldi import fbank as kaldi_fbank + +__all__ = ['KWSExecutor'] + + +class KWSExecutor(BaseExecutor): + def __init__(self): + super().__init__(task='kws') + self.parser = argparse.ArgumentParser( + prog='paddlespeech.kws', add_help=True) + self.parser.add_argument( + '--input', + type=str, + default=None, + help='Audio file to keyword spotting.') + self.parser.add_argument( + '--threshold', + type=float, + default=0.8, + help='Score threshold for keyword spotting.') + self.parser.add_argument( + '--model', + type=str, + default='mdtc_heysnips', + choices=[ + tag[:tag.index('-')] + for tag in self.task_resource.pretrained_models.keys() + ], + help='Choose model type of kws task.') + self.parser.add_argument( + '--config', + type=str, + default=None, + help='Config of kws task. Use deault config when it is None.') + self.parser.add_argument( + '--ckpt_path', + type=str, + default=None, + help='Checkpoint file of model.') + self.parser.add_argument( + '--device', + type=str, + default=paddle.get_device(), + help='Choose device to execute model inference.') + self.parser.add_argument( + '-d', + '--job_dump_result', + action='store_true', + help='Save job result into file.') + self.parser.add_argument( + '-v', + '--verbose', + action='store_true', + help='Increase logger verbosity of current task.') + + def _init_from_path(self, + model_type: str='mdtc_heysnips', + cfg_path: Optional[os.PathLike]=None, + ckpt_path: Optional[os.PathLike]=None): + """ + Init model and other resources from a specific path. + """ + if hasattr(self, 'model'): + logger.info('Model had been initialized.') + return + + if ckpt_path is None: + tag = model_type + '-' + '16k' + self.task_resource.set_task_model(tag) + self.cfg_path = os.path.join( + self.task_resource.res_dir, + self.task_resource.res_dict['cfg_path']) + self.ckpt_path = os.path.join( + self.task_resource.res_dir, + self.task_resource.res_dict['ckpt_path'] + '.pdparams') + else: + self.cfg_path = os.path.abspath(cfg_path) + self.ckpt_path = os.path.abspath(ckpt_path) + + # config + with open(self.cfg_path, 'r') as f: + config = yaml.safe_load(f) + + # model + backbone_class = self.task_resource.get_model_class( + model_type.split('_')[0]) + model_class = self.task_resource.get_model_class( + model_type.split('_')[0] + '_for_kws') + backbone = backbone_class( + stack_num=config['stack_num'], + stack_size=config['stack_size'], + in_channels=config['in_channels'], + res_channels=config['res_channels'], + kernel_size=config['kernel_size'], + causal=True, ) + self.model = model_class( + backbone=backbone, num_keywords=config['num_keywords']) + model_dict = paddle.load(self.ckpt_path) + self.model.set_state_dict(model_dict) + self.model.eval() + + self.feature_extractor = lambda x: kaldi_fbank( + x, sr=config['sample_rate'], + frame_shift=config['frame_shift'], + frame_length=config['frame_length'], + n_mels=config['n_mels'] + ) + + def preprocess(self, audio_file: Union[str, os.PathLike]): + """ + Input preprocess and return paddle.Tensor stored in self.input. + Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet). + """ + assert os.path.isfile(audio_file) + waveform, _ = load(audio_file) + if isinstance(audio_file, (str, os.PathLike)): + logger.info("Preprocessing audio_file:" + audio_file) + + # Feature extraction + waveform = paddle.to_tensor(waveform).unsqueeze(0) + self._inputs['feats'] = self.feature_extractor(waveform).unsqueeze(0) + + @paddle.no_grad() + def infer(self): + """ + Model inference and result stored in self.output. + """ + self._outputs['logits'] = self.model(self._inputs['feats']) + + def postprocess(self, threshold: float) -> Union[str, os.PathLike]: + """ + Output postprocess and return human-readable results such as texts and audio files. + """ + kws_score = max(self._outputs['logits'][0, :, 0]).item() + return 'Score: {:.3f}, Threshold: {}, Is keyword: {}'.format( + kws_score, threshold, kws_score > threshold) + + def execute(self, argv: List[str]) -> bool: + """ + Command line entry. + """ + parser_args = self.parser.parse_args(argv) + + model_type = parser_args.model + cfg_path = parser_args.config + ckpt_path = parser_args.ckpt_path + device = parser_args.device + threshold = parser_args.threshold + + if not parser_args.verbose: + self.disable_task_loggers() + + task_source = self.get_input_source(parser_args.input) + task_results = OrderedDict() + has_exceptions = False + + for id_, input_ in task_source.items(): + try: + res = self(input_, threshold, model_type, cfg_path, ckpt_path, + device) + task_results[id_] = res + except Exception as e: + has_exceptions = True + task_results[id_] = f'{e.__class__.__name__}: {e}' + + self.process_task_results(parser_args.input, task_results, + parser_args.job_dump_result) + + if has_exceptions: + return False + else: + return True + + @stats_wrapper + def __call__(self, + audio_file: os.PathLike, + threshold: float=0.8, + model: str='mdtc_heysnips', + config: Optional[os.PathLike]=None, + ckpt_path: Optional[os.PathLike]=None, + device: str=paddle.get_device()): + """ + Python API to call an executor. + """ + audio_file = os.path.abspath(os.path.expanduser(audio_file)) + paddle.set_device(device) + self._init_from_path(model, config, ckpt_path) + self.preprocess(audio_file) + self.infer() + res = self.postprocess(threshold) + + return res diff --git a/paddlespeech/resource/model_alias.py b/paddlespeech/resource/model_alias.py index 5309fd86f..9c76dd4b3 100644 --- a/paddlespeech/resource/model_alias.py +++ b/paddlespeech/resource/model_alias.py @@ -83,4 +83,10 @@ model_alias = { # ------------ Vector ------------- # --------------------------------- "ecapatdnn": ["paddlespeech.vector.models.ecapa_tdnn:EcapaTdnn"], + + # --------------------------------- + # -------------- kws -------------- + # --------------------------------- + "mdtc": ["paddlespeech.kws.models.mdtc:MDTC"], + "mdtc_for_kws": ["paddlespeech.kws.models.mdtc:KWSModel"], } diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py index 44cd79e8c..439fda5fd 100644 --- a/paddlespeech/resource/pretrained_models.py +++ b/paddlespeech/resource/pretrained_models.py @@ -1014,3 +1014,21 @@ vector_dynamic_pretrained_models = { }, }, } + +# --------------------------------- +# ------------- KWS --------------- +# --------------------------------- +kws_dynamic_pretrained_models = { + 'mdtc_heysnips-16k': { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/kws/heysnips/kws0_mdtc_heysnips_ckpt.tar.gz', + 'md5': + 'c0de0a9520d66c3c8d6679460893578f', + 'cfg_path': + 'conf/mdtc.yaml', + 'ckpt_path': + 'ckpt/model', + }, + }, +} diff --git a/paddlespeech/resource/resource.py b/paddlespeech/resource/resource.py index 45707eb44..70f12b64c 100644 --- a/paddlespeech/resource/resource.py +++ b/paddlespeech/resource/resource.py @@ -22,7 +22,7 @@ from ..utils.dynamic_import import dynamic_import from ..utils.env import MODEL_HOME from .model_alias import model_alias -task_supported = ['asr', 'cls', 'st', 'text', 'tts', 'vector'] +task_supported = ['asr', 'cls', 'st', 'text', 'tts', 'vector', 'kws'] model_format_supported = ['dynamic', 'static', 'onnx'] inference_mode_supported = ['online', 'offline'] @@ -164,7 +164,6 @@ class CommonTaskResource: try: import_models = '{}_{}_pretrained_models'.format(self.task, self.model_format) - print(f"from .pretrained_models import {import_models}") exec('from .pretrained_models import {}'.format(import_models)) models = OrderedDict(locals()[import_models]) except Exception as e: