diff --git a/audio/paddleaudio/datasets/__init__.py b/audio/paddleaudio/datasets/__init__.py index ebd4af98..f95fad30 100644 --- a/audio/paddleaudio/datasets/__init__.py +++ b/audio/paddleaudio/datasets/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from .esc50 import ESC50 from .gtzan import GTZAN +from .hey_snips import HeySnips from .rirs_noises import OpenRIRNoise from .tess import TESS from .urban_sound import UrbanSound8K diff --git a/audio/paddleaudio/datasets/dataset.py b/audio/paddleaudio/datasets/dataset.py index 06e2df6d..488187a6 100644 --- a/audio/paddleaudio/datasets/dataset.py +++ b/audio/paddleaudio/datasets/dataset.py @@ -17,6 +17,8 @@ import numpy as np import paddle from ..backends import load as load_audio +from ..compliance.kaldi import fbank as kaldi_fbank +from ..compliance.kaldi import mfcc as kaldi_mfcc from ..compliance.librosa import melspectrogram from ..compliance.librosa import mfcc @@ -24,6 +26,8 @@ feat_funcs = { 'raw': None, 'melspectrogram': melspectrogram, 'mfcc': mfcc, + 'kaldi_fbank': kaldi_fbank, + 'kaldi_mfcc': kaldi_mfcc, } @@ -73,16 +77,24 @@ class AudioClassificationDataset(paddle.io.Dataset): feat_func = feat_funcs[self.feat_type] record = {} - record['feat'] = feat_func( - waveform, sample_rate, - **self.feat_config) if feat_func else waveform + if self.feat_type in ['kaldi_fbank', 'kaldi_mfcc']: + waveform = paddle.to_tensor(waveform).unsqueeze(0) # (C, T) + record['feat'] = feat_func( + waveform=waveform, sr=self.sample_rate, **self.feat_config) + else: + record['feat'] = feat_func( + waveform, sample_rate, + **self.feat_config) if feat_func else waveform record['label'] = label return record def __getitem__(self, idx): record = self._convert_to_record(idx) - return np.array(record['feat']).transpose(), np.array( - record['label'], dtype=np.int64) + if self.feat_type in ['kaldi_fbank', 'kaldi_mfcc']: + return self.keys[idx], record['feat'], record['label'] + else: + return np.array(record['feat']).transpose(), np.array( + record['label'], dtype=np.int64) def __len__(self): return len(self.files) diff --git a/audio/paddleaudio/datasets/hey_snips.py b/audio/paddleaudio/datasets/hey_snips.py new file mode 100644 index 00000000..7a67b843 --- /dev/null +++ b/audio/paddleaudio/datasets/hey_snips.py @@ -0,0 +1,74 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections +import json +import os +from typing import List +from typing import Tuple + +from .dataset import AudioClassificationDataset + +__all__ = ['HeySnips'] + + +class HeySnips(AudioClassificationDataset): + meta_info = collections.namedtuple('META_INFO', + ('key', 'label', 'duration', 'wav')) + + def __init__(self, + data_dir: os.PathLike, + mode: str='train', + feat_type: str='kaldi_fbank', + sample_rate: int=16000, + **kwargs): + self.data_dir = data_dir + files, labels = self._get_data(mode) + super(HeySnips, self).__init__( + files=files, + labels=labels, + feat_type=feat_type, + sample_rate=sample_rate, + **kwargs) + + def _get_meta_info(self, mode) -> List[collections.namedtuple]: + ret = [] + with open(os.path.join(self.data_dir, '{}.json'.format(mode)), + 'r') as f: + data = json.load(f) + for item in data: + sample = collections.OrderedDict() + if item['duration'] > 0: + sample['key'] = item['id'] + sample['label'] = 0 if item['is_hotword'] == 1 else -1 + sample['duration'] = item['duration'] + sample['wav'] = os.path.join(self.data_dir, + item['audio_file_path']) + ret.append(self.meta_info(*sample.values())) + return ret + + def _get_data(self, mode: str) -> Tuple[List[str], List[int]]: + meta_info = self._get_meta_info(mode) + + files = [] + labels = [] + self.keys = [] + self.durations = [] + for sample in meta_info: + key, target, duration, wav = sample + files.append(wav) + labels.append(int(target)) + self.keys.append(key) + self.durations.append(float(duration)) + + return files, labels diff --git a/examples/hey_snips/README.md b/examples/hey_snips/README.md new file mode 100644 index 00000000..ba263906 --- /dev/null +++ b/examples/hey_snips/README.md @@ -0,0 +1,8 @@ + +## Metrics + +We mesure FRRs with fixing false alarms in one hour: + +|Model|False Alarm| False Reject Rate| +|--|--|--| +|MDTC| 1| 0.003559 | diff --git a/examples/hey_snips/kws0/README.md b/examples/hey_snips/kws0/README.md new file mode 100644 index 00000000..be8d142b --- /dev/null +++ b/examples/hey_snips/kws0/README.md @@ -0,0 +1,22 @@ +# MDTC Keyword Spotting with HeySnips Dataset + +## Dataset + +Before running scripts, you **MUST** follow this instruction to download the dataset: https://github.com/sonos/keyword-spotting-research-datasets + +After you download and decompress the dataset archive, you should **REPLACE** the value of `data_dir` in `conf/*.yaml` to complete dataset config. + +## Get Started + +In this section, we will train the [MDTC](https://arxiv.org/pdf/2102.13552.pdf) model and evaluate on "Hey Snips" dataset. + +```sh +CUDA_VISIBLE_DEVICES=0,1 ./run.sh conf/mdtc.yaml +``` + +This script contains training and scoring steps. You can just set the `CUDA_VISIBLE_DEVICES` environment var to run on single gpu or multi-gpus. + +The vars `stage` and `stop_stage` in `./run.sh` controls the running steps: +- stage 1: Training from scratch. +- stage 2: Evaluating model on test dataset and computing detection error tradeoff(DET) of all trigger thresholds. +- stage 3: Plotting the DET cruve for visualizaiton. diff --git a/examples/hey_snips/kws0/conf/mdtc.yaml b/examples/hey_snips/kws0/conf/mdtc.yaml new file mode 100644 index 00000000..3ce9f9d0 --- /dev/null +++ b/examples/hey_snips/kws0/conf/mdtc.yaml @@ -0,0 +1,39 @@ +data: + data_dir: '/PATH/TO/DATA/hey_snips_research_6k_en_train_eval_clean_ter' + dataset: 'paddleaudio.datasets:HeySnips' + +model: + num_keywords: 1 + backbone: 'paddlespeech.kws.models:MDTC' + config: + stack_num: 3 + stack_size: 4 + in_channels: 80 + res_channels: 32 + kernel_size: 5 + +feature: + feat_type: 'kaldi_fbank' + sample_rate: 16000 + frame_shift: 10 + frame_length: 25 + n_mels: 80 + +training: + epochs: 100 + num_workers: 16 + batch_size: 100 + checkpoint_dir: './checkpoint' + save_freq: 10 + log_freq: 10 + learning_rate: 0.001 + weight_decay: 0.00005 + grad_clip: 5.0 + +scoring: + batch_size: 100 + num_workers: 16 + checkpoint: './checkpoint/epoch_100/model.pdparams' + score_file: './scores.txt' + stats_file: './stats.0.txt' + img_file: './det.png' \ No newline at end of file diff --git a/examples/hey_snips/kws0/local/plot.sh b/examples/hey_snips/kws0/local/plot.sh new file mode 100755 index 00000000..5869e50b --- /dev/null +++ b/examples/hey_snips/kws0/local/plot.sh @@ -0,0 +1,2 @@ +#!/bin/bash +python3 ${BIN_DIR}/plot_det_curve.py --cfg_path=$1 --keyword HeySnips diff --git a/examples/hey_snips/kws0/local/score.sh b/examples/hey_snips/kws0/local/score.sh new file mode 100755 index 00000000..ed21d08c --- /dev/null +++ b/examples/hey_snips/kws0/local/score.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +python3 ${BIN_DIR}/score.py --cfg_path=$1 + +python3 ${BIN_DIR}/compute_det.py --cfg_path=$1 diff --git a/examples/hey_snips/kws0/local/train.sh b/examples/hey_snips/kws0/local/train.sh new file mode 100755 index 00000000..8d0181b8 --- /dev/null +++ b/examples/hey_snips/kws0/local/train.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +ngpu=$1 +cfg_path=$2 + +if [ ${ngpu} -gt 0 ]; then + python3 -m paddle.distributed.launch --gpus $CUDA_VISIBLE_DEVICES ${BIN_DIR}/train.py \ + --cfg_path ${cfg_path} +else + echo "set CUDA_VISIBLE_DEVICES to enable multi-gpus trainning." + python3 ${BIN_DIR}/train.py \ + --cfg_path ${cfg_path} +fi diff --git a/examples/hey_snips/kws0/path.sh b/examples/hey_snips/kws0/path.sh new file mode 100755 index 00000000..54a430d4 --- /dev/null +++ b/examples/hey_snips/kws0/path.sh @@ -0,0 +1,28 @@ +#!/bin/bash +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +export MAIN_ROOT=`realpath ${PWD}/../../../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +export PYTHONDONTWRITEBYTECODE=1 +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ + +MODEL=mdtc +export BIN_DIR=${MAIN_ROOT}/paddlespeech/kws/exps/${MODEL} \ No newline at end of file diff --git a/examples/hey_snips/kws0/run.sh b/examples/hey_snips/kws0/run.sh new file mode 100755 index 00000000..2cc09a4f --- /dev/null +++ b/examples/hey_snips/kws0/run.sh @@ -0,0 +1,41 @@ +#!/bin/bash +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e +source path.sh + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') + +if [ $# != 1 ];then + echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path" + exit -1 +fi + +stage=1 +stop_stage=3 + +cfg_path=$1 + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + ./local/train.sh ${ngpu} ${cfg_path} || exit -1 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + ./local/score.sh ${cfg_path} || exit -1 +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + ./local/plot.sh ${cfg_path} || exit -1 +fi \ No newline at end of file diff --git a/paddlespeech/kws/__init__.py b/paddlespeech/kws/__init__.py new file mode 100644 index 00000000..9c6e278e --- /dev/null +++ b/paddlespeech/kws/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .models.mdtc import MDTC diff --git a/paddlespeech/kws/exps/mdtc/collate.py b/paddlespeech/kws/exps/mdtc/collate.py new file mode 100644 index 00000000..dcc81123 --- /dev/null +++ b/paddlespeech/kws/exps/mdtc/collate.py @@ -0,0 +1,39 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time + +import paddle + + +def collate_features(batch): + # (key, feat, label) + collate_start = time.time() + keys = [] + feats = [] + labels = [] + lengths = [] + for sample in batch: + keys.append(sample[0]) + feats.append(sample[1]) + labels.append(sample[2]) + lengths.append(sample[1].shape[0]) + + max_length = max(lengths) + for i in range(len(feats)): + feats[i] = paddle.nn.functional.pad( + feats[i], [0, max_length - feats[i].shape[0], 0, 0], + data_format='NLC') + + return keys, paddle.stack(feats), paddle.to_tensor( + labels), paddle.to_tensor(lengths) diff --git a/paddlespeech/kws/exps/mdtc/compute_det.py b/paddlespeech/kws/exps/mdtc/compute_det.py new file mode 100644 index 00000000..817846b8 --- /dev/null +++ b/paddlespeech/kws/exps/mdtc/compute_det.py @@ -0,0 +1,116 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from wekws(https://github.com/wenet-e2e/wekws) +import argparse +import os + +import paddle +import yaml +from tqdm import tqdm + +from paddlespeech.s2t.utils.dynamic_import import dynamic_import + +# yapf: disable +parser = argparse.ArgumentParser(__doc__) +parser.add_argument("--cfg_path", type=str, required=True) +parser.add_argument('--keyword_index', type=int, default=0, help='keyword index') +parser.add_argument('--step', type=float, default=0.01, help='threshold step of trigger score') +parser.add_argument('--window_shift', type=int, default=50, help='window_shift is used to skip the frames after triggered') +args = parser.parse_args() +# yapf: enable + + +def load_label_and_score(keyword_index: int, + ds: paddle.io.Dataset, + score_file: os.PathLike): + score_table = {} # {utt_id: scores_over_frames} + with open(score_file, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split() + key = arr[0] + current_keyword = arr[1] + str_list = arr[2:] + if int(current_keyword) == keyword_index: + scores = list(map(float, str_list)) + if key not in score_table: + score_table.update({key: scores}) + keyword_table = {} # scores of keyword utt_id + filler_table = {} # scores of non-keyword utt_id + filler_duration = 0.0 + + for key, index, duration in zip(ds.keys, ds.labels, ds.durations): + assert key in score_table + if index == keyword_index: + keyword_table[key] = score_table[key] + else: + filler_table[key] = score_table[key] + filler_duration += duration + + return keyword_table, filler_table, filler_duration + + +if __name__ == '__main__': + args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path)) + with open(args.cfg_path, 'r') as f: + config = yaml.safe_load(f) + + data_conf = config['data'] + feat_conf = config['feature'] + scoring_conf = config['scoring'] + + # Dataset + ds_class = dynamic_import(data_conf['dataset']) + test_ds = ds_class(data_dir=data_conf['data_dir'], mode='test', **feat_conf) + + score_file = os.path.abspath(scoring_conf['score_file']) + stats_file = os.path.abspath(scoring_conf['stats_file']) + + keyword_table, filler_table, filler_duration = load_label_and_score( + args.keyword, test_ds, score_file) + print('Filler total duration Hours: {}'.format(filler_duration / 3600.0)) + pbar = tqdm(total=int(1.0 / args.step)) + with open(stats_file, 'w', encoding='utf8') as fout: + keyword_index = args.keyword_index + threshold = 0.0 + while threshold <= 1.0: + num_false_reject = 0 + # transverse the all keyword_table + for key, score_list in keyword_table.items(): + # computer positive test sample, use the max score of list. + score = max(score_list) + if float(score) < threshold: + num_false_reject += 1 + num_false_alarm = 0 + # transverse the all filler_table + for key, score_list in filler_table.items(): + i = 0 + while i < len(score_list): + if score_list[i] >= threshold: + num_false_alarm += 1 + i += args.window_shift + else: + i += 1 + if len(keyword_table) != 0: + false_reject_rate = num_false_reject / len(keyword_table) + num_false_alarm = max(num_false_alarm, 1e-6) + if filler_duration != 0: + false_alarm_per_hour = num_false_alarm / \ + (filler_duration / 3600.0) + fout.write('{:.6f} {:.6f} {:.6f}\n'.format( + threshold, false_alarm_per_hour, false_reject_rate)) + threshold += args.step + pbar.update(1) + + pbar.close() + print('DET saved to: {}'.format(stats_file)) diff --git a/paddlespeech/kws/exps/mdtc/plot_det_curve.py b/paddlespeech/kws/exps/mdtc/plot_det_curve.py new file mode 100644 index 00000000..ac920358 --- /dev/null +++ b/paddlespeech/kws/exps/mdtc/plot_det_curve.py @@ -0,0 +1,74 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from wekws(https://github.com/wenet-e2e/wekws) +import argparse +import os + +import matplotlib.pyplot as plt +import numpy as np +import yaml + +# yapf: disable +parser = argparse.ArgumentParser(__doc__) +parser.add_argument("--cfg_path", type=str, required=True) +parser.add_argument("--keyword", type=str, required=True) +args = parser.parse_args() +# yapf: enable + + +def load_stats_file(stats_file): + values = [] + with open(stats_file, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split() + threshold, fa_per_hour, frr = arr + values.append([float(fa_per_hour), float(frr) * 100]) + values.reverse() + return np.array(values) + + +def plot_det_curve(keywords, stats_file, figure_file, xlim, x_step, ylim, + y_step): + plt.figure(dpi=200) + plt.rcParams['xtick.direction'] = 'in' + plt.rcParams['ytick.direction'] = 'in' + plt.rcParams['font.size'] = 12 + + for index, keyword in enumerate(keywords): + values = load_stats_file(stats_file) + plt.plot(values[:, 0], values[:, 1], label=keyword) + + plt.xlim([0, xlim]) + plt.ylim([0, ylim]) + plt.xticks(range(0, xlim + x_step, x_step)) + plt.yticks(range(0, ylim + y_step, y_step)) + plt.xlabel('False Alarm Per Hour') + plt.ylabel('False Rejection Rate (\\%)') + plt.grid(linestyle='--') + plt.legend(loc='best', fontsize=16) + plt.savefig(figure_file) + + +if __name__ == '__main__': + args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path)) + with open(args.cfg_path, 'r') as f: + config = yaml.safe_load(f) + + scoring_conf = config['scoring'] + img_file = os.path.abspath(scoring_conf['img_file']) + stats_file = os.path.abspath(scoring_conf['stats_file']) + keywords = [args.keyword] + plot_det_curve(keywords, stats_file, img_file, 10, 2, 10, 2) + + print('DET curve image saved to: {}'.format(img_file)) diff --git a/paddlespeech/kws/exps/mdtc/score.py b/paddlespeech/kws/exps/mdtc/score.py new file mode 100644 index 00000000..7fe88ea3 --- /dev/null +++ b/paddlespeech/kws/exps/mdtc/score.py @@ -0,0 +1,79 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from wekws(https://github.com/wenet-e2e/wekws) +import argparse +import os + +import paddle +import yaml +from tqdm import tqdm + +from paddlespeech.kws.exps.mdtc.collate import collate_features +from paddlespeech.kws.models.mdtc import KWSModel +from paddlespeech.s2t.utils.dynamic_import import dynamic_import + +# yapf: disable +parser = argparse.ArgumentParser(__doc__) +parser.add_argument("--cfg_path", type=str, required=True) +args = parser.parse_args() +# yapf: enable + +if __name__ == '__main__': + args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path)) + with open(args.cfg_path, 'r') as f: + config = yaml.safe_load(f) + + model_conf = config['model'] + data_conf = config['data'] + feat_conf = config['feature'] + scoring_conf = config['scoring'] + + # Dataset + ds_class = dynamic_import(data_conf['dataset']) + test_ds = ds_class(data_dir=data_conf['data_dir'], mode='test', **feat_conf) + test_sampler = paddle.io.BatchSampler( + test_ds, batch_size=scoring_conf['batch_size'], drop_last=False) + test_loader = paddle.io.DataLoader( + test_ds, + batch_sampler=test_sampler, + num_workers=scoring_conf['num_workers'], + return_list=True, + use_buffer_reader=True, + collate_fn=collate_features, ) + + # Model + backbone_class = dynamic_import(model_conf['backbone']) + backbone = backbone_class(**model_conf['config']) + model = KWSModel(backbone=backbone, num_keywords=model_conf['num_keywords']) + model.set_state_dict(paddle.load(scoring_conf['checkpoint'])) + model.eval() + + with paddle.no_grad(), open( + scoring_conf['score_file'], 'w', encoding='utf8') as fout: + for batch_idx, batch in enumerate( + tqdm(test_loader, total=len(test_loader))): + keys, feats, labels, lengths = batch + logits = model(feats) + num_keywords = logits.shape[2] + for i in range(len(keys)): + key = keys[i] + score = logits[i][:lengths[i]] + for keyword_i in range(num_keywords): + keyword_scores = score[:, keyword_i] + score_frames = ' '.join( + ['{:.6f}'.format(x) for x in keyword_scores.tolist()]) + fout.write( + '{} {} {}\n'.format(key, keyword_i, score_frames)) + + print('Result saved to: {}'.format(scoring_conf['score_file'])) diff --git a/paddlespeech/kws/exps/mdtc/train.py b/paddlespeech/kws/exps/mdtc/train.py new file mode 100644 index 00000000..99e72871 --- /dev/null +++ b/paddlespeech/kws/exps/mdtc/train.py @@ -0,0 +1,168 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os + +import paddle +import yaml + +from paddleaudio.utils import logger +from paddleaudio.utils import Timer +from paddlespeech.kws.exps.mdtc.collate import collate_features +from paddlespeech.kws.models.loss import max_pooling_loss +from paddlespeech.kws.models.mdtc import KWSModel +from paddlespeech.s2t.utils.dynamic_import import dynamic_import + +# yapf: disable +parser = argparse.ArgumentParser(__doc__) +parser.add_argument("--cfg_path", type=str, required=True) +args = parser.parse_args() +# yapf: enable + +if __name__ == '__main__': + nranks = paddle.distributed.get_world_size() + if paddle.distributed.get_world_size() > 1: + paddle.distributed.init_parallel_env() + local_rank = paddle.distributed.get_rank() + + args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path)) + with open(args.cfg_path, 'r') as f: + config = yaml.safe_load(f) + + model_conf = config['model'] + data_conf = config['data'] + feat_conf = config['feature'] + training_conf = config['training'] + + # Dataset + ds_class = dynamic_import(data_conf['dataset']) + train_ds = ds_class( + data_dir=data_conf['data_dir'], mode='train', **feat_conf) + dev_ds = ds_class(data_dir=data_conf['data_dir'], mode='dev', **feat_conf) + + train_sampler = paddle.io.DistributedBatchSampler( + train_ds, + batch_size=training_conf['batch_size'], + shuffle=True, + drop_last=False) + train_loader = paddle.io.DataLoader( + train_ds, + batch_sampler=train_sampler, + num_workers=training_conf['num_workers'], + return_list=True, + use_buffer_reader=True, + collate_fn=collate_features, ) + + # Model + backbone_class = dynamic_import(model_conf['backbone']) + backbone = backbone_class(**model_conf['config']) + model = KWSModel(backbone=backbone, num_keywords=model_conf['num_keywords']) + model = paddle.DataParallel(model) + clip = paddle.nn.ClipGradByGlobalNorm(training_conf['grad_clip']) + optimizer = paddle.optimizer.Adam( + learning_rate=training_conf['learning_rate'], + weight_decay=training_conf['weight_decay'], + parameters=model.parameters(), + grad_clip=clip) + criterion = max_pooling_loss + + steps_per_epoch = len(train_sampler) + timer = Timer(steps_per_epoch * training_conf['epochs']) + timer.start() + + for epoch in range(1, training_conf['epochs'] + 1): + model.train() + + avg_loss = 0 + num_corrects = 0 + num_samples = 0 + for batch_idx, batch in enumerate(train_loader): + keys, feats, labels, lengths = batch + logits = model(feats) + loss, corrects, acc = criterion(logits, labels, lengths) + loss.backward() + optimizer.step() + if isinstance(optimizer._learning_rate, + paddle.optimizer.lr.LRScheduler): + optimizer._learning_rate.step() + optimizer.clear_grad() + + # Calculate loss + avg_loss += loss.numpy()[0] + + # Calculate metrics + num_corrects += corrects + num_samples += feats.shape[0] + + timer.count() + + if (batch_idx + 1 + ) % training_conf['log_freq'] == 0 and local_rank == 0: + lr = optimizer.get_lr() + avg_loss /= training_conf['log_freq'] + avg_acc = num_corrects / num_samples + + print_msg = 'Epoch={}/{}, Step={}/{}'.format( + epoch, training_conf['epochs'], batch_idx + 1, + steps_per_epoch) + print_msg += ' loss={:.4f}'.format(avg_loss) + print_msg += ' acc={:.4f}'.format(avg_acc) + print_msg += ' lr={:.6f} step/sec={:.2f} | ETA {}'.format( + lr, timer.timing, timer.eta) + logger.train(print_msg) + + avg_loss = 0 + num_corrects = 0 + num_samples = 0 + + if epoch % training_conf[ + 'save_freq'] == 0 and batch_idx + 1 == steps_per_epoch and local_rank == 0: + dev_sampler = paddle.io.BatchSampler( + dev_ds, + batch_size=training_conf['batch_size'], + shuffle=False, + drop_last=False) + dev_loader = paddle.io.DataLoader( + dev_ds, + batch_sampler=dev_sampler, + num_workers=training_conf['num_workers'], + return_list=True, + use_buffer_reader=True, + collate_fn=collate_features, ) + + model.eval() + num_corrects = 0 + num_samples = 0 + with logger.processing('Evaluation on validation dataset'): + for batch_idx, batch in enumerate(dev_loader): + keys, feats, labels, lengths = batch + logits = model(feats) + loss, corrects, acc = criterion(logits, labels, lengths) + num_corrects += corrects + num_samples += feats.shape[0] + + eval_acc = num_corrects / num_samples + print_msg = '[Evaluation result]' + print_msg += ' dev_acc={:.4f}'.format(eval_acc) + + logger.eval(print_msg) + + # Save model + save_dir = os.path.join(training_conf['checkpoint_dir'], + 'epoch_{}'.format(epoch)) + logger.info('Saving model checkpoint to {}'.format(save_dir)) + paddle.save(model.state_dict(), + os.path.join(save_dir, 'model.pdparams')) + paddle.save(optimizer.state_dict(), + os.path.join(save_dir, 'model.pdopt')) diff --git a/paddlespeech/kws/models/__init__.py b/paddlespeech/kws/models/__init__.py new file mode 100644 index 00000000..125a0d7a --- /dev/null +++ b/paddlespeech/kws/models/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .mdtc import KWSModel +from .mdtc import MDTC diff --git a/paddlespeech/kws/models/loss.py b/paddlespeech/kws/models/loss.py new file mode 100644 index 00000000..64c9a32c --- /dev/null +++ b/paddlespeech/kws/models/loss.py @@ -0,0 +1,81 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from wekws(https://github.com/wenet-e2e/wekws) +import paddle + + +def padding_mask(lengths: paddle.Tensor) -> paddle.Tensor: + batch_size = lengths.shape[0] + max_len = int(lengths.max().item()) + seq = paddle.arange(max_len, dtype=paddle.int64) + seq = seq.expand((batch_size, max_len)) + return seq >= lengths.unsqueeze(1) + + +def fill_mask_elements(condition: paddle.Tensor, value: float, + x: paddle.Tensor) -> paddle.Tensor: + assert condition.shape == x.shape + values = paddle.ones_like(x, dtype=x.dtype) * value + return paddle.where(condition, values, x) + + +def max_pooling_loss(logits: paddle.Tensor, + target: paddle.Tensor, + lengths: paddle.Tensor, + min_duration: int=0): + + mask = padding_mask(lengths) + num_utts = logits.shape[0] + num_keywords = logits.shape[2] + + loss = 0.0 + for i in range(num_utts): + for j in range(num_keywords): + # Add entropy loss CE = -(t * log(p) + (1 - t) * log(1 - p)) + if target[i] == j: + # For the keyword, do max-polling + prob = logits[i, :, j] + m = mask[i] + if min_duration > 0: + m[:min_duration] = True + prob = fill_mask_elements(m, 0.0, prob) + prob = paddle.clip(prob, 1e-8, 1.0) + max_prob = prob.max() + loss += -paddle.log(max_prob) + else: + # For other keywords or filler, do min-polling + prob = 1 - logits[i, :, j] + prob = fill_mask_elements(mask[i], 1.0, prob) + prob = paddle.clip(prob, 1e-8, 1.0) + min_prob = prob.min() + loss += -paddle.log(min_prob) + loss = loss / num_utts + + # Compute accuracy of current batch + mask = mask.unsqueeze(-1) + logits = fill_mask_elements(mask, 0.0, logits) + max_logits = logits.max(1) + num_correct = 0 + for i in range(num_utts): + max_p = max_logits[i].max(0).item() + idx = max_logits[i].argmax(0).item() + # Predict correct as the i'th keyword + if max_p > 0.5 and idx == target[i].item(): + num_correct += 1 + # Predict correct as the filler, filler id < 0 + if max_p < 0.5 and target[i].item() < 0: + num_correct += 1 + acc = num_correct / num_utts + # acc = 0.0 + return loss, num_correct, acc diff --git a/paddlespeech/kws/models/mdtc.py b/paddlespeech/kws/models/mdtc.py new file mode 100644 index 00000000..5d2e5de6 --- /dev/null +++ b/paddlespeech/kws/models/mdtc.py @@ -0,0 +1,233 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from wekws(https://github.com/wenet-e2e/wekws) +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class DSDilatedConv1d(nn.Layer): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + dilation: int=1, + stride: int=1, + bias: bool=True, ): + super(DSDilatedConv1d, self).__init__() + self.receptive_fields = dilation * (kernel_size - 1) + self.conv = nn.Conv1D( + in_channels, + in_channels, + kernel_size, + padding=0, + dilation=dilation, + stride=stride, + groups=in_channels, + bias_attr=bias, ) + self.bn = nn.BatchNorm1D(in_channels) + self.pointwise = nn.Conv1D( + in_channels, + out_channels, + kernel_size=1, + padding=0, + dilation=1, + bias_attr=bias) + + def forward(self, inputs: paddle.Tensor): + outputs = self.conv(inputs) + outputs = self.bn(outputs) + outputs = self.pointwise(outputs) + return outputs + + +class TCNBlock(nn.Layer): + def __init__( + self, + in_channels: int, + res_channels: int, + kernel_size: int, + dilation: int, + causal: bool, ): + super(TCNBlock, self).__init__() + self.in_channels = in_channels + self.res_channels = res_channels + self.kernel_size = kernel_size + self.dilation = dilation + self.causal = causal + self.receptive_fields = dilation * (kernel_size - 1) + self.half_receptive_fields = self.receptive_fields // 2 + self.conv1 = DSDilatedConv1d( + in_channels=in_channels, + out_channels=res_channels, + kernel_size=kernel_size, + dilation=dilation, ) + self.bn1 = nn.BatchNorm1D(res_channels) + self.relu1 = nn.ReLU() + + self.conv2 = nn.Conv1D( + in_channels=res_channels, out_channels=res_channels, kernel_size=1) + self.bn2 = nn.BatchNorm1D(res_channels) + self.relu2 = nn.ReLU() + + def forward(self, inputs: paddle.Tensor): + outputs = self.relu1(self.bn1(self.conv1(inputs))) + outputs = self.bn2(self.conv2(outputs)) + if self.causal: + inputs = inputs[:, :, self.receptive_fields:] + else: + inputs = inputs[:, :, self.half_receptive_fields: + -self.half_receptive_fields] + if self.in_channels == self.res_channels: + res_out = self.relu2(outputs + inputs) + else: + res_out = self.relu2(outputs) + return res_out + + +class TCNStack(nn.Layer): + def __init__( + self, + in_channels: int, + stack_num: int, + stack_size: int, + res_channels: int, + kernel_size: int, + causal: bool, ): + super(TCNStack, self).__init__() + self.in_channels = in_channels + self.stack_num = stack_num + self.stack_size = stack_size + self.res_channels = res_channels + self.kernel_size = kernel_size + self.causal = causal + self.res_blocks = self.stack_tcn_blocks() + self.receptive_fields = self.calculate_receptive_fields() + self.res_blocks = nn.Sequential(*self.res_blocks) + + def calculate_receptive_fields(self): + receptive_fields = 0 + for block in self.res_blocks: + receptive_fields += block.receptive_fields + return receptive_fields + + def build_dilations(self): + dilations = [] + for s in range(0, self.stack_size): + for l in range(0, self.stack_num): + dilations.append(2**l) + return dilations + + def stack_tcn_blocks(self): + dilations = self.build_dilations() + res_blocks = nn.LayerList() + + res_blocks.append( + TCNBlock( + self.in_channels, + self.res_channels, + self.kernel_size, + dilations[0], + self.causal, )) + for dilation in dilations[1:]: + res_blocks.append( + TCNBlock( + self.res_channels, + self.res_channels, + self.kernel_size, + dilation, + self.causal, )) + return res_blocks + + def forward(self, inputs: paddle.Tensor): + outputs = self.res_blocks(inputs) + return outputs + + +class MDTC(nn.Layer): + def __init__( + self, + stack_num: int, + stack_size: int, + in_channels: int, + res_channels: int, + kernel_size: int, + causal: bool=True, ): + super(MDTC, self).__init__() + assert kernel_size % 2 == 1 + self.kernel_size = kernel_size + self.causal = causal + self.preprocessor = TCNBlock( + in_channels, res_channels, kernel_size, dilation=1, causal=causal) + self.relu = nn.ReLU() + self.blocks = nn.LayerList() + self.receptive_fields = self.preprocessor.receptive_fields + for i in range(stack_num): + self.blocks.append( + TCNStack(res_channels, stack_size, 1, res_channels, kernel_size, + causal)) + self.receptive_fields += self.blocks[-1].receptive_fields + self.half_receptive_fields = self.receptive_fields // 2 + self.hidden_dim = res_channels + + def forward(self, x: paddle.Tensor): + if self.causal: + outputs = F.pad(x, (0, 0, self.receptive_fields, 0, 0, 0), + 'constant') + else: + outputs = F.pad( + x, + (0, 0, self.half_receptive_fields, self.half_receptive_fields, + 0, 0), + 'constant', ) + outputs = outputs.transpose([0, 2, 1]) + outputs_list = [] + outputs = self.relu(self.preprocessor(outputs)) + for block in self.blocks: + outputs = block(outputs) + outputs_list.append(outputs) + + normalized_outputs = [] + output_size = outputs_list[-1].shape[-1] + for x in outputs_list: + remove_length = x.shape[-1] - output_size + if self.causal and remove_length > 0: + normalized_outputs.append(x[:, :, remove_length:]) + elif not self.causal and remove_length > 1: + half_remove_length = remove_length // 2 + normalized_outputs.append( + x[:, :, half_remove_length:-half_remove_length]) + else: + normalized_outputs.append(x) + + outputs = paddle.zeros_like( + outputs_list[-1], dtype=outputs_list[-1].dtype) + for x in normalized_outputs: + outputs += x + outputs = outputs.transpose([0, 2, 1]) + return outputs, None + + +class KWSModel(nn.Layer): + def __init__(self, backbone, num_keywords): + super(KWSModel, self).__init__() + self.backbone = backbone + self.linear = nn.Linear(self.backbone.hidden_dim, num_keywords) + self.activation = nn.Sigmoid() + + def forward(self, x): + outputs = self.backbone(x) + outputs = self.linear(outputs) + return self.activation(outputs)