You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
149 lines
5.9 KiB
149 lines
5.9 KiB
3 years ago
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
import argparse
|
||
|
import os
|
||
|
|
||
|
import paddle
|
||
|
from model import SoundClassifier
|
||
|
from paddleaudio.datasets import ESC50
|
||
|
from paddleaudio.models.panns import cnn14
|
||
|
from paddleaudio.utils import logger
|
||
|
from paddleaudio.utils import Timer
|
||
|
|
||
|
# yapf: disable
|
||
|
parser = argparse.ArgumentParser(__doc__)
|
||
|
parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to train model, defaults to gpu.")
|
||
|
parser.add_argument("--epochs", type=int, default=50, help="Number of epoches for fine-tuning.")
|
||
|
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.")
|
||
|
parser.add_argument("--batch_size", type=int, default=16, help="Total examples' number in batch for training.")
|
||
|
parser.add_argument("--num_workers", type=int, default=0, help="Number of workers in dataloader.")
|
||
|
parser.add_argument("--checkpoint_dir", type=str, default='./checkpoint', help="Directory to save model checkpoints.")
|
||
|
parser.add_argument("--save_freq", type=int, default=10, help="Save checkpoint every n epoch.")
|
||
|
parser.add_argument("--log_freq", type=int, default=10, help="Log the training infomation every n steps.")
|
||
|
args = parser.parse_args()
|
||
|
# yapf: enable
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
paddle.set_device(args.device)
|
||
|
nranks = paddle.distributed.get_world_size()
|
||
|
if paddle.distributed.get_world_size() > 1:
|
||
|
paddle.distributed.init_parallel_env()
|
||
|
local_rank = paddle.distributed.get_rank()
|
||
|
|
||
|
backbone = cnn14(pretrained=True, extract_embedding=True)
|
||
|
model = SoundClassifier(backbone, num_class=len(ESC50.label_list))
|
||
|
model = paddle.DataParallel(model)
|
||
|
optimizer = paddle.optimizer.Adam(
|
||
|
learning_rate=args.learning_rate, parameters=model.parameters())
|
||
|
criterion = paddle.nn.loss.CrossEntropyLoss()
|
||
|
|
||
|
train_ds = ESC50(mode='train', feat_type='melspectrogram')
|
||
|
dev_ds = ESC50(mode='dev', feat_type='melspectrogram')
|
||
|
|
||
|
train_sampler = paddle.io.DistributedBatchSampler(
|
||
|
train_ds, batch_size=args.batch_size, shuffle=True, drop_last=False)
|
||
|
train_loader = paddle.io.DataLoader(
|
||
|
train_ds,
|
||
|
batch_sampler=train_sampler,
|
||
|
num_workers=args.num_workers,
|
||
|
return_list=True,
|
||
|
use_buffer_reader=True, )
|
||
|
|
||
|
steps_per_epoch = len(train_sampler)
|
||
|
timer = Timer(steps_per_epoch * args.epochs)
|
||
|
timer.start()
|
||
|
|
||
|
for epoch in range(1, args.epochs + 1):
|
||
|
model.train()
|
||
|
|
||
|
avg_loss = 0
|
||
|
num_corrects = 0
|
||
|
num_samples = 0
|
||
|
for batch_idx, batch in enumerate(train_loader):
|
||
|
feats, labels = batch
|
||
|
logits = model(feats)
|
||
|
|
||
|
loss = criterion(logits, labels)
|
||
|
loss.backward()
|
||
|
optimizer.step()
|
||
|
if isinstance(optimizer._learning_rate,
|
||
|
paddle.optimizer.lr.LRScheduler):
|
||
|
optimizer._learning_rate.step()
|
||
|
optimizer.clear_grad()
|
||
|
|
||
|
# Calculate loss
|
||
|
avg_loss += loss.numpy()[0]
|
||
|
|
||
|
# Calculate metrics
|
||
|
preds = paddle.argmax(logits, axis=1)
|
||
|
num_corrects += (preds == labels).numpy().sum()
|
||
|
num_samples += feats.shape[0]
|
||
|
|
||
|
timer.count()
|
||
|
|
||
|
if (batch_idx + 1) % args.log_freq == 0 and local_rank == 0:
|
||
|
lr = optimizer.get_lr()
|
||
|
avg_loss /= args.log_freq
|
||
|
avg_acc = num_corrects / num_samples
|
||
|
|
||
|
print_msg = 'Epoch={}/{}, Step={}/{}'.format(
|
||
|
epoch, args.epochs, batch_idx + 1, steps_per_epoch)
|
||
|
print_msg += ' loss={:.4f}'.format(avg_loss)
|
||
|
print_msg += ' acc={:.4f}'.format(avg_acc)
|
||
|
print_msg += ' lr={:.6f} step/sec={:.2f} | ETA {}'.format(
|
||
|
lr, timer.timing, timer.eta)
|
||
|
logger.train(print_msg)
|
||
|
|
||
|
avg_loss = 0
|
||
|
num_corrects = 0
|
||
|
num_samples = 0
|
||
|
|
||
|
if epoch % args.save_freq == 0 and batch_idx + 1 == steps_per_epoch and local_rank == 0:
|
||
|
dev_sampler = paddle.io.BatchSampler(
|
||
|
dev_ds,
|
||
|
batch_size=args.batch_size,
|
||
|
shuffle=False,
|
||
|
drop_last=False)
|
||
|
dev_loader = paddle.io.DataLoader(
|
||
|
dev_ds,
|
||
|
batch_sampler=dev_sampler,
|
||
|
num_workers=args.num_workers,
|
||
|
return_list=True, )
|
||
|
|
||
|
model.eval()
|
||
|
num_corrects = 0
|
||
|
num_samples = 0
|
||
|
with logger.processing('Evaluation on validation dataset'):
|
||
|
for batch_idx, batch in enumerate(dev_loader):
|
||
|
feats, labels = batch
|
||
|
logits = model(feats)
|
||
|
|
||
|
preds = paddle.argmax(logits, axis=1)
|
||
|
num_corrects += (preds == labels).numpy().sum()
|
||
|
num_samples += feats.shape[0]
|
||
|
|
||
|
print_msg = '[Evaluation result]'
|
||
|
print_msg += ' dev_acc={:.4f}'.format(num_corrects / num_samples)
|
||
|
|
||
|
logger.eval(print_msg)
|
||
|
|
||
|
# Save model
|
||
|
save_dir = os.path.join(args.checkpoint_dir,
|
||
|
'epoch_{}'.format(epoch))
|
||
|
logger.info('Saving model checkpoint to {}'.format(save_dir))
|
||
|
paddle.save(model.state_dict(),
|
||
|
os.path.join(save_dir, 'model.pdparams'))
|
||
|
paddle.save(optimizer.state_dict(),
|
||
|
os.path.join(save_dir, 'model.pdopt'))
|