Add KWS example.

pull/1558/head
KP 3 years ago
parent b60b1dadde
commit f9761d532c

@ -63,10 +63,12 @@ class HeySnips(AudioClassificationDataset):
files = []
labels = []
self.keys = []
self.durations = []
for sample in meta_info:
key, target, _, wav = sample
key, target, duration, wav = sample
files.append(wav)
labels.append(int(target))
self.keys.append(key)
self.durations.append(float(duration))
return files, labels

@ -0,0 +1,22 @@
# MDTC Keyword Spotting with HeySnips Dataset
## Dataset
Before running scripts, you **MUST** follow this instruction to download the dataset: https://github.com/sonos/keyword-spotting-research-datasets
After you download and decompress the dataset archive, you should **REPLACE** the value of `data_dir` in `conf/*.yaml` to complete dataset config.
## Get Started
In this section, we will train the [MDTC](https://arxiv.org/pdf/2102.13552.pdf) model and evaluate on "Hey Snips" dataset.
```sh
CUDA_VISIBLE_DEVICES=0,1 ./run.sh conf/mdtc.yaml
```
This script contains training and scoring steps. You can just set the `CUDA_VISIBLE_DEVICES` environment var to run on single gpu or multi-gpus.
The vars `stage` and `stop_stage` in `./run.sh` controls the running steps:
- stage 1: Training from scratch.
- stage 2: Evaluating model on test dataset and computing detection error tradeoff(DET) of all trigger thresholds.
- stage 3: Plotting the DET cruve for visualizaiton.

@ -0,0 +1,8 @@
## Metrics
We mesure FRRs with fixing false alarms in one hour:
|Model|False Alarm| False Reject Rate|
|--|--|--|
|MDTC| 1| 0.003559 |

@ -0,0 +1,39 @@
data:
data_dir: '/PATH/TO/DATA/hey_snips_research_6k_en_train_eval_clean_ter'
dataset: 'paddleaudio.datasets:HeySnips'
model:
num_keywords: 1
backbone: 'paddlespeech.kws.models:MDTC'
config:
stack_num: 3
stack_size: 4
in_channels: 80
res_channels: 32
kernel_size: 5
feature:
feat_type: 'kaldi_fbank'
sample_rate: 16000
frame_shift: 10
frame_length: 25
n_mels: 80
training:
epochs: 100
num_workers: 16
batch_size: 100
checkpoint_dir: './checkpoint'
save_freq: 10
log_freq: 10
learning_rate: 0.001
weight_decay: 0.00005
grad_clip: 5.0
scoring:
batch_size: 100
num_workers: 16
checkpoint: './checkpoint/epoch_100/model.pdparams'
score_file: './scores.txt'
stats_file: './stats.0.txt'
img_file: './det.png'

@ -0,0 +1,2 @@
#!/bin/bash
python3 ${BIN_DIR}/plot_det_curve.py --cfg_path=$1 --keyword HeySnips

@ -0,0 +1,5 @@
#!/bin/bash
python3 ${BIN_DIR}/score.py --cfg_path=$1
python3 ${BIN_DIR}/compute_det.py --cfg_path=$1

@ -0,0 +1,12 @@
#!/bin/bash
ngpu=$1
cfg_path=$2
if [ ${ngpu} -gt 0 ]; then
python3 -m paddle.distributed.launch --gpus $CUDA_VISIBLE_DEVICES ${BIN_DIR}/train.py \
--cfg_path ${cfg_path}
else
python3 ${BIN_DIR}/train.py \
--cfg_path ${cfg_path}
fi

@ -13,35 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
. ./path.sh
set -e
source path.sh
stage=0
stop_stage=50
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
# data directory
# if we set the variable ${dir}, we will store the wav info to this directory
# otherwise, we will store the wav info to vox1 and vox2 directory respectively
# vox2 wav path, we must convert the m4a format to wav format
dir=data/ # data info directory
stage=1
stop_stage=3
exp_dir=exp/ecapa-tdnn-vox12-big/ # experiment directory
conf_path=conf/mdtc.yaml
gpus=0,1,2,3
cfg_path=$1
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
mkdir -p ${exp_dir}
if [ $stage -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# stage 0: data prepare for vox1 and vox2, vox2 must be converted from m4a to wav
bash ./local/data.sh ${dir} ${conf_path}|| exit -1;
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
./local/train.sh ${ngpu} ${cfg_path} || exit -1
fi
if [ $stage -le 1 ] && [ ${stop_stage} -ge 1 ]; then
CUDA_VISIBLE_DEVICES=${gpus} bash ./local/train.sh ${dir} ${exp_dir} ${conf_path}
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
./local/score.sh ${cfg_path} || exit -1
fi
if [ $stage -le 2 ] && [ ${stop_stage} -ge 2 ]; then
CUDA_VISIBLE_DEVICES=0 bash ./local/test.sh ${dir} ${exp_dir} ${conf_path}
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
./local/plot.sh ${cfg_path} || exit -1
fi

@ -0,0 +1,39 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import paddle
def collate_features(batch):
# (key, feat, label)
collate_start = time.time()
keys = []
feats = []
labels = []
lengths = []
for sample in batch:
keys.append(sample[0])
feats.append(sample[1])
labels.append(sample[2])
lengths.append(sample[1].shape[0])
max_length = max(lengths)
for i in range(len(feats)):
feats[i] = paddle.nn.functional.pad(
feats[i], [0, max_length - feats[i].shape[0], 0, 0],
data_format='NLC')
return keys, paddle.stack(feats), paddle.to_tensor(
labels), paddle.to_tensor(lengths)

@ -11,15 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
# Modified from wekws(https://github.com/wenet-e2e/wekws)
import argparse
import os
import sys
import yaml
from tqdm import tqdm
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
def load_label_and_score(keyword, label_file, score_file):
# score_table: {uttid: [keywordlist]}
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--cfg_path", type=str, required=True)
parser.add_argument('--keyword', type=int, default=0, help='keyword label')
parser.add_argument('--step', type=float, default=0.01, help='threshold step')
parser.add_argument('--window_shift', type=int, default=50, help='window_shift is used to skip the frames after triggered')
args = parser.parse_args()
# yapf: enable
def load_label_and_score(keyword, ds, score_file):
score_table = {}
with open(score_file, 'r', encoding='utf8') as fin:
for line in fin:
@ -34,59 +45,40 @@ def load_label_and_score(keyword, label_file, score_file):
keyword_table = {}
filler_table = {}
filler_duration = 0.0
with open(label_file, 'r', encoding='utf8') as fin:
for line in fin:
obj = json.loads(line.strip())
assert 'key' in obj
assert 'txt' in obj
assert 'duration' in obj
key = obj['key']
index = obj['txt']
duration = obj['duration']
assert key in score_table
if index == keyword:
keyword_table[key] = score_table[key]
else:
filler_table[key] = score_table[key]
filler_duration += duration
for key, index, duration in zip(ds.keys, ds.labels, ds.durations):
assert key in score_table
if index == keyword:
keyword_table[key] = score_table[key]
else:
filler_table[key] = score_table[key]
filler_duration += duration
return keyword_table, filler_table, filler_duration
class Args:
def __init__(self):
self.test_data = '/ssd3/chenxiaojie06/PaddleSpeech/DeepSpeech/paddlespeech/kws/models/data/test/data.list'
self.keyword = 0
self.score_file = os.path.join(
os.path.abspath(sys.argv[1]), 'score.txt')
self.stats_file = os.path.join(
os.path.abspath(sys.argv[1]), 'stats.0.txt')
self.step = 0.01
self.window_shift = 50
if __name__ == '__main__':
args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path))
with open(args.cfg_path, 'r') as f:
config = yaml.safe_load(f)
data_conf = config['data']
feat_conf = config['feature']
scoring_conf = config['scoring']
args = Args()
# Dataset
ds_class = dynamic_import(data_conf['dataset'])
test_ds = ds_class(data_dir=data_conf['data_dir'], mode='test', **feat_conf)
if __name__ == '__main__':
# parser = argparse.ArgumentParser(description='compute det curve')
# parser.add_argument('--test_data', required=True, help='label file')
# parser.add_argument('--keyword', type=int, default=0, help='keyword label')
# parser.add_argument('--score_file', required=True, help='score file')
# parser.add_argument('--step', type=float, default=0.01,
# help='threshold step')
# parser.add_argument('--window_shift', type=int, default=50,
# help='window_shift is used to skip the frames after triggered')
# parser.add_argument('--stats_file',
# required=True,
# help='false reject/alarm stats file')
# args = parser.parse_args()
score_file = os.path.abspath(scoring_conf['score_file'])
stats_file = os.path.abspath(scoring_conf['stats_file'])
window_shift = args.window_shift
keyword_table, filler_table, filler_duration = load_label_and_score(
args.keyword, args.test_data, args.score_file)
args.keyword, test_ds, score_file)
print('Filler total duration Hours: {}'.format(filler_duration / 3600.0))
pbar = tqdm(total=int(1.0 / args.step))
with open(args.stats_file, 'w', encoding='utf8') as fout:
keyword_index = int(args.keyword)
with open(stats_file, 'w', encoding='utf8') as fout:
keyword_index = args.keyword
threshold = 0.0
while threshold <= 1.0:
num_false_reject = 0
@ -103,7 +95,7 @@ if __name__ == '__main__':
while i < len(score_list):
if score_list[i] >= threshold:
num_false_alarm += 1
i += window_shift
i += args.window_shift
else:
i += 1
if len(keyword_table) != 0:
@ -118,4 +110,4 @@ if __name__ == '__main__':
pbar.update(1)
pbar.close()
print('DET saved to: {}'.format(args.stats_file))
print('DET saved to: {}'.format(stats_file))

