diff --git a/examples/hey_snips/README.md b/examples/hey_snips/README.md new file mode 100644 index 00000000..e69de29b diff --git a/examples/hey_snips/kws0/RESULTS.md b/examples/hey_snips/kws0/RESULTS.md new file mode 100644 index 00000000..e69de29b diff --git a/examples/hey_snips/kws0/path.sh b/examples/hey_snips/kws0/path.sh new file mode 100755 index 00000000..54a430d4 --- /dev/null +++ b/examples/hey_snips/kws0/path.sh @@ -0,0 +1,28 @@ +#!/bin/bash +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +export MAIN_ROOT=`realpath ${PWD}/../../../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +export PYTHONDONTWRITEBYTECODE=1 +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ + +MODEL=mdtc +export BIN_DIR=${MAIN_ROOT}/paddlespeech/kws/exps/${MODEL} \ No newline at end of file diff --git a/examples/hey_snips/kws0/run.sh b/examples/hey_snips/kws0/run.sh new file mode 100755 index 00000000..69b1ad6a --- /dev/null +++ b/examples/hey_snips/kws0/run.sh @@ -0,0 +1,47 @@ +#!/bin/bash +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +. ./path.sh +set -e + +stage=0 +stop_stage=50 + +# data directory +# if we set the variable ${dir}, we will store the wav info to this directory +# otherwise, we will store the wav info to vox1 and vox2 directory respectively +# vox2 wav path, we must convert the m4a format to wav format +dir=data/ # data info directory + +exp_dir=exp/ecapa-tdnn-vox12-big/ # experiment directory +conf_path=conf/mdtc.yaml +gpus=0,1,2,3 + +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; + +mkdir -p ${exp_dir} + +if [ $stage -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # stage 0: data prepare for vox1 and vox2, vox2 must be converted from m4a to wav + bash ./local/data.sh ${dir} ${conf_path}|| exit -1; +fi + +if [ $stage -le 1 ] && [ ${stop_stage} -ge 1 ]; then + CUDA_VISIBLE_DEVICES=${gpus} bash ./local/train.sh ${dir} ${exp_dir} ${conf_path} +fi + +if [ $stage -le 2 ] && [ ${stop_stage} -ge 2 ]; then + CUDA_VISIBLE_DEVICES=0 bash ./local/test.sh ${dir} ${exp_dir} ${conf_path} +fi diff --git a/paddlespeech/kws/exps/mdtc/compute_det.py b/paddlespeech/kws/exps/mdtc/compute_det.py new file mode 100644 index 00000000..19a3fe14 --- /dev/null +++ b/paddlespeech/kws/exps/mdtc/compute_det.py @@ -0,0 +1,121 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +import sys + +from tqdm import tqdm + + +def load_label_and_score(keyword, label_file, score_file): + # score_table: {uttid: [keywordlist]} + score_table = {} + with open(score_file, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split() + key = arr[0] + current_keyword = arr[1] + str_list = arr[2:] + if int(current_keyword) == keyword: + scores = list(map(float, str_list)) + if key not in score_table: + score_table.update({key: scores}) + keyword_table = {} + filler_table = {} + filler_duration = 0.0 + with open(label_file, 'r', encoding='utf8') as fin: + for line in fin: + obj = json.loads(line.strip()) + assert 'key' in obj + assert 'txt' in obj + assert 'duration' in obj + key = obj['key'] + index = obj['txt'] + duration = obj['duration'] + assert key in score_table + if index == keyword: + keyword_table[key] = score_table[key] + else: + filler_table[key] = score_table[key] + filler_duration += duration + return keyword_table, filler_table, filler_duration + + +class Args: + def __init__(self): + self.test_data = '/ssd3/chenxiaojie06/PaddleSpeech/DeepSpeech/paddlespeech/kws/models/data/test/data.list' + self.keyword = 0 + self.score_file = os.path.join( + os.path.abspath(sys.argv[1]), 'score.txt') + self.stats_file = os.path.join( + os.path.abspath(sys.argv[1]), 'stats.0.txt') + self.step = 0.01 + self.window_shift = 50 + + +args = Args() + +if __name__ == '__main__': + # parser = argparse.ArgumentParser(description='compute det curve') + # parser.add_argument('--test_data', required=True, help='label file') + # parser.add_argument('--keyword', type=int, default=0, help='keyword label') + # parser.add_argument('--score_file', required=True, help='score file') + # parser.add_argument('--step', type=float, default=0.01, + # help='threshold step') + # parser.add_argument('--window_shift', type=int, default=50, + # help='window_shift is used to skip the frames after triggered') + # parser.add_argument('--stats_file', + # required=True, + # help='false reject/alarm stats file') + # args = parser.parse_args() + + window_shift = args.window_shift + keyword_table, filler_table, filler_duration = load_label_and_score( + args.keyword, args.test_data, args.score_file) + print('Filler total duration Hours: {}'.format(filler_duration / 3600.0)) + pbar = tqdm(total=int(1.0 / args.step)) + with open(args.stats_file, 'w', encoding='utf8') as fout: + keyword_index = int(args.keyword) + threshold = 0.0 + while threshold <= 1.0: + num_false_reject = 0 + # transverse the all keyword_table + for key, score_list in keyword_table.items(): + # computer positive test sample, use the max score of list. + score = max(score_list) + if float(score) < threshold: + num_false_reject += 1 + num_false_alarm = 0 + # transverse the all filler_table + for key, score_list in filler_table.items(): + i = 0 + while i < len(score_list): + if score_list[i] >= threshold: + num_false_alarm += 1 + i += window_shift + else: + i += 1 + if len(keyword_table) != 0: + false_reject_rate = num_false_reject / len(keyword_table) + num_false_alarm = max(num_false_alarm, 1e-6) + if filler_duration != 0: + false_alarm_per_hour = num_false_alarm / \ + (filler_duration / 3600.0) + fout.write('{:.6f} {:.6f} {:.6f}\n'.format( + threshold, false_alarm_per_hour, false_reject_rate)) + threshold += args.step + pbar.update(1) + + pbar.close() + print('DET saved to: {}'.format(args.stats_file)) diff --git a/paddlespeech/kws/exps/mdtc/plot_det_curve.py b/paddlespeech/kws/exps/mdtc/plot_det_curve.py new file mode 100644 index 00000000..7986574f --- /dev/null +++ b/paddlespeech/kws/exps/mdtc/plot_det_curve.py @@ -0,0 +1,63 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys + +import matplotlib.pyplot as plt +import numpy as np + + +def load_stats_file(stats_file): + values = [] + with open(stats_file, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split() + threshold, fa_per_hour, frr = arr + values.append([float(fa_per_hour), float(frr) * 100]) + values.reverse() + return np.array(values) + + +def plot_det_curve(keywords, stats_dir, figure_file, xlim, x_step, ylim, + y_step): + plt.figure(dpi=200) + plt.rcParams['xtick.direction'] = 'in' + plt.rcParams['ytick.direction'] = 'in' + plt.rcParams['font.size'] = 12 + + for index, keyword in enumerate(keywords): + stats_file = os.path.join(stats_dir, 'stats.' + str(index) + '.txt') + values = load_stats_file(stats_file) + plt.plot(values[:, 0], values[:, 1], label=keyword) + + plt.xlim([0, xlim]) + plt.ylim([0, ylim]) + plt.xticks(range(0, xlim + x_step, x_step)) + plt.yticks(range(0, ylim + y_step, y_step)) + plt.xlabel('False Alarm Per Hour') + plt.ylabel('False Rejection Rate (\\%)') + plt.grid(linestyle='--') + plt.legend(loc='best', fontsize=16) + plt.savefig(figure_file) + + +if __name__ == '__main__': + + keywords = ['Hey_Snips'] + img_path = os.path.join(os.path.abspath(sys.argv[1]), 'det.png') + + plot_det_curve(keywords, + os.path.abspath(sys.argv[1]), img_path, 10, 2, 10, 2) + + print('DET curve image saved to: {}'.format(img_path)) diff --git a/paddlespeech/kws/exps/mdtc/score.py b/paddlespeech/kws/exps/mdtc/score.py new file mode 100644 index 00000000..9fdbcf49 --- /dev/null +++ b/paddlespeech/kws/exps/mdtc/score.py @@ -0,0 +1,103 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +import time + +import paddle +from mdtc import KWSModel +from mdtc import MDTC +from tqdm import tqdm + +from paddleaudio.datasets import HeySnips + + +def collate_features(batch): + # (key, feat, label) in one sample + collate_start = time.time() + keys = [] + feats = [] + labels = [] + lengths = [] + for sample in batch: + keys.append(sample[0]) + feats.append(sample[1]) + labels.append(sample[2]) + lengths.append(sample[1].shape[0]) + + max_length = max(lengths) + for i in range(len(feats)): + feats[i] = paddle.nn.functional.pad( + feats[i], [0, max_length - feats[i].shape[0], 0, 0], + data_format='NLC') + + return keys, paddle.stack(feats), paddle.to_tensor( + labels), paddle.to_tensor(lengths) + + +if __name__ == '__main__': + # Dataset + feat_conf = { + # 'n_mfcc': 80, + 'n_mels': 80, + 'frame_shift': 10, + 'frame_length': 25, + # 'dither': 1.0, + } + test_ds = HeySnips( + mode='test', feat_type='kaldi_fbank', sample_rate=16000, **feat_conf) + test_sampler = paddle.io.BatchSampler( + test_ds, batch_size=32, drop_last=False) + test_loader = paddle.io.DataLoader( + test_ds, + batch_sampler=test_sampler, + num_workers=16, + return_list=True, + use_buffer_reader=True, + collate_fn=collate_features, ) + + # Model + backbone = MDTC( + stack_num=3, + stack_size=4, + in_channels=80, + res_channels=32, + kernel_size=5, + causal=True, ) + model = KWSModel(backbone=backbone, num_keywords=1) + model = paddle.DataParallel(model) + # kws_checkpoint = '/ssd3/chenxiaojie06/PaddleSpeech/DeepSpeech/paddlespeech/kws/models/checkpoint/epoch_10_0.8903940343290826/model.pdparams' + kws_checkpoint = os.path.join( + os.path.abspath(sys.argv[1]), 'model.pdparams') + model.set_state_dict(paddle.load(kws_checkpoint)) + model.eval() + + score_abs_path = os.path.join(os.path.abspath(sys.argv[1]), 'score.txt') + with paddle.no_grad(), open(score_abs_path, 'w', encoding='utf8') as fout: + for batch_idx, batch in enumerate( + tqdm(test_loader, total=len(test_loader))): + keys, feats, labels, lengths = batch + logits = model(feats) + num_keywords = logits.shape[2] + for i in range(len(keys)): + key = keys[i] + score = logits[i][:lengths[i]] + for keyword_i in range(num_keywords): + keyword_scores = score[:, keyword_i] + score_frames = ' '.join( + ['{:.6f}'.format(x) for x in keyword_scores.tolist()]) + fout.write( + '{} {} {}\n'.format(key, keyword_i, score_frames)) + + print('Scores saved to: {}'.format(score_abs_path)) diff --git a/paddlespeech/kws/exps/mdtc/train.py b/paddlespeech/kws/exps/mdtc/train.py new file mode 100644 index 00000000..17a9acfc --- /dev/null +++ b/paddlespeech/kws/exps/mdtc/train.py @@ -0,0 +1,205 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import time + +import paddle +from loss import max_pooling_loss +from mdtc import KWSModel +from mdtc import MDTC + +from paddleaudio.datasets import HeySnips +from paddleaudio.utils import logger +from paddleaudio.utils import Timer + + +def collate_features(batch): + # (key, feat, label) + collate_start = time.time() + keys = [] + feats = [] + labels = [] + lengths = [] + for sample in batch: + keys.append(sample[0]) + feats.append(sample[1]) + labels.append(sample[2]) + lengths.append(sample[1].shape[0]) + + max_length = max(lengths) + for i in range(len(feats)): + feats[i] = paddle.nn.functional.pad( + feats[i], [0, max_length - feats[i].shape[0], 0, 0], + data_format='NLC') + + return keys, paddle.stack(feats), paddle.to_tensor( + labels), paddle.to_tensor(lengths) + + +if __name__ == '__main__': + # Dataset + feat_conf = { + # 'n_mfcc': 80, + 'n_mels': 80, + 'frame_shift': 10, + 'frame_length': 25, + # 'dither': 1.0, + } + data_dir = '/ssd1/chenxiaojie06/datasets/hey_snips/hey_snips_research_6k_en_train_eval_clean_ter' + train_ds = HeySnips( + data_dir=data_dir, + mode='train', + feat_type='kaldi_fbank', + sample_rate=16000, + **feat_conf) + dev_ds = HeySnips( + data_dir=data_dir, + mode='dev', + feat_type='kaldi_fbank', + sample_rate=16000, + **feat_conf) + + training_conf = { + 'epochs': 100, + 'learning_rate': 0.001, + 'weight_decay': 0.00005, + 'num_workers': 16, + 'batch_size': 100, + 'checkpoint_dir': './checkpoint', + 'save_freq': 10, + 'log_freq': 10, + } + + train_sampler = paddle.io.BatchSampler( + train_ds, + batch_size=training_conf['batch_size'], + shuffle=True, + drop_last=False) + train_loader = paddle.io.DataLoader( + train_ds, + batch_sampler=train_sampler, + num_workers=training_conf['num_workers'], + return_list=True, + use_buffer_reader=True, + collate_fn=collate_features, ) + + # Model + backbone = MDTC( + stack_num=3, + stack_size=4, + in_channels=80, + res_channels=32, + kernel_size=5, + causal=True, ) + model = KWSModel(backbone=backbone, num_keywords=1) + model = paddle.DataParallel(model) + clip = paddle.nn.ClipGradByGlobalNorm(5.0) + optimizer = paddle.optimizer.Adam( + learning_rate=training_conf['learning_rate'], + weight_decay=training_conf['weight_decay'], + parameters=model.parameters(), + grad_clip=clip) + criterion = max_pooling_loss + + steps_per_epoch = len(train_sampler) + timer = Timer(steps_per_epoch * training_conf['epochs']) + timer.start() + + for epoch in range(1, training_conf['epochs'] + 1): + model.train() + + avg_loss = 0 + num_corrects = 0 + num_samples = 0 + batch_start = time.time() + for batch_idx, batch in enumerate(train_loader): + # print('Fetch one batch: {:.4f}'.format(time.time()-batch_start)) + keys, feats, labels, lengths = batch + logits = model(feats) + loss, corrects, acc = criterion(logits, labels, lengths) + loss.backward() + optimizer.step() + if isinstance(optimizer._learning_rate, + paddle.optimizer.lr.LRScheduler): + optimizer._learning_rate.step() + optimizer.clear_grad() + + # Calculate loss + avg_loss += loss.numpy()[0] + + # Calculate metrics + num_corrects += corrects + num_samples += feats.shape[0] + + timer.count() + + if (batch_idx + 1) % training_conf['log_freq'] == 0: + lr = optimizer.get_lr() + avg_loss /= training_conf['log_freq'] + avg_acc = num_corrects / num_samples + + print_msg = 'Epoch={}/{}, Step={}/{}'.format( + epoch, training_conf['epochs'], batch_idx + 1, + steps_per_epoch) + print_msg += ' loss={:.4f}'.format(avg_loss) + print_msg += ' acc={:.4f}'.format(avg_acc) + print_msg += ' lr={:.6f} step/sec={:.2f} | ETA {}'.format( + lr, timer.timing, timer.eta) + logger.train(print_msg) + + avg_loss = 0 + num_corrects = 0 + num_samples = 0 + batch_start = time.time() + + if epoch % training_conf[ + 'save_freq'] == 0 and batch_idx + 1 == steps_per_epoch: + dev_sampler = paddle.io.BatchSampler( + dev_ds, + batch_size=training_conf['batch_size'], + shuffle=False, + drop_last=False) + dev_loader = paddle.io.DataLoader( + dev_ds, + batch_sampler=dev_sampler, + num_workers=training_conf['num_workers'], + return_list=True, + use_buffer_reader=True, + collate_fn=collate_features, ) + + model.eval() + num_corrects = 0 + num_samples = 0 + with logger.processing('Evaluation on validation dataset'): + for batch_idx, batch in enumerate(dev_loader): + keys, feats, labels, lengths = batch + logits = model(feats) + loss, corrects, acc = criterion(logits, labels, lengths) + num_corrects += corrects + num_samples += feats.shape[0] + + eval_acc = num_corrects / num_samples + print_msg = '[Evaluation result]' + print_msg += ' dev_acc={:.4f}'.format(eval_acc) + + logger.eval(print_msg) + + # Save model + save_dir = os.path.join(training_conf['checkpoint_dir'], + 'epoch_{}_{:.4f}'.format(epoch, eval_acc)) + logger.info('Saving model checkpoint to {}'.format(save_dir)) + paddle.save(model.state_dict(), + os.path.join(save_dir, 'model.pdparams')) + paddle.save(optimizer.state_dict(), + os.path.join(save_dir, 'model.pdopt'))