commit
962a278996
@ -0,0 +1,74 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import collections
|
||||
import json
|
||||
import os
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
from .dataset import AudioClassificationDataset
|
||||
|
||||
__all__ = ['HeySnips']
|
||||
|
||||
|
||||
class HeySnips(AudioClassificationDataset):
|
||||
meta_info = collections.namedtuple('META_INFO',
|
||||
('key', 'label', 'duration', 'wav'))
|
||||
|
||||
def __init__(self,
|
||||
data_dir: os.PathLike,
|
||||
mode: str='train',
|
||||
feat_type: str='kaldi_fbank',
|
||||
sample_rate: int=16000,
|
||||
**kwargs):
|
||||
self.data_dir = data_dir
|
||||
files, labels = self._get_data(mode)
|
||||
super(HeySnips, self).__init__(
|
||||
files=files,
|
||||
labels=labels,
|
||||
feat_type=feat_type,
|
||||
sample_rate=sample_rate,
|
||||
**kwargs)
|
||||
|
||||
def _get_meta_info(self, mode) -> List[collections.namedtuple]:
|
||||
ret = []
|
||||
with open(os.path.join(self.data_dir, '{}.json'.format(mode)),
|
||||
'r') as f:
|
||||
data = json.load(f)
|
||||
for item in data:
|
||||
sample = collections.OrderedDict()
|
||||
if item['duration'] > 0:
|
||||
sample['key'] = item['id']
|
||||
sample['label'] = 0 if item['is_hotword'] == 1 else -1
|
||||
sample['duration'] = item['duration']
|
||||
sample['wav'] = os.path.join(self.data_dir,
|
||||
item['audio_file_path'])
|
||||
ret.append(self.meta_info(*sample.values()))
|
||||
return ret
|
||||
|
||||
def _get_data(self, mode: str) -> Tuple[List[str], List[int]]:
|
||||
meta_info = self._get_meta_info(mode)
|
||||
|
||||
files = []
|
||||
labels = []
|
||||
self.keys = []
|
||||
self.durations = []
|
||||
for sample in meta_info:
|
||||
key, target, duration, wav = sample
|
||||
files.append(wav)
|
||||
labels.append(int(target))
|
||||
self.keys.append(key)
|
||||
self.durations.append(float(duration))
|
||||
|
||||
return files, labels
|
@ -0,0 +1,8 @@
|
||||
|
||||
## Metrics
|
||||
|
||||
We mesure FRRs with fixing false alarms in one hour:
|
||||
|
||||
|Model|False Alarm| False Reject Rate|
|
||||
|--|--|--|
|
||||
|MDTC| 1| 0.003559 |
|
@ -0,0 +1,22 @@
|
||||
# MDTC Keyword Spotting with HeySnips Dataset
|
||||
|
||||
## Dataset
|
||||
|
||||
Before running scripts, you **MUST** follow this instruction to download the dataset: https://github.com/sonos/keyword-spotting-research-datasets
|
||||
|
||||
After you download and decompress the dataset archive, you should **REPLACE** the value of `data_dir` in `conf/*.yaml` to complete dataset config.
|
||||
|
||||
## Get Started
|
||||
|
||||
In this section, we will train the [MDTC](https://arxiv.org/pdf/2102.13552.pdf) model and evaluate on "Hey Snips" dataset.
|
||||
|
||||
```sh
|
||||
CUDA_VISIBLE_DEVICES=0,1 ./run.sh conf/mdtc.yaml
|
||||
```
|
||||
|
||||
This script contains training and scoring steps. You can just set the `CUDA_VISIBLE_DEVICES` environment var to run on single gpu or multi-gpus.
|
||||
|
||||
The vars `stage` and `stop_stage` in `./run.sh` controls the running steps:
|
||||
- stage 1: Training from scratch.
|
||||
- stage 2: Evaluating model on test dataset and computing detection error tradeoff(DET) of all trigger thresholds.
|
||||
- stage 3: Plotting the DET cruve for visualizaiton.
|
@ -0,0 +1,39 @@
|
||||
data:
|
||||
data_dir: '/PATH/TO/DATA/hey_snips_research_6k_en_train_eval_clean_ter'
|
||||
dataset: 'paddleaudio.datasets:HeySnips'
|
||||
|
||||
model:
|
||||
num_keywords: 1
|
||||
backbone: 'paddlespeech.kws.models:MDTC'
|
||||
config:
|
||||
stack_num: 3
|
||||
stack_size: 4
|
||||
in_channels: 80
|
||||
res_channels: 32
|
||||
kernel_size: 5
|
||||
|
||||
feature:
|
||||
feat_type: 'kaldi_fbank'
|
||||
sample_rate: 16000
|
||||
frame_shift: 10
|
||||
frame_length: 25
|
||||
n_mels: 80
|
||||
|
||||
training:
|
||||
epochs: 100
|
||||
num_workers: 16
|
||||
batch_size: 100
|
||||
checkpoint_dir: './checkpoint'
|
||||
save_freq: 10
|
||||
log_freq: 10
|
||||
learning_rate: 0.001
|
||||
weight_decay: 0.00005
|
||||
grad_clip: 5.0
|
||||
|
||||
scoring:
|
||||
batch_size: 100
|
||||
num_workers: 16
|
||||
checkpoint: './checkpoint/epoch_100/model.pdparams'
|
||||
score_file: './scores.txt'
|
||||
stats_file: './stats.0.txt'
|
||||
img_file: './det.png'
|
@ -0,0 +1,2 @@
|
||||
#!/bin/bash
|
||||
python3 ${BIN_DIR}/plot_det_curve.py --cfg_path=$1 --keyword HeySnips
|
@ -0,0 +1,5 @@
|
||||
#!/bin/bash
|
||||
|
||||
python3 ${BIN_DIR}/score.py --cfg_path=$1
|
||||
|
||||
python3 ${BIN_DIR}/compute_det.py --cfg_path=$1
|
@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
|
||||
ngpu=$1
|
||||
cfg_path=$2
|
||||
|
||||
if [ ${ngpu} -gt 0 ]; then
|
||||
python3 -m paddle.distributed.launch --gpus $CUDA_VISIBLE_DEVICES ${BIN_DIR}/train.py \
|
||||
--cfg_path ${cfg_path}
|
||||
else
|
||||
echo "set CUDA_VISIBLE_DEVICES to enable multi-gpus trainning."
|
||||
python3 ${BIN_DIR}/train.py \
|
||||
--cfg_path ${cfg_path}
|
||||
fi
|
@ -0,0 +1,28 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
export MAIN_ROOT=`realpath ${PWD}/../../../`
|
||||
|
||||
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
|
||||
export LC_ALL=C
|
||||
|
||||
export PYTHONDONTWRITEBYTECODE=1
|
||||
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||
|
||||
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
|
||||
|
||||
MODEL=mdtc
|
||||
export BIN_DIR=${MAIN_ROOT}/paddlespeech/kws/exps/${MODEL}
|
@ -0,0 +1,41 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
set -e
|
||||
source path.sh
|
||||
|
||||
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
|
||||
if [ $# != 1 ];then
|
||||
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
stage=1
|
||||
stop_stage=3
|
||||
|
||||
cfg_path=$1
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
./local/train.sh ${ngpu} ${cfg_path} || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
./local/score.sh ${cfg_path} || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
./local/plot.sh ${cfg_path} || exit -1
|
||||
fi
|
@ -0,0 +1,14 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .models.mdtc import MDTC
|
@ -0,0 +1,39 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import time
|
||||
|
||||
import paddle
|
||||
|
||||
|
||||
def collate_features(batch):
|
||||
# (key, feat, label)
|
||||
collate_start = time.time()
|
||||
keys = []
|
||||
feats = []
|
||||
labels = []
|
||||
lengths = []
|
||||
for sample in batch:
|
||||
keys.append(sample[0])
|
||||
feats.append(sample[1])
|
||||
labels.append(sample[2])
|
||||
lengths.append(sample[1].shape[0])
|
||||
|
||||
max_length = max(lengths)
|
||||
for i in range(len(feats)):
|
||||
feats[i] = paddle.nn.functional.pad(
|
||||
feats[i], [0, max_length - feats[i].shape[0], 0, 0],
|
||||
data_format='NLC')
|
||||
|
||||
return keys, paddle.stack(feats), paddle.to_tensor(
|
||||
labels), paddle.to_tensor(lengths)
|
@ -0,0 +1,116 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from wekws(https://github.com/wenet-e2e/wekws)
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import paddle
|
||||
import yaml
|
||||
from tqdm import tqdm
|
||||
|
||||
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
|
||||
|
||||
# yapf: disable
|
||||
parser = argparse.ArgumentParser(__doc__)
|
||||
parser.add_argument("--cfg_path", type=str, required=True)
|
||||
parser.add_argument('--keyword_index', type=int, default=0, help='keyword index')
|
||||
parser.add_argument('--step', type=float, default=0.01, help='threshold step of trigger score')
|
||||
parser.add_argument('--window_shift', type=int, default=50, help='window_shift is used to skip the frames after triggered')
|
||||
args = parser.parse_args()
|
||||
# yapf: enable
|
||||
|
||||
|
||||
def load_label_and_score(keyword_index: int,
|
||||
ds: paddle.io.Dataset,
|
||||
score_file: os.PathLike):
|
||||
score_table = {} # {utt_id: scores_over_frames}
|
||||
with open(score_file, 'r', encoding='utf8') as fin:
|
||||
for line in fin:
|
||||
arr = line.strip().split()
|
||||
key = arr[0]
|
||||
current_keyword = arr[1]
|
||||
str_list = arr[2:]
|
||||
if int(current_keyword) == keyword_index:
|
||||
scores = list(map(float, str_list))
|
||||
if key not in score_table:
|
||||
score_table.update({key: scores})
|
||||
keyword_table = {} # scores of keyword utt_id
|
||||
filler_table = {} # scores of non-keyword utt_id
|
||||
filler_duration = 0.0
|
||||
|
||||
for key, index, duration in zip(ds.keys, ds.labels, ds.durations):
|
||||
assert key in score_table
|
||||
if index == keyword_index:
|
||||
keyword_table[key] = score_table[key]
|
||||
else:
|
||||
filler_table[key] = score_table[key]
|
||||
filler_duration += duration
|
||||
|
||||
return keyword_table, filler_table, filler_duration
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path))
|
||||
with open(args.cfg_path, 'r') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
data_conf = config['data']
|
||||
feat_conf = config['feature']
|
||||
scoring_conf = config['scoring']
|
||||
|
||||
# Dataset
|
||||
ds_class = dynamic_import(data_conf['dataset'])
|
||||
test_ds = ds_class(data_dir=data_conf['data_dir'], mode='test', **feat_conf)
|
||||
|
||||
score_file = os.path.abspath(scoring_conf['score_file'])
|
||||
stats_file = os.path.abspath(scoring_conf['stats_file'])
|
||||
|
||||
keyword_table, filler_table, filler_duration = load_label_and_score(
|
||||
args.keyword, test_ds, score_file)
|
||||
print('Filler total duration Hours: {}'.format(filler_duration / 3600.0))
|
||||
pbar = tqdm(total=int(1.0 / args.step))
|
||||
with open(stats_file, 'w', encoding='utf8') as fout:
|
||||
keyword_index = args.keyword_index
|
||||
threshold = 0.0
|
||||
while threshold <= 1.0:
|
||||
num_false_reject = 0
|
||||
# transverse the all keyword_table
|
||||
for key, score_list in keyword_table.items():
|
||||
# computer positive test sample, use the max score of list.
|
||||
score = max(score_list)
|
||||
if float(score) < threshold:
|
||||
num_false_reject += 1
|
||||
num_false_alarm = 0
|
||||
# transverse the all filler_table
|
||||
for key, score_list in filler_table.items():
|
||||
i = 0
|
||||
while i < len(score_list):
|
||||
if score_list[i] >= threshold:
|
||||
num_false_alarm += 1
|
||||
i += args.window_shift
|
||||
else:
|
||||
i += 1
|
||||
if len(keyword_table) != 0:
|
||||
false_reject_rate = num_false_reject / len(keyword_table)
|
||||
num_false_alarm = max(num_false_alarm, 1e-6)
|
||||
if filler_duration != 0:
|
||||
false_alarm_per_hour = num_false_alarm / \
|
||||
(filler_duration / 3600.0)
|
||||
fout.write('{:.6f} {:.6f} {:.6f}\n'.format(
|
||||
threshold, false_alarm_per_hour, false_reject_rate))
|
||||
threshold += args.step
|
||||
pbar.update(1)
|
||||
|
||||
pbar.close()
|
||||
print('DET saved to: {}'.format(stats_file))
|
@ -0,0 +1,74 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from wekws(https://github.com/wenet-e2e/wekws)
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import yaml
|
||||
|
||||
# yapf: disable
|
||||
parser = argparse.ArgumentParser(__doc__)
|
||||
parser.add_argument("--cfg_path", type=str, required=True)
|
||||
parser.add_argument("--keyword", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
# yapf: enable
|
||||
|
||||
|
||||
def load_stats_file(stats_file):
|
||||
values = []
|
||||
with open(stats_file, 'r', encoding='utf8') as fin:
|
||||
for line in fin:
|
||||
arr = line.strip().split()
|
||||
threshold, fa_per_hour, frr = arr
|
||||
values.append([float(fa_per_hour), float(frr) * 100])
|
||||
values.reverse()
|
||||
return np.array(values)
|
||||
|
||||
|
||||
def plot_det_curve(keywords, stats_file, figure_file, xlim, x_step, ylim,
|
||||
y_step):
|
||||
plt.figure(dpi=200)
|
||||
plt.rcParams['xtick.direction'] = 'in'
|
||||
plt.rcParams['ytick.direction'] = 'in'
|
||||
plt.rcParams['font.size'] = 12
|
||||
|
||||
for index, keyword in enumerate(keywords):
|
||||
values = load_stats_file(stats_file)
|
||||
plt.plot(values[:, 0], values[:, 1], label=keyword)
|
||||
|
||||
plt.xlim([0, xlim])
|
||||
plt.ylim([0, ylim])
|
||||
plt.xticks(range(0, xlim + x_step, x_step))
|
||||
plt.yticks(range(0, ylim + y_step, y_step))
|
||||
plt.xlabel('False Alarm Per Hour')
|
||||
plt.ylabel('False Rejection Rate (\\%)')
|
||||
plt.grid(linestyle='--')
|
||||
plt.legend(loc='best', fontsize=16)
|
||||
plt.savefig(figure_file)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path))
|
||||
with open(args.cfg_path, 'r') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
scoring_conf = config['scoring']
|
||||
img_file = os.path.abspath(scoring_conf['img_file'])
|
||||
stats_file = os.path.abspath(scoring_conf['stats_file'])
|
||||
keywords = [args.keyword]
|
||||
plot_det_curve(keywords, stats_file, img_file, 10, 2, 10, 2)
|
||||
|
||||
print('DET curve image saved to: {}'.format(img_file))
|
@ -0,0 +1,79 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from wekws(https://github.com/wenet-e2e/wekws)
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import paddle
|
||||
import yaml
|
||||
from tqdm import tqdm
|
||||
|
||||
from paddlespeech.kws.exps.mdtc.collate import collate_features
|
||||
from paddlespeech.kws.models.mdtc import KWSModel
|
||||
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
|
||||
|
||||
# yapf: disable
|
||||
parser = argparse.ArgumentParser(__doc__)
|
||||
parser.add_argument("--cfg_path", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
# yapf: enable
|
||||
|
||||
if __name__ == '__main__':
|
||||
args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path))
|
||||
with open(args.cfg_path, 'r') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
model_conf = config['model']
|
||||
data_conf = config['data']
|
||||
feat_conf = config['feature']
|
||||
scoring_conf = config['scoring']
|
||||
|
||||
# Dataset
|
||||
ds_class = dynamic_import(data_conf['dataset'])
|
||||
test_ds = ds_class(data_dir=data_conf['data_dir'], mode='test', **feat_conf)
|
||||
test_sampler = paddle.io.BatchSampler(
|
||||
test_ds, batch_size=scoring_conf['batch_size'], drop_last=False)
|
||||
test_loader = paddle.io.DataLoader(
|
||||
test_ds,
|
||||
batch_sampler=test_sampler,
|
||||
num_workers=scoring_conf['num_workers'],
|
||||
return_list=True,
|
||||
use_buffer_reader=True,
|
||||
collate_fn=collate_features, )
|
||||
|
||||
# Model
|
||||
backbone_class = dynamic_import(model_conf['backbone'])
|
||||
backbone = backbone_class(**model_conf['config'])
|
||||
model = KWSModel(backbone=backbone, num_keywords=model_conf['num_keywords'])
|
||||
model.set_state_dict(paddle.load(scoring_conf['checkpoint']))
|
||||
model.eval()
|
||||
|
||||
with paddle.no_grad(), open(
|
||||
scoring_conf['score_file'], 'w', encoding='utf8') as fout:
|
||||
for batch_idx, batch in enumerate(
|
||||
tqdm(test_loader, total=len(test_loader))):
|
||||
keys, feats, labels, lengths = batch
|
||||
logits = model(feats)
|
||||
num_keywords = logits.shape[2]
|
||||
for i in range(len(keys)):
|
||||
key = keys[i]
|
||||
score = logits[i][:lengths[i]]
|
||||
for keyword_i in range(num_keywords):
|
||||
keyword_scores = score[:, keyword_i]
|
||||
score_frames = ' '.join(
|
||||
['{:.6f}'.format(x) for x in keyword_scores.tolist()])
|
||||
fout.write(
|
||||
'{} {} {}\n'.format(key, keyword_i, score_frames))
|
||||
|
||||
print('Result saved to: {}'.format(scoring_conf['score_file']))
|
@ -0,0 +1,168 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import paddle
|
||||
import yaml
|
||||
|
||||
from paddleaudio.utils import logger
|
||||
from paddleaudio.utils import Timer
|
||||
from paddlespeech.kws.exps.mdtc.collate import collate_features
|
||||
from paddlespeech.kws.models.loss import max_pooling_loss
|
||||
from paddlespeech.kws.models.mdtc import KWSModel
|
||||
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
|
||||
|
||||
# yapf: disable
|
||||
parser = argparse.ArgumentParser(__doc__)
|
||||
parser.add_argument("--cfg_path", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
# yapf: enable
|
||||
|
||||
if __name__ == '__main__':
|
||||
nranks = paddle.distributed.get_world_size()
|
||||
if paddle.distributed.get_world_size() > 1:
|
||||
paddle.distributed.init_parallel_env()
|
||||
local_rank = paddle.distributed.get_rank()
|
||||
|
||||
args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path))
|
||||
with open(args.cfg_path, 'r') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
model_conf = config['model']
|
||||
data_conf = config['data']
|
||||
feat_conf = config['feature']
|
||||
training_conf = config['training']
|
||||
|
||||
# Dataset
|
||||
ds_class = dynamic_import(data_conf['dataset'])
|
||||
train_ds = ds_class(
|
||||
data_dir=data_conf['data_dir'], mode='train', **feat_conf)
|
||||
dev_ds = ds_class(data_dir=data_conf['data_dir'], mode='dev', **feat_conf)
|
||||
|
||||
train_sampler = paddle.io.DistributedBatchSampler(
|
||||
train_ds,
|
||||
batch_size=training_conf['batch_size'],
|
||||
shuffle=True,
|
||||
drop_last=False)
|
||||
train_loader = paddle.io.DataLoader(
|
||||
train_ds,
|
||||
batch_sampler=train_sampler,
|
||||
num_workers=training_conf['num_workers'],
|
||||
return_list=True,
|
||||
use_buffer_reader=True,
|
||||
collate_fn=collate_features, )
|
||||
|
||||
# Model
|
||||
backbone_class = dynamic_import(model_conf['backbone'])
|
||||
backbone = backbone_class(**model_conf['config'])
|
||||
model = KWSModel(backbone=backbone, num_keywords=model_conf['num_keywords'])
|
||||
model = paddle.DataParallel(model)
|
||||
clip = paddle.nn.ClipGradByGlobalNorm(training_conf['grad_clip'])
|
||||
optimizer = paddle.optimizer.Adam(
|
||||
learning_rate=training_conf['learning_rate'],
|
||||
weight_decay=training_conf['weight_decay'],
|
||||
parameters=model.parameters(),
|
||||
grad_clip=clip)
|
||||
criterion = max_pooling_loss
|
||||
|
||||
steps_per_epoch = len(train_sampler)
|
||||
timer = Timer(steps_per_epoch * training_conf['epochs'])
|
||||
timer.start()
|
||||
|
||||
for epoch in range(1, training_conf['epochs'] + 1):
|
||||
model.train()
|
||||
|
||||
avg_loss = 0
|
||||
num_corrects = 0
|
||||
num_samples = 0
|
||||
for batch_idx, batch in enumerate(train_loader):
|
||||
keys, feats, labels, lengths = batch
|
||||
logits = model(feats)
|
||||
loss, corrects, acc = criterion(logits, labels, lengths)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
if isinstance(optimizer._learning_rate,
|
||||
paddle.optimizer.lr.LRScheduler):
|
||||
optimizer._learning_rate.step()
|
||||
optimizer.clear_grad()
|
||||
|
||||
# Calculate loss
|
||||
avg_loss += loss.numpy()[0]
|
||||
|
||||
# Calculate metrics
|
||||
num_corrects += corrects
|
||||
num_samples += feats.shape[0]
|
||||
|
||||
timer.count()
|
||||
|
||||
if (batch_idx + 1
|
||||
) % training_conf['log_freq'] == 0 and local_rank == 0:
|
||||
lr = optimizer.get_lr()
|
||||
avg_loss /= training_conf['log_freq']
|
||||
avg_acc = num_corrects / num_samples
|
||||
|
||||
print_msg = 'Epoch={}/{}, Step={}/{}'.format(
|
||||
epoch, training_conf['epochs'], batch_idx + 1,
|
||||
steps_per_epoch)
|
||||
print_msg += ' loss={:.4f}'.format(avg_loss)
|
||||
print_msg += ' acc={:.4f}'.format(avg_acc)
|
||||
print_msg += ' lr={:.6f} step/sec={:.2f} | ETA {}'.format(
|
||||
lr, timer.timing, timer.eta)
|
||||
logger.train(print_msg)
|
||||
|
||||
avg_loss = 0
|
||||
num_corrects = 0
|
||||
num_samples = 0
|
||||
|
||||
if epoch % training_conf[
|
||||
'save_freq'] == 0 and batch_idx + 1 == steps_per_epoch and local_rank == 0:
|
||||
dev_sampler = paddle.io.BatchSampler(
|
||||
dev_ds,
|
||||
batch_size=training_conf['batch_size'],
|
||||
shuffle=False,
|
||||
drop_last=False)
|
||||
dev_loader = paddle.io.DataLoader(
|
||||
dev_ds,
|
||||
batch_sampler=dev_sampler,
|
||||
num_workers=training_conf['num_workers'],
|
||||
return_list=True,
|
||||
use_buffer_reader=True,
|
||||
collate_fn=collate_features, )
|
||||
|
||||
model.eval()
|
||||
num_corrects = 0
|
||||
num_samples = 0
|
||||
with logger.processing('Evaluation on validation dataset'):
|
||||
for batch_idx, batch in enumerate(dev_loader):
|
||||
keys, feats, labels, lengths = batch
|
||||
logits = model(feats)
|
||||
loss, corrects, acc = criterion(logits, labels, lengths)
|
||||
num_corrects += corrects
|
||||
num_samples += feats.shape[0]
|
||||
|
||||
eval_acc = num_corrects / num_samples
|
||||
print_msg = '[Evaluation result]'
|
||||
print_msg += ' dev_acc={:.4f}'.format(eval_acc)
|
||||
|
||||
logger.eval(print_msg)
|
||||
|
||||
# Save model
|
||||
save_dir = os.path.join(training_conf['checkpoint_dir'],
|
||||
'epoch_{}'.format(epoch))
|
||||
logger.info('Saving model checkpoint to {}'.format(save_dir))
|
||||
paddle.save(model.state_dict(),
|
||||
os.path.join(save_dir, 'model.pdparams'))
|
||||
paddle.save(optimizer.state_dict(),
|
||||
os.path.join(save_dir, 'model.pdopt'))
|
@ -0,0 +1,15 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .mdtc import KWSModel
|
||||
from .mdtc import MDTC
|
@ -0,0 +1,81 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from wekws(https://github.com/wenet-e2e/wekws)
|
||||
import paddle
|
||||
|
||||
|
||||
def padding_mask(lengths: paddle.Tensor) -> paddle.Tensor:
|
||||
batch_size = lengths.shape[0]
|
||||
max_len = int(lengths.max().item())
|
||||
seq = paddle.arange(max_len, dtype=paddle.int64)
|
||||
seq = seq.expand((batch_size, max_len))
|
||||
return seq >= lengths.unsqueeze(1)
|
||||
|
||||
|
||||
def fill_mask_elements(condition: paddle.Tensor, value: float,
|
||||
x: paddle.Tensor) -> paddle.Tensor:
|
||||
assert condition.shape == x.shape
|
||||
values = paddle.ones_like(x, dtype=x.dtype) * value
|
||||
return paddle.where(condition, values, x)
|
||||
|
||||
|
||||
def max_pooling_loss(logits: paddle.Tensor,
|
||||
target: paddle.Tensor,
|
||||
lengths: paddle.Tensor,
|
||||
min_duration: int=0):
|
||||
|
||||
mask = padding_mask(lengths)
|
||||
num_utts = logits.shape[0]
|
||||
num_keywords = logits.shape[2]
|
||||
|
||||
loss = 0.0
|
||||
for i in range(num_utts):
|
||||
for j in range(num_keywords):
|
||||
# Add entropy loss CE = -(t * log(p) + (1 - t) * log(1 - p))
|
||||
if target[i] == j:
|
||||
# For the keyword, do max-polling
|
||||
prob = logits[i, :, j]
|
||||
m = mask[i]
|
||||
if min_duration > 0:
|
||||
m[:min_duration] = True
|
||||
prob = fill_mask_elements(m, 0.0, prob)
|
||||
prob = paddle.clip(prob, 1e-8, 1.0)
|
||||
max_prob = prob.max()
|
||||
loss += -paddle.log(max_prob)
|
||||
else:
|
||||
# For other keywords or filler, do min-polling
|
||||
prob = 1 - logits[i, :, j]
|
||||
prob = fill_mask_elements(mask[i], 1.0, prob)
|
||||
prob = paddle.clip(prob, 1e-8, 1.0)
|
||||
min_prob = prob.min()
|
||||
loss += -paddle.log(min_prob)
|
||||
loss = loss / num_utts
|
||||
|
||||
# Compute accuracy of current batch
|
||||
mask = mask.unsqueeze(-1)
|
||||
logits = fill_mask_elements(mask, 0.0, logits)
|
||||
max_logits = logits.max(1)
|
||||
num_correct = 0
|
||||
for i in range(num_utts):
|
||||
max_p = max_logits[i].max(0).item()
|
||||
idx = max_logits[i].argmax(0).item()
|
||||
# Predict correct as the i'th keyword
|
||||
if max_p > 0.5 and idx == target[i].item():
|
||||
num_correct += 1
|
||||
# Predict correct as the filler, filler id < 0
|
||||
if max_p < 0.5 and target[i].item() < 0:
|
||||
num_correct += 1
|
||||
acc = num_correct / num_utts
|
||||
# acc = 0.0
|
||||
return loss, num_correct, acc
|
@ -0,0 +1,233 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from wekws(https://github.com/wenet-e2e/wekws)
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
class DSDilatedConv1d(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
dilation: int=1,
|
||||
stride: int=1,
|
||||
bias: bool=True, ):
|
||||
super(DSDilatedConv1d, self).__init__()
|
||||
self.receptive_fields = dilation * (kernel_size - 1)
|
||||
self.conv = nn.Conv1D(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size,
|
||||
padding=0,
|
||||
dilation=dilation,
|
||||
stride=stride,
|
||||
groups=in_channels,
|
||||
bias_attr=bias, )
|
||||
self.bn = nn.BatchNorm1D(in_channels)
|
||||
self.pointwise = nn.Conv1D(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
bias_attr=bias)
|
||||
|
||||
def forward(self, inputs: paddle.Tensor):
|
||||
outputs = self.conv(inputs)
|
||||
outputs = self.bn(outputs)
|
||||
outputs = self.pointwise(outputs)
|
||||
return outputs
|
||||
|
||||
|
||||
class TCNBlock(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
res_channels: int,
|
||||
kernel_size: int,
|
||||
dilation: int,
|
||||
causal: bool, ):
|
||||
super(TCNBlock, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.res_channels = res_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation = dilation
|
||||
self.causal = causal
|
||||
self.receptive_fields = dilation * (kernel_size - 1)
|
||||
self.half_receptive_fields = self.receptive_fields // 2
|
||||
self.conv1 = DSDilatedConv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=res_channels,
|
||||
kernel_size=kernel_size,
|
||||
dilation=dilation, )
|
||||
self.bn1 = nn.BatchNorm1D(res_channels)
|
||||
self.relu1 = nn.ReLU()
|
||||
|
||||
self.conv2 = nn.Conv1D(
|
||||
in_channels=res_channels, out_channels=res_channels, kernel_size=1)
|
||||
self.bn2 = nn.BatchNorm1D(res_channels)
|
||||
self.relu2 = nn.ReLU()
|
||||
|
||||
def forward(self, inputs: paddle.Tensor):
|
||||
outputs = self.relu1(self.bn1(self.conv1(inputs)))
|
||||
outputs = self.bn2(self.conv2(outputs))
|
||||
if self.causal:
|
||||
inputs = inputs[:, :, self.receptive_fields:]
|
||||
else:
|
||||
inputs = inputs[:, :, self.half_receptive_fields:
|
||||
-self.half_receptive_fields]
|
||||
if self.in_channels == self.res_channels:
|
||||
res_out = self.relu2(outputs + inputs)
|
||||
else:
|
||||
res_out = self.relu2(outputs)
|
||||
return res_out
|
||||
|
||||
|
||||
class TCNStack(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
stack_num: int,
|
||||
stack_size: int,
|
||||
res_channels: int,
|
||||
kernel_size: int,
|
||||
causal: bool, ):
|
||||
super(TCNStack, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.stack_num = stack_num
|
||||
self.stack_size = stack_size
|
||||
self.res_channels = res_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.causal = causal
|
||||
self.res_blocks = self.stack_tcn_blocks()
|
||||
self.receptive_fields = self.calculate_receptive_fields()
|
||||
self.res_blocks = nn.Sequential(*self.res_blocks)
|
||||
|
||||
def calculate_receptive_fields(self):
|
||||
receptive_fields = 0
|
||||
for block in self.res_blocks:
|
||||
receptive_fields += block.receptive_fields
|
||||
return receptive_fields
|
||||
|
||||
def build_dilations(self):
|
||||
dilations = []
|
||||
for s in range(0, self.stack_size):
|
||||
for l in range(0, self.stack_num):
|
||||
dilations.append(2**l)
|
||||
return dilations
|
||||
|
||||
def stack_tcn_blocks(self):
|
||||
dilations = self.build_dilations()
|
||||
res_blocks = nn.LayerList()
|
||||
|
||||
res_blocks.append(
|
||||
TCNBlock(
|
||||
self.in_channels,
|
||||
self.res_channels,
|
||||
self.kernel_size,
|
||||
dilations[0],
|
||||
self.causal, ))
|
||||
for dilation in dilations[1:]:
|
||||
res_blocks.append(
|
||||
TCNBlock(
|
||||
self.res_channels,
|
||||
self.res_channels,
|
||||
self.kernel_size,
|
||||
dilation,
|
||||
self.causal, ))
|
||||
return res_blocks
|
||||
|
||||
def forward(self, inputs: paddle.Tensor):
|
||||
outputs = self.res_blocks(inputs)
|
||||
return outputs
|
||||
|
||||
|
||||
class MDTC(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
stack_num: int,
|
||||
stack_size: int,
|
||||
in_channels: int,
|
||||
res_channels: int,
|
||||
kernel_size: int,
|
||||
causal: bool=True, ):
|
||||
super(MDTC, self).__init__()
|
||||
assert kernel_size % 2 == 1
|
||||
self.kernel_size = kernel_size
|
||||
self.causal = causal
|
||||
self.preprocessor = TCNBlock(
|
||||
in_channels, res_channels, kernel_size, dilation=1, causal=causal)
|
||||
self.relu = nn.ReLU()
|
||||
self.blocks = nn.LayerList()
|
||||
self.receptive_fields = self.preprocessor.receptive_fields
|
||||
for i in range(stack_num):
|
||||
self.blocks.append(
|
||||
TCNStack(res_channels, stack_size, 1, res_channels, kernel_size,
|
||||
causal))
|
||||
self.receptive_fields += self.blocks[-1].receptive_fields
|
||||
self.half_receptive_fields = self.receptive_fields // 2
|
||||
self.hidden_dim = res_channels
|
||||
|
||||
def forward(self, x: paddle.Tensor):
|
||||
if self.causal:
|
||||
outputs = F.pad(x, (0, 0, self.receptive_fields, 0, 0, 0),
|
||||
'constant')
|
||||
else:
|
||||
outputs = F.pad(
|
||||
x,
|
||||
(0, 0, self.half_receptive_fields, self.half_receptive_fields,
|
||||
0, 0),
|
||||
'constant', )
|
||||
outputs = outputs.transpose([0, 2, 1])
|
||||
outputs_list = []
|
||||
outputs = self.relu(self.preprocessor(outputs))
|
||||
for block in self.blocks:
|
||||
outputs = block(outputs)
|
||||
outputs_list.append(outputs)
|
||||
|
||||
normalized_outputs = []
|
||||
output_size = outputs_list[-1].shape[-1]
|
||||
for x in outputs_list:
|
||||
remove_length = x.shape[-1] - output_size
|
||||
if self.causal and remove_length > 0:
|
||||
normalized_outputs.append(x[:, :, remove_length:])
|
||||
elif not self.causal and remove_length > 1:
|
||||
half_remove_length = remove_length // 2
|
||||
normalized_outputs.append(
|
||||
x[:, :, half_remove_length:-half_remove_length])
|
||||
else:
|
||||
normalized_outputs.append(x)
|
||||
|
||||
outputs = paddle.zeros_like(
|
||||
outputs_list[-1], dtype=outputs_list[-1].dtype)
|
||||
for x in normalized_outputs:
|
||||
outputs += x
|
||||
outputs = outputs.transpose([0, 2, 1])
|
||||
return outputs, None
|
||||
|
||||
|
||||
class KWSModel(nn.Layer):
|
||||
def __init__(self, backbone, num_keywords):
|
||||
super(KWSModel, self).__init__()
|
||||
self.backbone = backbone
|
||||
self.linear = nn.Linear(self.backbone.hidden_dim, num_keywords)
|
||||
self.activation = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
outputs = self.backbone(x)
|
||||
outputs = self.linear(outputs)
|
||||
return self.activation(outputs)
|
Loading…
Reference in new issue