Update train.py

1.benchmark needn't eval 
2.let benchmark test can modify the training epoch
pull/2386/head
Zhao Yuting 3 years ago committed by GitHub
parent 324b166c52
commit 92b00cf61e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -91,7 +91,7 @@ if __name__ == '__main__':
steps_per_epoch = len(train_sampler) steps_per_epoch = len(train_sampler)
timer = Timer(steps_per_epoch * config['epochs']) timer = Timer(steps_per_epoch * config['epochs'])
timer.start() timer.start()
config['epochs'] = args.benchmark_max_step
for epoch in range(1, config['epochs'] + 1): for epoch in range(1, config['epochs'] + 1):
model.train() model.train()
@ -137,33 +137,9 @@ if __name__ == '__main__':
if epoch % config[ if epoch % config[
'save_freq'] == 0 and batch_idx + 1 == steps_per_epoch and local_rank == 0: 'save_freq'] == 0 and batch_idx + 1 == steps_per_epoch and local_rank == 0:
dev_sampler = paddle.io.BatchSampler(
dev_ds,
batch_size=config['batch_size'],
shuffle=False,
drop_last=False)
dev_loader = paddle.io.DataLoader(
dev_ds,
batch_sampler=dev_sampler,
num_workers=config['num_workers'],
return_list=True,
use_buffer_reader=True,
collate_fn=collate_features, )
model.eval()
num_corrects = 0 num_corrects = 0
num_samples = 0 num_samples = 0
with logger.processing('Evaluation on validation dataset'):
for batch_idx, batch in enumerate(dev_loader):
keys, feats, labels, lengths = batch
logits = model(feats)
loss, corrects, acc = criterion(logits, labels, lengths)
num_corrects += corrects
num_samples += feats.shape[0]
eval_acc = num_corrects / num_samples
print_msg = '[Evaluation result]'
print_msg += ' dev_acc={:.4f}'.format(eval_acc)
logger.eval(print_msg) logger.eval(print_msg)

Loading…
Cancel
Save