|
|
@ -91,7 +91,7 @@ if __name__ == '__main__':
|
|
|
|
steps_per_epoch = len(train_sampler)
|
|
|
|
steps_per_epoch = len(train_sampler)
|
|
|
|
timer = Timer(steps_per_epoch * config['epochs'])
|
|
|
|
timer = Timer(steps_per_epoch * config['epochs'])
|
|
|
|
timer.start()
|
|
|
|
timer.start()
|
|
|
|
|
|
|
|
config['epochs'] = args.benchmark_max_step
|
|
|
|
for epoch in range(1, config['epochs'] + 1):
|
|
|
|
for epoch in range(1, config['epochs'] + 1):
|
|
|
|
model.train()
|
|
|
|
model.train()
|
|
|
|
|
|
|
|
|
|
|
@ -137,33 +137,9 @@ if __name__ == '__main__':
|
|
|
|
|
|
|
|
|
|
|
|
if epoch % config[
|
|
|
|
if epoch % config[
|
|
|
|
'save_freq'] == 0 and batch_idx + 1 == steps_per_epoch and local_rank == 0:
|
|
|
|
'save_freq'] == 0 and batch_idx + 1 == steps_per_epoch and local_rank == 0:
|
|
|
|
dev_sampler = paddle.io.BatchSampler(
|
|
|
|
|
|
|
|
dev_ds,
|
|
|
|
|
|
|
|
batch_size=config['batch_size'],
|
|
|
|
|
|
|
|
shuffle=False,
|
|
|
|
|
|
|
|
drop_last=False)
|
|
|
|
|
|
|
|
dev_loader = paddle.io.DataLoader(
|
|
|
|
|
|
|
|
dev_ds,
|
|
|
|
|
|
|
|
batch_sampler=dev_sampler,
|
|
|
|
|
|
|
|
num_workers=config['num_workers'],
|
|
|
|
|
|
|
|
return_list=True,
|
|
|
|
|
|
|
|
use_buffer_reader=True,
|
|
|
|
|
|
|
|
collate_fn=collate_features, )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
num_corrects = 0
|
|
|
|
num_corrects = 0
|
|
|
|
num_samples = 0
|
|
|
|
num_samples = 0
|
|
|
|
with logger.processing('Evaluation on validation dataset'):
|
|
|
|
|
|
|
|
for batch_idx, batch in enumerate(dev_loader):
|
|
|
|
|
|
|
|
keys, feats, labels, lengths = batch
|
|
|
|
|
|
|
|
logits = model(feats)
|
|
|
|
|
|
|
|
loss, corrects, acc = criterion(logits, labels, lengths)
|
|
|
|
|
|
|
|
num_corrects += corrects
|
|
|
|
|
|
|
|
num_samples += feats.shape[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eval_acc = num_corrects / num_samples
|
|
|
|
|
|
|
|
print_msg = '[Evaluation result]'
|
|
|
|
|
|
|
|
print_msg += ' dev_acc={:.4f}'.format(eval_acc)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.eval(print_msg)
|
|
|
|
logger.eval(print_msg)
|
|
|
|
|
|
|
|
|
|
|
|