@ -11,11 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from wekws(https://github.com/wenet-e2e/wekws)
import argparse
import os
import sys
import matplotlib.pyplot as plt
import numpy as np
import yaml
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--cfg_path", type=str, required=True)
parser.add_argument("--keyword", type=str, required=True)
args = parser.parse_args()
# yapf: enable
def load_stats_file(stats_file):
@ -29,7 +38,7 @@ def load_stats_file(stats_file):
return np.array(values)
def plot_det_curve(keywords, stats_dir, figure_file, xlim, x_step, ylim,
def plot_det_curve(keywords, stats_file, figure_file, xlim, x_step, ylim,
y_step):
plt.figure(dpi=200)
plt.rcParams['xtick.direction'] = 'in'
@ -37,7 +46,6 @@ def plot_det_curve(keywords, stats_dir, figure_file, xlim, x_step, ylim,
plt.rcParams['font.size'] = 12
for index, keyword in enumerate(keywords):
stats_file = os.path.join(stats_dir, 'stats.' + str(index) + '.txt')
values = load_stats_file(stats_file)
plt.plot(values[:, 0], values[:, 1], label=keyword)
@ -53,11 +61,14 @@ def plot_det_curve(keywords, stats_dir, figure_file, xlim, x_step, ylim,
if __name__ == '__main__':
args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path))
with open(args.cfg_path, 'r') as f:
config = yaml.safe_load(f)
keywords = ['Hey_Snips']
img_path = os.path.join(os.path.abspath(sys.argv[1]), 'det.png')
plot_det_curve(keywords,
os.path.abspath(sys.argv[1]), img_path, 10, 2, 10, 2)
scoring_conf = config['scoring']
img_file = os.path.abspath(scoring_conf['img_file'])
stats_file = os.path.abspath(scoring_conf['stats_file'])
keywords = [args.keyword]
plot_det_curve(keywords, stats_file, img_file, 10, 2, 10, 2)
print('DET curve image saved to: {}'.format(img_path))
print('DET curve image saved to: {}'.format(img_file))

@ -11,80 +11,56 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from wekws(https://github.com/wenet-e2e/wekws)
import argparse
import os
import sys
import time
import paddle
from mdtc import KWSModel
from mdtc import MDTC
import yaml
from tqdm import tqdm
from paddleaudio.datasets import HeySnips
from paddlespeech.kws.exps.mdtc.collate import collate_features
from paddlespeech.kws.models.mdtc import KWSModel
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--cfg_path", type=str, required=True)
args = parser.parse_args()
# yapf: enable
def collate_features(batch):
# (key, feat, label) in one sample
collate_start = time.time()
keys = []
feats = []
labels = []
lengths = []
for sample in batch:
keys.append(sample[0])
feats.append(sample[1])
labels.append(sample[2])
lengths.append(sample[1].shape[0])
max_length = max(lengths)
for i in range(len(feats)):
feats[i] = paddle.nn.functional.pad(
feats[i], [0, max_length - feats[i].shape[0], 0, 0],
data_format='NLC')
return keys, paddle.stack(feats), paddle.to_tensor(
labels), paddle.to_tensor(lengths)
if __name__ == '__main__':
args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path))
with open(args.cfg_path, 'r') as f:
config = yaml.safe_load(f)
model_conf = config['model']
data_conf = config['data']
feat_conf = config['feature']
scoring_conf = config['scoring']
if __name__ == '__main__':
# Dataset
feat_conf = {
# 'n_mfcc': 80,
'n_mels': 80,
'frame_shift': 10,
'frame_length': 25,
# 'dither': 1.0,
}
test_ds = HeySnips(
mode='test', feat_type='kaldi_fbank', sample_rate=16000, **feat_conf)
ds_class = dynamic_import(data_conf['dataset'])
test_ds = ds_class(data_dir=data_conf['data_dir'], mode='test', **feat_conf)
test_sampler = paddle.io.BatchSampler(
test_ds, batch_size=32, drop_last=False)
test_ds, batch_size=scoring_conf['batch_size'], drop_last=False)
test_loader = paddle.io.DataLoader(
test_ds,
batch_sampler=test_sampler,
num_workers=16,
num_workers=scoring_conf['num_workers'],
return_list=True,
use_buffer_reader=True,
collate_fn=collate_features, )
# Model
backbone = MDTC(
stack_num=3,
stack_size=4,
in_channels=80,
res_channels=32,
kernel_size=5,
causal=True, )
model = KWSModel(backbone=backbone, num_keywords=1)
model = paddle.DataParallel(model)
# kws_checkpoint = '/ssd3/chenxiaojie06/PaddleSpeech/DeepSpeech/paddlespeech/kws/models/checkpoint/epoch_10_0.8903940343290826/model.pdparams'
kws_checkpoint = os.path.join(
os.path.abspath(sys.argv[1]), 'model.pdparams')
model.set_state_dict(paddle.load(kws_checkpoint))
backbone_class = dynamic_import(model_conf['backbone'])
backbone = backbone_class(**model_conf['config'])
model = KWSModel(backbone=backbone, num_keywords=model_conf['num_keywords'])
model.set_state_dict(paddle.load(scoring_conf['checkpoint']))
model.eval()
score_abs_path = os.path.join(os.path.abspath(sys.argv[1]), 'score.txt')
with paddle.no_grad(), open(score_abs_path, 'w', encoding='utf8') as fout:
with paddle.no_grad(), open(
scoring_conf['score_file'], 'w', encoding='utf8') as fout:
for batch_idx, batch in enumerate(
tqdm(test_loader, total=len(test_loader))):
keys, feats, labels, lengths = batch
@ -100,4 +76,4 @@ if __name__ == '__main__':
fout.write(
'{} {} {}\n'.format(key, keyword_i, score_frames))
print('Scores saved to: {}'.format(score_abs_path))
print('Result saved to: {}'.format(scoring_conf['score_file']))

