Merge pull request #1783 from KPatr1ck/kws

[KWS]Update KWS example.
pull/1787/head
Hui Zhang 2 years ago committed by GitHub
commit 2504ccbf8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,39 +1,49 @@
data:
data_dir: '/PATH/TO/DATA/hey_snips_research_6k_en_train_eval_clean_ter'
dataset: 'paddleaudio.datasets:HeySnips'
# https://yaml.org/type/float.html
###########################################
# Data #
###########################################
dataset: 'paddleaudio.datasets:HeySnips'
data_dir: '/PATH/TO/DATA/hey_snips_research_6k_en_train_eval_clean_ter'
model:
num_keywords: 1
backbone: 'paddlespeech.kws.models:MDTC'
config:
stack_num: 3
stack_size: 4
in_channels: 80
res_channels: 32
kernel_size: 5
############################################
# Network Architecture #
############################################
backbone: 'paddlespeech.kws.models:MDTC'
num_keywords: 1
stack_num: 3
stack_size: 4
in_channels: 80
res_channels: 32
kernel_size: 5
feature:
feat_type: 'kaldi_fbank'
sample_rate: 16000
frame_shift: 10
frame_length: 25
n_mels: 80
###########################################
# Feature #
###########################################
feat_type: 'kaldi_fbank'
sample_rate: 16000
frame_shift: 10
frame_length: 25
n_mels: 80
training:
epochs: 100
num_workers: 16
batch_size: 100
checkpoint_dir: './checkpoint'
save_freq: 10
log_freq: 10
learning_rate: 0.001
weight_decay: 0.00005
grad_clip: 5.0
###########################################
# Training #
###########################################
epochs: 100
num_workers: 16
batch_size: 100
checkpoint_dir: './checkpoint'
save_freq: 10
log_freq: 10
learning_rate: 0.001
weight_decay: 0.00005
grad_clip: 5.0
scoring:
batch_size: 100
num_workers: 16
checkpoint: './checkpoint/epoch_100/model.pdparams'
score_file: './scores.txt'
stats_file: './stats.0.txt'
img_file: './det.png'
###########################################
# Scoring #
###########################################
batch_size: 100
num_workers: 16
checkpoint: './checkpoint/epoch_100/model.pdparams'
score_file: './scores.txt'
stats_file: './stats.0.txt'
img_file: './det.png'

