diff --git a/audio/paddleaudio/datasets/hey_snips.py b/audio/paddleaudio/datasets/hey_snips.py index 53aebdf8..7a67b843 100644 --- a/audio/paddleaudio/datasets/hey_snips.py +++ b/audio/paddleaudio/datasets/hey_snips.py @@ -63,10 +63,12 @@ class HeySnips(AudioClassificationDataset): files = [] labels = [] self.keys = [] + self.durations = [] for sample in meta_info: - key, target, _, wav = sample + key, target, duration, wav = sample files.append(wav) labels.append(int(target)) self.keys.append(key) + self.durations.append(float(duration)) return files, labels diff --git a/examples/hey_snips/README.md b/examples/hey_snips/README.md index e69de29b..be8d142b 100644 --- a/examples/hey_snips/README.md +++ b/examples/hey_snips/README.md @@ -0,0 +1,22 @@ +# MDTC Keyword Spotting with HeySnips Dataset + +## Dataset + +Before running scripts, you **MUST** follow this instruction to download the dataset: https://github.com/sonos/keyword-spotting-research-datasets + +After you download and decompress the dataset archive, you should **REPLACE** the value of `data_dir` in `conf/*.yaml` to complete dataset config. + +## Get Started + +In this section, we will train the [MDTC](https://arxiv.org/pdf/2102.13552.pdf) model and evaluate on "Hey Snips" dataset. + +```sh +CUDA_VISIBLE_DEVICES=0,1 ./run.sh conf/mdtc.yaml +``` + +This script contains training and scoring steps. You can just set the `CUDA_VISIBLE_DEVICES` environment var to run on single gpu or multi-gpus. + +The vars `stage` and `stop_stage` in `./run.sh` controls the running steps: +- stage 1: Training from scratch. +- stage 2: Evaluating model on test dataset and computing detection error tradeoff(DET) of all trigger thresholds. +- stage 3: Plotting the DET cruve for visualizaiton. diff --git a/examples/hey_snips/RESULTS.md b/examples/hey_snips/RESULTS.md new file mode 100644 index 00000000..ba263906 --- /dev/null +++ b/examples/hey_snips/RESULTS.md @@ -0,0 +1,8 @@ + +## Metrics + +We mesure FRRs with fixing false alarms in one hour: + +|Model|False Alarm| False Reject Rate| +|--|--|--| +|MDTC| 1| 0.003559 | diff --git a/examples/hey_snips/kws0/RESULTS.md b/examples/hey_snips/kws0/RESULTS.md deleted file mode 100644 index e69de29b..00000000 diff --git a/examples/hey_snips/kws0/conf/mdtc.yaml b/examples/hey_snips/kws0/conf/mdtc.yaml new file mode 100644 index 00000000..3ce9f9d0 --- /dev/null +++ b/examples/hey_snips/kws0/conf/mdtc.yaml @@ -0,0 +1,39 @@ +data: + data_dir: '/PATH/TO/DATA/hey_snips_research_6k_en_train_eval_clean_ter' + dataset: 'paddleaudio.datasets:HeySnips' + +model: + num_keywords: 1 + backbone: 'paddlespeech.kws.models:MDTC' + config: + stack_num: 3 + stack_size: 4 + in_channels: 80 + res_channels: 32 + kernel_size: 5 + +feature: + feat_type: 'kaldi_fbank' + sample_rate: 16000 + frame_shift: 10 + frame_length: 25 + n_mels: 80 + +training: + epochs: 100 + num_workers: 16 + batch_size: 100 + checkpoint_dir: './checkpoint' + save_freq: 10 + log_freq: 10 + learning_rate: 0.001 + weight_decay: 0.00005 + grad_clip: 5.0 + +scoring: + batch_size: 100 + num_workers: 16 + checkpoint: './checkpoint/epoch_100/model.pdparams' + score_file: './scores.txt' + stats_file: './stats.0.txt' + img_file: './det.png' \ No newline at end of file diff --git a/examples/hey_snips/kws0/local/plot.sh b/examples/hey_snips/kws0/local/plot.sh new file mode 100755 index 00000000..5869e50b --- /dev/null +++ b/examples/hey_snips/kws0/local/plot.sh @@ -0,0 +1,2 @@ +#!/bin/bash +python3 ${BIN_DIR}/plot_det_curve.py --cfg_path=$1 --keyword HeySnips diff --git a/examples/hey_snips/kws0/local/score.sh b/examples/hey_snips/kws0/local/score.sh new file mode 100755 index 00000000..ed21d08c --- /dev/null +++ b/examples/hey_snips/kws0/local/score.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +python3 ${BIN_DIR}/score.py --cfg_path=$1 + +python3 ${BIN_DIR}/compute_det.py --cfg_path=$1 diff --git a/examples/hey_snips/kws0/local/train.sh b/examples/hey_snips/kws0/local/train.sh new file mode 100755 index 00000000..cab547b8 --- /dev/null +++ b/examples/hey_snips/kws0/local/train.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +ngpu=$1 +cfg_path=$2 + +if [ ${ngpu} -gt 0 ]; then + python3 -m paddle.distributed.launch --gpus $CUDA_VISIBLE_DEVICES ${BIN_DIR}/train.py \ + --cfg_path ${cfg_path} +else + python3 ${BIN_DIR}/train.py \ + --cfg_path ${cfg_path} +fi diff --git a/examples/hey_snips/kws0/run.sh b/examples/hey_snips/kws0/run.sh index 69b1ad6a..d6d1d878 100755 --- a/examples/hey_snips/kws0/run.sh +++ b/examples/hey_snips/kws0/run.sh @@ -13,35 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -. ./path.sh set -e +source path.sh -stage=0 -stop_stage=50 +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') -# data directory -# if we set the variable ${dir}, we will store the wav info to this directory -# otherwise, we will store the wav info to vox1 and vox2 directory respectively -# vox2 wav path, we must convert the m4a format to wav format -dir=data/ # data info directory +stage=1 +stop_stage=3 -exp_dir=exp/ecapa-tdnn-vox12-big/ # experiment directory -conf_path=conf/mdtc.yaml -gpus=0,1,2,3 +cfg_path=$1 -source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; - -mkdir -p ${exp_dir} - -if [ $stage -le 0 ] && [ ${stop_stage} -ge 0 ]; then - # stage 0: data prepare for vox1 and vox2, vox2 must be converted from m4a to wav - bash ./local/data.sh ${dir} ${conf_path}|| exit -1; +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + ./local/train.sh ${ngpu} ${cfg_path} || exit -1 fi -if [ $stage -le 1 ] && [ ${stop_stage} -ge 1 ]; then - CUDA_VISIBLE_DEVICES=${gpus} bash ./local/train.sh ${dir} ${exp_dir} ${conf_path} +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + ./local/score.sh ${cfg_path} || exit -1 fi -if [ $stage -le 2 ] && [ ${stop_stage} -ge 2 ]; then - CUDA_VISIBLE_DEVICES=0 bash ./local/test.sh ${dir} ${exp_dir} ${conf_path} -fi +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + ./local/plot.sh ${cfg_path} || exit -1 +fi \ No newline at end of file diff --git a/paddlespeech/kws/exps/mdtc/collate.py b/paddlespeech/kws/exps/mdtc/collate.py new file mode 100644 index 00000000..dcc81123 --- /dev/null +++ b/paddlespeech/kws/exps/mdtc/collate.py @@ -0,0 +1,39 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time + +import paddle + + +def collate_features(batch): + # (key, feat, label) + collate_start = time.time() + keys = [] + feats = [] + labels = [] + lengths = [] + for sample in batch: + keys.append(sample[0]) + feats.append(sample[1]) + labels.append(sample[2]) + lengths.append(sample[1].shape[0]) + + max_length = max(lengths) + for i in range(len(feats)): + feats[i] = paddle.nn.functional.pad( + feats[i], [0, max_length - feats[i].shape[0], 0, 0], + data_format='NLC') + + return keys, paddle.stack(feats), paddle.to_tensor( + labels), paddle.to_tensor(lengths) diff --git a/paddlespeech/kws/exps/mdtc/compute_det.py b/paddlespeech/kws/exps/mdtc/compute_det.py index 19a3fe14..91b02ff6 100644 --- a/paddlespeech/kws/exps/mdtc/compute_det.py +++ b/paddlespeech/kws/exps/mdtc/compute_det.py @@ -11,15 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json +# Modified from wekws(https://github.com/wenet-e2e/wekws) +import argparse import os -import sys +import yaml from tqdm import tqdm +from paddlespeech.s2t.utils.dynamic_import import dynamic_import -def load_label_and_score(keyword, label_file, score_file): - # score_table: {uttid: [keywordlist]} +# yapf: disable +parser = argparse.ArgumentParser(__doc__) +parser.add_argument("--cfg_path", type=str, required=True) +parser.add_argument('--keyword', type=int, default=0, help='keyword label') +parser.add_argument('--step', type=float, default=0.01, help='threshold step') +parser.add_argument('--window_shift', type=int, default=50, help='window_shift is used to skip the frames after triggered') +args = parser.parse_args() +# yapf: enable + + +def load_label_and_score(keyword, ds, score_file): score_table = {} with open(score_file, 'r', encoding='utf8') as fin: for line in fin: @@ -34,59 +45,40 @@ def load_label_and_score(keyword, label_file, score_file): keyword_table = {} filler_table = {} filler_duration = 0.0 - with open(label_file, 'r', encoding='utf8') as fin: - for line in fin: - obj = json.loads(line.strip()) - assert 'key' in obj - assert 'txt' in obj - assert 'duration' in obj - key = obj['key'] - index = obj['txt'] - duration = obj['duration'] - assert key in score_table - if index == keyword: - keyword_table[key] = score_table[key] - else: - filler_table[key] = score_table[key] - filler_duration += duration + + for key, index, duration in zip(ds.keys, ds.labels, ds.durations): + assert key in score_table + if index == keyword: + keyword_table[key] = score_table[key] + else: + filler_table[key] = score_table[key] + filler_duration += duration + return keyword_table, filler_table, filler_duration -class Args: - def __init__(self): - self.test_data = '/ssd3/chenxiaojie06/PaddleSpeech/DeepSpeech/paddlespeech/kws/models/data/test/data.list' - self.keyword = 0 - self.score_file = os.path.join( - os.path.abspath(sys.argv[1]), 'score.txt') - self.stats_file = os.path.join( - os.path.abspath(sys.argv[1]), 'stats.0.txt') - self.step = 0.01 - self.window_shift = 50 +if __name__ == '__main__': + args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path)) + with open(args.cfg_path, 'r') as f: + config = yaml.safe_load(f) + data_conf = config['data'] + feat_conf = config['feature'] + scoring_conf = config['scoring'] -args = Args() + # Dataset + ds_class = dynamic_import(data_conf['dataset']) + test_ds = ds_class(data_dir=data_conf['data_dir'], mode='test', **feat_conf) -if __name__ == '__main__': - # parser = argparse.ArgumentParser(description='compute det curve') - # parser.add_argument('--test_data', required=True, help='label file') - # parser.add_argument('--keyword', type=int, default=0, help='keyword label') - # parser.add_argument('--score_file', required=True, help='score file') - # parser.add_argument('--step', type=float, default=0.01, - # help='threshold step') - # parser.add_argument('--window_shift', type=int, default=50, - # help='window_shift is used to skip the frames after triggered') - # parser.add_argument('--stats_file', - # required=True, - # help='false reject/alarm stats file') - # args = parser.parse_args() + score_file = os.path.abspath(scoring_conf['score_file']) + stats_file = os.path.abspath(scoring_conf['stats_file']) - window_shift = args.window_shift keyword_table, filler_table, filler_duration = load_label_and_score( - args.keyword, args.test_data, args.score_file) + args.keyword, test_ds, score_file) print('Filler total duration Hours: {}'.format(filler_duration / 3600.0)) pbar = tqdm(total=int(1.0 / args.step)) - with open(args.stats_file, 'w', encoding='utf8') as fout: - keyword_index = int(args.keyword) + with open(stats_file, 'w', encoding='utf8') as fout: + keyword_index = args.keyword threshold = 0.0 while threshold <= 1.0: num_false_reject = 0 @@ -103,7 +95,7 @@ if __name__ == '__main__': while i < len(score_list): if score_list[i] >= threshold: num_false_alarm += 1 - i += window_shift + i += args.window_shift else: i += 1 if len(keyword_table) != 0: @@ -118,4 +110,4 @@ if __name__ == '__main__': pbar.update(1) pbar.close() - print('DET saved to: {}'.format(args.stats_file)) + print('DET saved to: {}'.format(stats_file)) diff --git a/paddlespeech/kws/exps/mdtc/plot_det_curve.py b/paddlespeech/kws/exps/mdtc/plot_det_curve.py index 7986574f..ac920358 100644 --- a/paddlespeech/kws/exps/mdtc/plot_det_curve.py +++ b/paddlespeech/kws/exps/mdtc/plot_det_curve.py @@ -11,11 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# Modified from wekws(https://github.com/wenet-e2e/wekws) +import argparse import os -import sys import matplotlib.pyplot as plt import numpy as np +import yaml + +# yapf: disable +parser = argparse.ArgumentParser(__doc__) +parser.add_argument("--cfg_path", type=str, required=True) +parser.add_argument("--keyword", type=str, required=True) +args = parser.parse_args() +# yapf: enable def load_stats_file(stats_file): @@ -29,7 +38,7 @@ def load_stats_file(stats_file): return np.array(values) -def plot_det_curve(keywords, stats_dir, figure_file, xlim, x_step, ylim, +def plot_det_curve(keywords, stats_file, figure_file, xlim, x_step, ylim, y_step): plt.figure(dpi=200) plt.rcParams['xtick.direction'] = 'in' @@ -37,7 +46,6 @@ def plot_det_curve(keywords, stats_dir, figure_file, xlim, x_step, ylim, plt.rcParams['font.size'] = 12 for index, keyword in enumerate(keywords): - stats_file = os.path.join(stats_dir, 'stats.' + str(index) + '.txt') values = load_stats_file(stats_file) plt.plot(values[:, 0], values[:, 1], label=keyword) @@ -53,11 +61,14 @@ def plot_det_curve(keywords, stats_dir, figure_file, xlim, x_step, ylim, if __name__ == '__main__': + args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path)) + with open(args.cfg_path, 'r') as f: + config = yaml.safe_load(f) - keywords = ['Hey_Snips'] - img_path = os.path.join(os.path.abspath(sys.argv[1]), 'det.png') - - plot_det_curve(keywords, - os.path.abspath(sys.argv[1]), img_path, 10, 2, 10, 2) + scoring_conf = config['scoring'] + img_file = os.path.abspath(scoring_conf['img_file']) + stats_file = os.path.abspath(scoring_conf['stats_file']) + keywords = [args.keyword] + plot_det_curve(keywords, stats_file, img_file, 10, 2, 10, 2) - print('DET curve image saved to: {}'.format(img_path)) + print('DET curve image saved to: {}'.format(img_file)) diff --git a/paddlespeech/kws/exps/mdtc/score.py b/paddlespeech/kws/exps/mdtc/score.py index 9fdbcf49..7fe88ea3 100644 --- a/paddlespeech/kws/exps/mdtc/score.py +++ b/paddlespeech/kws/exps/mdtc/score.py @@ -11,80 +11,56 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# Modified from wekws(https://github.com/wenet-e2e/wekws) +import argparse import os -import sys -import time import paddle -from mdtc import KWSModel -from mdtc import MDTC +import yaml from tqdm import tqdm -from paddleaudio.datasets import HeySnips +from paddlespeech.kws.exps.mdtc.collate import collate_features +from paddlespeech.kws.models.mdtc import KWSModel +from paddlespeech.s2t.utils.dynamic_import import dynamic_import +# yapf: disable +parser = argparse.ArgumentParser(__doc__) +parser.add_argument("--cfg_path", type=str, required=True) +args = parser.parse_args() +# yapf: enable -def collate_features(batch): - # (key, feat, label) in one sample - collate_start = time.time() - keys = [] - feats = [] - labels = [] - lengths = [] - for sample in batch: - keys.append(sample[0]) - feats.append(sample[1]) - labels.append(sample[2]) - lengths.append(sample[1].shape[0]) - - max_length = max(lengths) - for i in range(len(feats)): - feats[i] = paddle.nn.functional.pad( - feats[i], [0, max_length - feats[i].shape[0], 0, 0], - data_format='NLC') - - return keys, paddle.stack(feats), paddle.to_tensor( - labels), paddle.to_tensor(lengths) +if __name__ == '__main__': + args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path)) + with open(args.cfg_path, 'r') as f: + config = yaml.safe_load(f) + model_conf = config['model'] + data_conf = config['data'] + feat_conf = config['feature'] + scoring_conf = config['scoring'] -if __name__ == '__main__': # Dataset - feat_conf = { - # 'n_mfcc': 80, - 'n_mels': 80, - 'frame_shift': 10, - 'frame_length': 25, - # 'dither': 1.0, - } - test_ds = HeySnips( - mode='test', feat_type='kaldi_fbank', sample_rate=16000, **feat_conf) + ds_class = dynamic_import(data_conf['dataset']) + test_ds = ds_class(data_dir=data_conf['data_dir'], mode='test', **feat_conf) test_sampler = paddle.io.BatchSampler( - test_ds, batch_size=32, drop_last=False) + test_ds, batch_size=scoring_conf['batch_size'], drop_last=False) test_loader = paddle.io.DataLoader( test_ds, batch_sampler=test_sampler, - num_workers=16, + num_workers=scoring_conf['num_workers'], return_list=True, use_buffer_reader=True, collate_fn=collate_features, ) # Model - backbone = MDTC( - stack_num=3, - stack_size=4, - in_channels=80, - res_channels=32, - kernel_size=5, - causal=True, ) - model = KWSModel(backbone=backbone, num_keywords=1) - model = paddle.DataParallel(model) - # kws_checkpoint = '/ssd3/chenxiaojie06/PaddleSpeech/DeepSpeech/paddlespeech/kws/models/checkpoint/epoch_10_0.8903940343290826/model.pdparams' - kws_checkpoint = os.path.join( - os.path.abspath(sys.argv[1]), 'model.pdparams') - model.set_state_dict(paddle.load(kws_checkpoint)) + backbone_class = dynamic_import(model_conf['backbone']) + backbone = backbone_class(**model_conf['config']) + model = KWSModel(backbone=backbone, num_keywords=model_conf['num_keywords']) + model.set_state_dict(paddle.load(scoring_conf['checkpoint'])) model.eval() - score_abs_path = os.path.join(os.path.abspath(sys.argv[1]), 'score.txt') - with paddle.no_grad(), open(score_abs_path, 'w', encoding='utf8') as fout: + with paddle.no_grad(), open( + scoring_conf['score_file'], 'w', encoding='utf8') as fout: for batch_idx, batch in enumerate( tqdm(test_loader, total=len(test_loader))): keys, feats, labels, lengths = batch @@ -100,4 +76,4 @@ if __name__ == '__main__': fout.write( '{} {} {}\n'.format(key, keyword_i, score_frames)) - print('Scores saved to: {}'.format(score_abs_path)) + print('Result saved to: {}'.format(scoring_conf['score_file'])) diff --git a/paddlespeech/kws/exps/mdtc/train.py b/paddlespeech/kws/exps/mdtc/train.py index 17a9acfc..99e72871 100644 --- a/paddlespeech/kws/exps/mdtc/train.py +++ b/paddlespeech/kws/exps/mdtc/train.py @@ -11,77 +11,47 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import argparse import os -import time import paddle -from loss import max_pooling_loss -from mdtc import KWSModel -from mdtc import MDTC +import yaml -from paddleaudio.datasets import HeySnips from paddleaudio.utils import logger from paddleaudio.utils import Timer +from paddlespeech.kws.exps.mdtc.collate import collate_features +from paddlespeech.kws.models.loss import max_pooling_loss +from paddlespeech.kws.models.mdtc import KWSModel +from paddlespeech.s2t.utils.dynamic_import import dynamic_import +# yapf: disable +parser = argparse.ArgumentParser(__doc__) +parser.add_argument("--cfg_path", type=str, required=True) +args = parser.parse_args() +# yapf: enable -def collate_features(batch): - # (key, feat, label) - collate_start = time.time() - keys = [] - feats = [] - labels = [] - lengths = [] - for sample in batch: - keys.append(sample[0]) - feats.append(sample[1]) - labels.append(sample[2]) - lengths.append(sample[1].shape[0]) - - max_length = max(lengths) - for i in range(len(feats)): - feats[i] = paddle.nn.functional.pad( - feats[i], [0, max_length - feats[i].shape[0], 0, 0], - data_format='NLC') +if __name__ == '__main__': + nranks = paddle.distributed.get_world_size() + if paddle.distributed.get_world_size() > 1: + paddle.distributed.init_parallel_env() + local_rank = paddle.distributed.get_rank() - return keys, paddle.stack(feats), paddle.to_tensor( - labels), paddle.to_tensor(lengths) + args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path)) + with open(args.cfg_path, 'r') as f: + config = yaml.safe_load(f) + model_conf = config['model'] + data_conf = config['data'] + feat_conf = config['feature'] + training_conf = config['training'] -if __name__ == '__main__': # Dataset - feat_conf = { - # 'n_mfcc': 80, - 'n_mels': 80, - 'frame_shift': 10, - 'frame_length': 25, - # 'dither': 1.0, - } - data_dir = '/ssd1/chenxiaojie06/datasets/hey_snips/hey_snips_research_6k_en_train_eval_clean_ter' - train_ds = HeySnips( - data_dir=data_dir, - mode='train', - feat_type='kaldi_fbank', - sample_rate=16000, - **feat_conf) - dev_ds = HeySnips( - data_dir=data_dir, - mode='dev', - feat_type='kaldi_fbank', - sample_rate=16000, - **feat_conf) - - training_conf = { - 'epochs': 100, - 'learning_rate': 0.001, - 'weight_decay': 0.00005, - 'num_workers': 16, - 'batch_size': 100, - 'checkpoint_dir': './checkpoint', - 'save_freq': 10, - 'log_freq': 10, - } - - train_sampler = paddle.io.BatchSampler( + ds_class = dynamic_import(data_conf['dataset']) + train_ds = ds_class( + data_dir=data_conf['data_dir'], mode='train', **feat_conf) + dev_ds = ds_class(data_dir=data_conf['data_dir'], mode='dev', **feat_conf) + + train_sampler = paddle.io.DistributedBatchSampler( train_ds, batch_size=training_conf['batch_size'], shuffle=True, @@ -95,16 +65,11 @@ if __name__ == '__main__': collate_fn=collate_features, ) # Model - backbone = MDTC( - stack_num=3, - stack_size=4, - in_channels=80, - res_channels=32, - kernel_size=5, - causal=True, ) - model = KWSModel(backbone=backbone, num_keywords=1) + backbone_class = dynamic_import(model_conf['backbone']) + backbone = backbone_class(**model_conf['config']) + model = KWSModel(backbone=backbone, num_keywords=model_conf['num_keywords']) model = paddle.DataParallel(model) - clip = paddle.nn.ClipGradByGlobalNorm(5.0) + clip = paddle.nn.ClipGradByGlobalNorm(training_conf['grad_clip']) optimizer = paddle.optimizer.Adam( learning_rate=training_conf['learning_rate'], weight_decay=training_conf['weight_decay'], @@ -122,9 +87,7 @@ if __name__ == '__main__': avg_loss = 0 num_corrects = 0 num_samples = 0 - batch_start = time.time() for batch_idx, batch in enumerate(train_loader): - # print('Fetch one batch: {:.4f}'.format(time.time()-batch_start)) keys, feats, labels, lengths = batch logits = model(feats) loss, corrects, acc = criterion(logits, labels, lengths) @@ -144,7 +107,8 @@ if __name__ == '__main__': timer.count() - if (batch_idx + 1) % training_conf['log_freq'] == 0: + if (batch_idx + 1 + ) % training_conf['log_freq'] == 0 and local_rank == 0: lr = optimizer.get_lr() avg_loss /= training_conf['log_freq'] avg_acc = num_corrects / num_samples @@ -161,10 +125,9 @@ if __name__ == '__main__': avg_loss = 0 num_corrects = 0 num_samples = 0 - batch_start = time.time() if epoch % training_conf[ - 'save_freq'] == 0 and batch_idx + 1 == steps_per_epoch: + 'save_freq'] == 0 and batch_idx + 1 == steps_per_epoch and local_rank == 0: dev_sampler = paddle.io.BatchSampler( dev_ds, batch_size=training_conf['batch_size'], @@ -197,7 +160,7 @@ if __name__ == '__main__': # Save model save_dir = os.path.join(training_conf['checkpoint_dir'], - 'epoch_{}_{:.4f}'.format(epoch, eval_acc)) + 'epoch_{}'.format(epoch)) logger.info('Saving model checkpoint to {}'.format(save_dir)) paddle.save(model.state_dict(), os.path.join(save_dir, 'model.pdparams')) diff --git a/paddlespeech/kws/models/__init__.py b/paddlespeech/kws/models/__init__.py index 97043fd7..125a0d7a 100644 --- a/paddlespeech/kws/models/__init__.py +++ b/paddlespeech/kws/models/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from .mdtc import KWSModel +from .mdtc import MDTC diff --git a/paddlespeech/kws/models/loss.py b/paddlespeech/kws/models/loss.py new file mode 100644 index 00000000..8a2e9e74 --- /dev/null +++ b/paddlespeech/kws/models/loss.py @@ -0,0 +1,80 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from wekws(https://github.com/wenet-e2e/wekws) +import paddle + + +def fill_mask_elements(condition, value, x): + assert condition.shape == x.shape + values = paddle.ones_like(x, dtype=x.dtype) * value + return paddle.where(condition, values, x) + + +def max_pooling_loss(logits: paddle.Tensor, + target: paddle.Tensor, + lengths: paddle.Tensor, + min_duration: int=0): + + mask = padding_mask(lengths) + num_utts = logits.shape[0] + num_keywords = logits.shape[2] + + loss = 0.0 + for i in range(num_utts): + for j in range(num_keywords): + # Add entropy loss CE = -(t * log(p) + (1 - t) * log(1 - p)) + if target[i] == j: + # For the keyword, do max-polling + prob = logits[i, :, j] + m = mask[i] + if min_duration > 0: + m[:min_duration] = True + prob = fill_mask_elements(m, 0.0, prob) + prob = paddle.clip(prob, 1e-8, 1.0) + max_prob = prob.max() + loss += -paddle.log(max_prob) + else: + # For other keywords or filler, do min-polling + prob = 1 - logits[i, :, j] + prob = fill_mask_elements(mask[i], 1.0, prob) + prob = paddle.clip(prob, 1e-8, 1.0) + min_prob = prob.min() + loss += -paddle.log(min_prob) + loss = loss / num_utts + + # Compute accuracy of current batch + mask = mask.unsqueeze(-1) + logits = fill_mask_elements(mask, 0.0, logits) + max_logits = logits.max(1) + num_correct = 0 + for i in range(num_utts): + max_p = max_logits[i].max(0).item() + idx = max_logits[i].argmax(0).item() + # Predict correct as the i'th keyword + if max_p > 0.5 and idx == target[i].item(): + num_correct += 1 + # Predict correct as the filler, filler id < 0 + if max_p < 0.5 and target[i].item() < 0: + num_correct += 1 + acc = num_correct / num_utts + # acc = 0.0 + return loss, num_correct, acc + + +def padding_mask(lengths: paddle.Tensor) -> paddle.Tensor: + batch_size = lengths.shape[0] + max_len = int(lengths.max().item()) + seq = paddle.arange(max_len, dtype=paddle.int64) + seq = seq.expand((batch_size, max_len)) + return seq >= lengths.unsqueeze(1) diff --git a/paddlespeech/kws/models/mdtc.py b/paddlespeech/kws/models/mdtc.py index 2cb14305..5d2e5de6 100644 --- a/paddlespeech/kws/models/mdtc.py +++ b/paddlespeech/kws/models/mdtc.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# Modified from wekws(https://github.com/wenet-e2e/wekws) import paddle import paddle.nn as nn import paddle.nn.functional as F @@ -163,7 +164,7 @@ class MDTC(nn.Layer): in_channels: int, res_channels: int, kernel_size: int, - causal: bool, ): + causal: bool=True, ): super(MDTC, self).__init__() assert kernel_size % 2 == 1 self.kernel_size = kernel_size @@ -230,17 +231,3 @@ class KWSModel(nn.Layer): outputs = self.backbone(x) outputs = self.linear(outputs) return self.activation(outputs) - - -if __name__ == '__main__': - paddle.set_device('cpu') - from paddleaudio.features import LogMelSpectrogram - mdtc = MDTC(3, 4, 80, 32, 5, causal=True) - - x = paddle.randn(shape=(32, 16000 * 5)) - feature_extractor = LogMelSpectrogram(sr=16000, n_fft=512, n_mels=80) - feats = feature_extractor(x).transpose([0, 2, 1]) - print(feats.shape) - - res, _ = mdtc(feats) - print(res.shape)