@ -11,77 +11,47 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import time
import paddle
from loss import max_pooling_loss
from mdtc import KWSModel
from mdtc import MDTC
import yaml
from paddleaudio.datasets import HeySnips
from paddleaudio.utils import logger
from paddleaudio.utils import Timer
from paddlespeech.kws.exps.mdtc.collate import collate_features
from paddlespeech.kws.models.loss import max_pooling_loss
from paddlespeech.kws.models.mdtc import KWSModel
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--cfg_path", type=str, required=True)
args = parser.parse_args()
# yapf: enable
def collate_features(batch):
# (key, feat, label)
collate_start = time.time()
keys = []
feats = []
labels = []
lengths = []
for sample in batch:
keys.append(sample[0])
feats.append(sample[1])
labels.append(sample[2])
lengths.append(sample[1].shape[0])
max_length = max(lengths)
for i in range(len(feats)):
feats[i] = paddle.nn.functional.pad(
feats[i], [0, max_length - feats[i].shape[0], 0, 0],
data_format='NLC')
if __name__ == '__main__':
nranks = paddle.distributed.get_world_size()
if paddle.distributed.get_world_size() > 1:
paddle.distributed.init_parallel_env()
local_rank = paddle.distributed.get_rank()
return keys, paddle.stack(feats), paddle.to_tensor(
labels), paddle.to_tensor(lengths)
args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path))
with open(args.cfg_path, 'r') as f:
config = yaml.safe_load(f)
model_conf = config['model']
data_conf = config['data']
feat_conf = config['feature']
training_conf = config['training']
if __name__ == '__main__':
# Dataset
feat_conf = {
# 'n_mfcc': 80,
'n_mels': 80,
'frame_shift': 10,
'frame_length': 25,
# 'dither': 1.0,
}
data_dir = '/ssd1/chenxiaojie06/datasets/hey_snips/hey_snips_research_6k_en_train_eval_clean_ter'
train_ds = HeySnips(
data_dir=data_dir,
mode='train',
feat_type='kaldi_fbank',
sample_rate=16000,
**feat_conf)
dev_ds = HeySnips(
data_dir=data_dir,
mode='dev',
feat_type='kaldi_fbank',
sample_rate=16000,
**feat_conf)
training_conf = {
'epochs': 100,
'learning_rate': 0.001,
'weight_decay': 0.00005,
'num_workers': 16,
'batch_size': 100,
'checkpoint_dir': './checkpoint',
'save_freq': 10,
'log_freq': 10,
}
train_sampler = paddle.io.BatchSampler(
ds_class = dynamic_import(data_conf['dataset'])
train_ds = ds_class(
data_dir=data_conf['data_dir'], mode='train', **feat_conf)
dev_ds = ds_class(data_dir=data_conf['data_dir'], mode='dev', **feat_conf)
train_sampler = paddle.io.DistributedBatchSampler(
train_ds,
batch_size=training_conf['batch_size'],
shuffle=True,
@ -95,16 +65,11 @@ if __name__ == '__main__':
collate_fn=collate_features, )
# Model
backbone = MDTC(
stack_num=3,
stack_size=4,
in_channels=80,
res_channels=32,
kernel_size=5,
causal=True, )
model = KWSModel(backbone=backbone, num_keywords=1)
backbone_class = dynamic_import(model_conf['backbone'])
backbone = backbone_class(**model_conf['config'])
model = KWSModel(backbone=backbone, num_keywords=model_conf['num_keywords'])
model = paddle.DataParallel(model)
clip = paddle.nn.ClipGradByGlobalNorm(5.0)
clip = paddle.nn.ClipGradByGlobalNorm(training_conf['grad_clip'])
optimizer = paddle.optimizer.Adam(
learning_rate=training_conf['learning_rate'],
weight_decay=training_conf['weight_decay'],
@ -122,9 +87,7 @@ if __name__ == '__main__':
avg_loss = 0
num_corrects = 0
num_samples = 0
batch_start = time.time()
for batch_idx, batch in enumerate(train_loader):
# print('Fetch one batch: {:.4f}'.format(time.time()-batch_start))
keys, feats, labels, lengths = batch
logits = model(feats)
loss, corrects, acc = criterion(logits, labels, lengths)
@ -144,7 +107,8 @@ if __name__ == '__main__':
timer.count()
if (batch_idx + 1) % training_conf['log_freq'] == 0:
if (batch_idx + 1
) % training_conf['log_freq'] == 0 and local_rank == 0:
lr = optimizer.get_lr()
avg_loss /= training_conf['log_freq']
avg_acc = num_corrects / num_samples
@ -161,10 +125,9 @@ if __name__ == '__main__':
avg_loss = 0
num_corrects = 0
num_samples = 0
batch_start = time.time()
if epoch % training_conf[
'save_freq'] == 0 and batch_idx + 1 == steps_per_epoch:
'save_freq'] == 0 and batch_idx + 1 == steps_per_epoch and local_rank == 0:
dev_sampler = paddle.io.BatchSampler(
dev_ds,
batch_size=training_conf['batch_size'],
@ -197,7 +160,7 @@ if __name__ == '__main__':
# Save model
save_dir = os.path.join(training_conf['checkpoint_dir'],
'epoch_{}_{:.4f}'.format(epoch, eval_acc))
'epoch_{}'.format(epoch))
logger.info('Saving model checkpoint to {}'.format(save_dir))
paddle.save(model.state_dict(),
os.path.join(save_dir, 'model.pdparams'))

@ -11,3 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .mdtc import KWSModel
from .mdtc import MDTC

@ -0,0 +1,80 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from wekws(https://github.com/wenet-e2e/wekws)
import paddle
def fill_mask_elements(condition, value, x):
assert condition.shape == x.shape
values = paddle.ones_like(x, dtype=x.dtype) * value
return paddle.where(condition, values, x)
def max_pooling_loss(logits: paddle.Tensor,
target: paddle.Tensor,
lengths: paddle.Tensor,
min_duration: int=0):
mask = padding_mask(lengths)
num_utts = logits.shape[0]
num_keywords = logits.shape[2]
loss = 0.0
for i in range(num_utts):
for j in range(num_keywords):
# Add entropy loss CE = -(t * log(p) + (1 - t) * log(1 - p))
if target[i] == j:
# For the keyword, do max-polling
prob = logits[i, :, j]
m = mask[i]
if min_duration > 0:
m[:min_duration] = True
prob = fill_mask_elements(m, 0.0, prob)
prob = paddle.clip(prob, 1e-8, 1.0)
max_prob = prob.max()
loss += -paddle.log(max_prob)
else:
# For other keywords or filler, do min-polling
prob = 1 - logits[i, :, j]
prob = fill_mask_elements(mask[i], 1.0, prob)
prob = paddle.clip(prob, 1e-8, 1.0)
min_prob = prob.min()
loss += -paddle.log(min_prob)
loss = loss / num_utts
# Compute accuracy of current batch
mask = mask.unsqueeze(-1)
logits = fill_mask_elements(mask, 0.0, logits)
max_logits = logits.max(1)
num_correct = 0
for i in range(num_utts):
max_p = max_logits[i].max(0).item()
idx = max_logits[i].argmax(0).item()
# Predict correct as the i'th keyword
if max_p > 0.5 and idx == target[i].item():
num_correct += 1
# Predict correct as the filler, filler id < 0
if max_p < 0.5 and target[i].item() < 0:
num_correct += 1
acc = num_correct / num_utts
# acc = 0.0
return loss, num_correct, acc
def padding_mask(lengths: paddle.Tensor) -> paddle.Tensor:
batch_size = lengths.shape[0]
max_len = int(lengths.max().item())
seq = paddle.arange(max_len, dtype=paddle.int64)
seq = seq.expand((batch_size, max_len))
return seq >= lengths.unsqueeze(1)

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from wekws(https://github.com/wenet-e2e/wekws)
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
@ -163,7 +164,7 @@ class MDTC(nn.Layer):
in_channels: int,
res_channels: int,
kernel_size: int,
causal: bool, ):
causal: bool=True, ):
super(MDTC, self).__init__()
assert kernel_size % 2 == 1
self.kernel_size = kernel_size
@ -230,17 +231,3 @@ class KWSModel(nn.Layer):
outputs = self.backbone(x)
outputs = self.linear(outputs)
return self.activation(outputs)
if __name__ == '__main__':
paddle.set_device('cpu')
from paddleaudio.features import LogMelSpectrogram
mdtc = MDTC(3, 4, 80, 32, 5, causal=True)
x = paddle.randn(shape=(32, 16000 * 5))
feature_extractor = LogMelSpectrogram(sr=16000, n_fft=512, n_mels=80)
feats = feature_extractor(x).transpose([0, 2, 1])
print(feats.shape)
res, _ = mdtc(feats)
print(res.shape)

Loading…
Cancel
Save