@ -1,2 +1,25 @@
#!/bin/bash
python3 ${BIN_DIR}/plot_det_curve.py --cfg_path=$1 --keyword HeySnips
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if [ $# != 3 ];then
echo "usage: ${0} config_path checkpoint output_file"
exit -1
fi
keyword=$1
stats_file=$2
img_file=$3
python3 ${BIN_DIR}/plot_det_curve.py --keyword_label ${keyword} --stats_file ${stats_file} --img_file ${img_file}

@ -1,5 +1,27 @@
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
python3 ${BIN_DIR}/score.py --cfg_path=$1
if [ $# != 4 ];then
echo "usage: ${0} checkpoint score_file stats_file"
exit -1
fi
python3 ${BIN_DIR}/compute_det.py --cfg_path=$1
cfg_path=$1
ckpt=$2
score_file=$3
stats_file=$4
python3 ${BIN_DIR}/score.py --config ${cfg_path} --ckpt ${ckpt} --score_file ${score_file} || exit -1
python3 ${BIN_DIR}/compute_det.py --config ${cfg_path} --score_file ${score_file} --stats_file ${stats_file} || exit -1

@ -1,13 +1,31 @@
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if [ $# != 2 ];then
echo "usage: ${0} num_gpus config_path"
exit -1
fi
ngpu=$1
cfg_path=$2
if [ ${ngpu} -gt 0 ]; then
python3 -m paddle.distributed.launch --gpus $CUDA_VISIBLE_DEVICES ${BIN_DIR}/train.py \
--cfg_path ${cfg_path}
--config ${cfg_path}
else
echo "set CUDA_VISIBLE_DEVICES to enable multi-gpus trainning."
python3 ${BIN_DIR}/train.py \
--cfg_path ${cfg_path}
--config ${cfg_path}
fi

@ -32,10 +32,16 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
./local/train.sh ${ngpu} ${cfg_path} || exit -1
fi
ckpt=./checkpoint/epoch_100/model.pdparams
score_file=./scores.txt
stats_file=./stats.0.txt
img_file=./det.png
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
./local/score.sh ${cfg_path} || exit -1
./local/score.sh ${cfg_path} ${ckpt} ${score_file} ${stats_file} || exit -1
fi
keyword=HeySnips
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
./local/plot.sh ${cfg_path} || exit -1
./local/plot.sh ${keyword} ${stats_file} ${img_file} || exit -1
fi

@ -12,24 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from wekws(https://github.com/wenet-e2e/wekws)
import argparse
import os
import paddle
import yaml
from tqdm import tqdm
from yacs.config import CfgNode
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--cfg_path", type=str, required=True)
parser.add_argument('--keyword_index', type=int, default=0, help='keyword index')
parser.add_argument('--step', type=float, default=0.01, help='threshold step of trigger score')
parser.add_argument('--window_shift', type=int, default=50, help='window_shift is used to skip the frames after triggered')
args = parser.parse_args()
# yapf: enable
def load_label_and_score(keyword_index: int,
ds: paddle.io.Dataset,
@ -61,26 +52,52 @@ def load_label_and_score(keyword_index: int,
if __name__ == '__main__':
args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path))
with open(args.cfg_path, 'r') as f:
config = yaml.safe_load(f)
parser = default_argument_parser()
parser.add_argument(
'--keyword_index', type=int, default=0, help='keyword index')
parser.add_argument(
'--step',
type=float,
default=0.01,
help='threshold step of trigger score')
parser.add_argument(
'--window_shift',
type=int,
default=50,
help='window_shift is used to skip the frames after triggered')
parser.add_argument(
"--score_file",
type=str,
required=True,
help='output file of trigger scores')
parser.add_argument(
'--stats_file',
type=str,
default='./stats.0.txt',
help='output file of detection error tradeoff')
args = parser.parse_args()
data_conf = config['data']
feat_conf = config['feature']
scoring_conf = config['scoring']
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
# Dataset
ds_class = dynamic_import(data_conf['dataset'])
test_ds = ds_class(data_dir=data_conf['data_dir'], mode='test', **feat_conf)
score_file = os.path.abspath(scoring_conf['score_file'])
stats_file = os.path.abspath(scoring_conf['stats_file'])
ds_class = dynamic_import(config['dataset'])
test_ds = ds_class(
data_dir=config['data_dir'],
mode='test',
feat_type=config['feat_type'],
sample_rate=config['sample_rate'],
frame_shift=config['frame_shift'],
frame_length=config['frame_length'],
n_mels=config['n_mels'], )
keyword_table, filler_table, filler_duration = load_label_and_score(
args.keyword, test_ds, score_file)
args.keyword_index, test_ds, args.score_file)
print('Filler total duration Hours: {}'.format(filler_duration / 3600.0))
pbar = tqdm(total=int(1.0 / args.step))
with open(stats_file, 'w', encoding='utf8') as fout:
with open(args.stats_file, 'w', encoding='utf8') as fout:
keyword_index = args.keyword_index
threshold = 0.0
while threshold <= 1.0:
@ -113,4 +130,4 @@ if __name__ == '__main__':
pbar.update(1)
pbar.close()
print('DET saved to: {}'.format(stats_file))
print('DET saved to: {}'.format(args.stats_file))

@ -17,12 +17,12 @@ import os
import matplotlib.pyplot as plt
import numpy as np
import yaml
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--cfg_path", type=str, required=True)
parser.add_argument("--keyword", type=str, required=True)
parser.add_argument('--keyword_label', type=str, required=True, help='keyword string shown on image')
parser.add_argument('--stats_file', type=str, required=True, help='output file of detection error tradeoff')
parser.add_argument('--img_file', type=str, default='./det.png', help='output det image')
args = parser.parse_args()
# yapf: enable
@ -61,14 +61,8 @@ def plot_det_curve(keywords, stats_file, figure_file, xlim, x_step, ylim,
if __name__ == '__main__':
args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path))
with open(args.cfg_path, 'r') as f:
config = yaml.safe_load(f)
scoring_conf = config['scoring']
img_file = os.path.abspath(scoring_conf['img_file'])
stats_file = os.path.abspath(scoring_conf['stats_file'])
keywords = [args.keyword]
plot_det_curve(keywords, stats_file, img_file, 10, 2, 10, 2)
img_file = os.path.abspath(args.img_file)
stats_file = os.path.abspath(args.stats_file)
plot_det_curve([args.keyword_label], stats_file, img_file, 10, 2, 10, 2)
print('DET curve image saved to: {}'.format(img_file))

@ -12,55 +12,67 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from wekws(https://github.com/wenet-e2e/wekws)
import argparse
import os
import paddle
import yaml
from tqdm import tqdm
from yacs.config import CfgNode
from paddlespeech.kws.exps.mdtc.collate import collate_features
from paddlespeech.kws.models.mdtc import KWSModel
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--cfg_path", type=str, required=True)
args = parser.parse_args()
# yapf: enable
if __name__ == '__main__':
args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path))
with open(args.cfg_path, 'r') as f:
config = yaml.safe_load(f)
parser = default_argument_parser()
parser.add_argument(
"--ckpt",
type=str,
required=True,
help='model checkpoint for evaluation.')
parser.add_argument(
"--score_file",
type=str,
default='./scores.txt',
help='output file of trigger scores')
args = parser.parse_args()
model_conf = config['model']
data_conf = config['data']
feat_conf = config['feature']
scoring_conf = config['scoring']
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
# Dataset
ds_class = dynamic_import(data_conf['dataset'])
test_ds = ds_class(data_dir=data_conf['data_dir'], mode='test', **feat_conf)
ds_class = dynamic_import(config['dataset'])
test_ds = ds_class(
data_dir=config['data_dir'],
mode='test',
feat_type=config['feat_type'],
sample_rate=config['sample_rate'],
frame_shift=config['frame_shift'],
frame_length=config['frame_length'],
n_mels=config['n_mels'], )
test_sampler = paddle.io.BatchSampler(
test_ds, batch_size=scoring_conf['batch_size'], drop_last=False)
test_ds, batch_size=config['batch_size'], drop_last=False)
test_loader = paddle.io.DataLoader(
test_ds,
batch_sampler=test_sampler,
num_workers=scoring_conf['num_workers'],
num_workers=config['num_workers'],
return_list=True,
use_buffer_reader=True,
collate_fn=collate_features, )
# Model
backbone_class = dynamic_import(model_conf['backbone'])
backbone = backbone_class(**model_conf['config'])
model = KWSModel(backbone=backbone, num_keywords=model_conf['num_keywords'])
model.set_state_dict(paddle.load(scoring_conf['checkpoint']))
backbone_class = dynamic_import(config['backbone'])
backbone = backbone_class(
stack_num=config['stack_num'],
stack_size=config['stack_size'],
in_channels=config['in_channels'],
res_channels=config['res_channels'],
kernel_size=config['kernel_size'], )
model = KWSModel(backbone=backbone, num_keywords=config['num_keywords'])
model.set_state_dict(paddle.load(args.ckpt))
model.eval()
with paddle.no_grad(), open(
scoring_conf['score_file'], 'w', encoding='utf8') as fout:
with paddle.no_grad(), open(args.score_file, 'w', encoding='utf8') as f:
for batch_idx, batch in enumerate(
tqdm(test_loader, total=len(test_loader))):
keys, feats, labels, lengths = batch
@ -73,7 +85,6 @@ if __name__ == '__main__':
keyword_scores = score[:, keyword_i]
score_frames = ' '.join(
['{:.6f}'.format(x) for x in keyword_scores.tolist()])
fout.write(
'{} {} {}\n'.format(key, keyword_i, score_frames))
f.write('{} {} {}\n'.format(key, keyword_i, score_frames))
print('Result saved to: {}'.format(scoring_conf['score_file']))
print('Result saved to: {}'.format(args.score_file))

