You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/paddlespeech/kws/exps/mdtc/train.py

169 lines
6.2 KiB

2 years ago
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2 years ago
import argparse
2 years ago
import os
import paddle
2 years ago
import yaml
2 years ago
from paddleaudio.utils import logger
from paddleaudio.utils import Timer
2 years ago
from paddlespeech.kws.exps.mdtc.collate import collate_features
from paddlespeech.kws.models.loss import max_pooling_loss
from paddlespeech.kws.models.mdtc import KWSModel
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
2 years ago
2 years ago
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--cfg_path", type=str, required=True)
args = parser.parse_args()
# yapf: enable
2 years ago
2 years ago
if __name__ == '__main__':
nranks = paddle.distributed.get_world_size()
if paddle.distributed.get_world_size() > 1:
paddle.distributed.init_parallel_env()
local_rank = paddle.distributed.get_rank()
2 years ago
2 years ago
args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path))
with open(args.cfg_path, 'r') as f:
config = yaml.safe_load(f)
2 years ago
2 years ago
model_conf = config['model']
data_conf = config['data']
feat_conf = config['feature']
training_conf = config['training']
2 years ago
# Dataset
2 years ago
ds_class = dynamic_import(data_conf['dataset'])
train_ds = ds_class(
data_dir=data_conf['data_dir'], mode='train', **feat_conf)
dev_ds = ds_class(data_dir=data_conf['data_dir'], mode='dev', **feat_conf)
train_sampler = paddle.io.DistributedBatchSampler(
2 years ago
train_ds,
batch_size=training_conf['batch_size'],
shuffle=True,
drop_last=False)
train_loader = paddle.io.DataLoader(
train_ds,
batch_sampler=train_sampler,
num_workers=training_conf['num_workers'],
return_list=True,
use_buffer_reader=True,
collate_fn=collate_features, )
# Model
2 years ago
backbone_class = dynamic_import(model_conf['backbone'])
backbone = backbone_class(**model_conf['config'])
model = KWSModel(backbone=backbone, num_keywords=model_conf['num_keywords'])
2 years ago
model = paddle.DataParallel(model)
2 years ago
clip = paddle.nn.ClipGradByGlobalNorm(training_conf['grad_clip'])
2 years ago
optimizer = paddle.optimizer.Adam(
learning_rate=training_conf['learning_rate'],
weight_decay=training_conf['weight_decay'],
parameters=model.parameters(),
grad_clip=clip)
criterion = max_pooling_loss
steps_per_epoch = len(train_sampler)
timer = Timer(steps_per_epoch * training_conf['epochs'])
timer.start()
for epoch in range(1, training_conf['epochs'] + 1):
model.train()
avg_loss = 0
num_corrects = 0
num_samples = 0
for batch_idx, batch in enumerate(train_loader):
keys, feats, labels, lengths = batch
logits = model(feats)
loss, corrects, acc = criterion(logits, labels, lengths)
loss.backward()
optimizer.step()
if isinstance(optimizer._learning_rate,
paddle.optimizer.lr.LRScheduler):
optimizer._learning_rate.step()
optimizer.clear_grad()
# Calculate loss
avg_loss += loss.numpy()[0]
# Calculate metrics
num_corrects += corrects
num_samples += feats.shape[0]
timer.count()
2 years ago
if (batch_idx + 1
) % training_conf['log_freq'] == 0 and local_rank == 0:
2 years ago
lr = optimizer.get_lr()
avg_loss /= training_conf['log_freq']
avg_acc = num_corrects / num_samples
print_msg = 'Epoch={}/{}, Step={}/{}'.format(
epoch, training_conf['epochs'], batch_idx + 1,
steps_per_epoch)
print_msg += ' loss={:.4f}'.format(avg_loss)
print_msg += ' acc={:.4f}'.format(avg_acc)
print_msg += ' lr={:.6f} step/sec={:.2f} | ETA {}'.format(
lr, timer.timing, timer.eta)
logger.train(print_msg)
avg_loss = 0
num_corrects = 0
num_samples = 0
if epoch % training_conf[
2 years ago
'save_freq'] == 0 and batch_idx + 1 == steps_per_epoch and local_rank == 0:
2 years ago
dev_sampler = paddle.io.BatchSampler(
dev_ds,
batch_size=training_conf['batch_size'],
shuffle=False,
drop_last=False)
dev_loader = paddle.io.DataLoader(
dev_ds,
batch_sampler=dev_sampler,
num_workers=training_conf['num_workers'],
return_list=True,
use_buffer_reader=True,
collate_fn=collate_features, )
model.eval()
num_corrects = 0
num_samples = 0
with logger.processing('Evaluation on validation dataset'):
for batch_idx, batch in enumerate(dev_loader):
keys, feats, labels, lengths = batch
logits = model(feats)
loss, corrects, acc = criterion(logits, labels, lengths)
num_corrects += corrects
num_samples += feats.shape[0]
eval_acc = num_corrects / num_samples
print_msg = '[Evaluation result]'
print_msg += ' dev_acc={:.4f}'.format(eval_acc)
logger.eval(print_msg)
# Save model
save_dir = os.path.join(training_conf['checkpoint_dir'],
2 years ago
'epoch_{}'.format(epoch))
2 years ago
logger.info('Saving model checkpoint to {}'.format(save_dir))
paddle.save(model.state_dict(),
os.path.join(save_dir, 'model.pdparams'))
paddle.save(optimizer.state_dict(),
os.path.join(save_dir, 'model.pdopt'))