@ -11,77 +11,88 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import paddle
import yaml
from yacs.config import CfgNode
from paddleaudio.utils import logger
from paddleaudio.utils import Timer
from paddlespeech.kws.exps.mdtc.collate import collate_features
from paddlespeech.kws.models.loss import max_pooling_loss
from paddlespeech.kws.models.mdtc import KWSModel
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--cfg_path", type=str, required=True)
args = parser.parse_args()
# yapf: enable
if __name__ == '__main__':
parser = default_argument_parser()
args = parser.parse_args()
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
nranks = paddle.distributed.get_world_size()
if paddle.distributed.get_world_size() > 1:
paddle.distributed.init_parallel_env()
local_rank = paddle.distributed.get_rank()
args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path))
with open(args.cfg_path, 'r') as f:
config = yaml.safe_load(f)
model_conf = config['model']
data_conf = config['data']
feat_conf = config['feature']
training_conf = config['training']
# Dataset
ds_class = dynamic_import(data_conf['dataset'])
ds_class = dynamic_import(config['dataset'])
train_ds = ds_class(
data_dir=data_conf['data_dir'], mode='train', **feat_conf)
dev_ds = ds_class(data_dir=data_conf['data_dir'], mode='dev', **feat_conf)
data_dir=config['data_dir'],
mode='train',
feat_type=config['feat_type'],
sample_rate=config['sample_rate'],
frame_shift=config['frame_shift'],
frame_length=config['frame_length'],
n_mels=config['n_mels'], )
dev_ds = ds_class(
data_dir=config['data_dir'],
mode='dev',
feat_type=config['feat_type'],
sample_rate=config['sample_rate'],
frame_shift=config['frame_shift'],
frame_length=config['frame_length'],
n_mels=config['n_mels'], )
train_sampler = paddle.io.DistributedBatchSampler(
train_ds,
batch_size=training_conf['batch_size'],
batch_size=config['batch_size'],
shuffle=True,
drop_last=False)
train_loader = paddle.io.DataLoader(
train_ds,
batch_sampler=train_sampler,
num_workers=training_conf['num_workers'],
num_workers=config['num_workers'],
return_list=True,
use_buffer_reader=True,
collate_fn=collate_features, )
# Model
backbone_class = dynamic_import(model_conf['backbone'])
backbone = backbone_class(**model_conf['config'])
model = KWSModel(backbone=backbone, num_keywords=model_conf['num_keywords'])
backbone_class = dynamic_import(config['backbone'])
backbone = backbone_class(
stack_num=config['stack_num'],
stack_size=config['stack_size'],
in_channels=config['in_channels'],
res_channels=config['res_channels'],
kernel_size=config['kernel_size'], )
model = KWSModel(backbone=backbone, num_keywords=config['num_keywords'])
model = paddle.DataParallel(model)
clip = paddle.nn.ClipGradByGlobalNorm(training_conf['grad_clip'])
clip = paddle.nn.ClipGradByGlobalNorm(config['grad_clip'])
optimizer = paddle.optimizer.Adam(
learning_rate=training_conf['learning_rate'],
weight_decay=training_conf['weight_decay'],
learning_rate=config['learning_rate'],
weight_decay=config['weight_decay'],
parameters=model.parameters(),
grad_clip=clip)
criterion = max_pooling_loss
steps_per_epoch = len(train_sampler)
timer = Timer(steps_per_epoch * training_conf['epochs'])
timer = Timer(steps_per_epoch * config['epochs'])
timer.start()
for epoch in range(1, training_conf['epochs'] + 1):
for epoch in range(1, config['epochs'] + 1):
model.train()
avg_loss = 0
@ -107,15 +118,13 @@ if __name__ == '__main__':
timer.count()
if (batch_idx + 1
) % training_conf['log_freq'] == 0 and local_rank == 0:
if (batch_idx + 1) % config['log_freq'] == 0 and local_rank == 0:
lr = optimizer.get_lr()
avg_loss /= training_conf['log_freq']
avg_loss /= config['log_freq']
avg_acc = num_corrects / num_samples
print_msg = 'Epoch={}/{}, Step={}/{}'.format(
epoch, training_conf['epochs'], batch_idx + 1,
steps_per_epoch)
epoch, config['epochs'], batch_idx + 1, steps_per_epoch)
print_msg += ' loss={:.4f}'.format(avg_loss)
print_msg += ' acc={:.4f}'.format(avg_acc)
print_msg += ' lr={:.6f} step/sec={:.2f} | ETA {}'.format(
@ -126,17 +135,17 @@ if __name__ == '__main__':
num_corrects = 0
num_samples = 0
if epoch % training_conf[
if epoch % config[
'save_freq'] == 0 and batch_idx + 1 == steps_per_epoch and local_rank == 0:
dev_sampler = paddle.io.BatchSampler(
dev_ds,
batch_size=training_conf['batch_size'],
batch_size=config['batch_size'],
shuffle=False,
drop_last=False)
dev_loader = paddle.io.DataLoader(
dev_ds,
batch_sampler=dev_sampler,
num_workers=training_conf['num_workers'],
num_workers=config['num_workers'],
return_list=True,
use_buffer_reader=True,
collate_fn=collate_features, )
@ -159,7 +168,7 @@ if __name__ == '__main__':
logger.eval(print_msg)
# Save model
save_dir = os.path.join(training_conf['checkpoint_dir'],
save_dir = os.path.join(config['checkpoint_dir'],
'epoch_{}'.format(epoch))
logger.info('Saving model checkpoint to {}'.format(save_dir))
paddle.save(model.state_dict(),

Loading…
Cancel
Save