diff --git a/paddlespeech/kws/exps/mdtc/train.py b/paddlespeech/kws/exps/mdtc/train.py index 94e45d590..02cd00a7c 100644 --- a/paddlespeech/kws/exps/mdtc/train.py +++ b/paddlespeech/kws/exps/mdtc/train.py @@ -91,7 +91,7 @@ if __name__ == '__main__': steps_per_epoch = len(train_sampler) timer = Timer(steps_per_epoch * config['epochs']) timer.start() - + config['epochs'] = args.benchmark_max_step for epoch in range(1, config['epochs'] + 1): model.train() @@ -137,33 +137,9 @@ if __name__ == '__main__': if epoch % config[ 'save_freq'] == 0 and batch_idx + 1 == steps_per_epoch and local_rank == 0: - dev_sampler = paddle.io.BatchSampler( - dev_ds, - batch_size=config['batch_size'], - shuffle=False, - drop_last=False) - dev_loader = paddle.io.DataLoader( - dev_ds, - batch_sampler=dev_sampler, - num_workers=config['num_workers'], - return_list=True, - use_buffer_reader=True, - collate_fn=collate_features, ) - - model.eval() + num_corrects = 0 num_samples = 0 - with logger.processing('Evaluation on validation dataset'): - for batch_idx, batch in enumerate(dev_loader): - keys, feats, labels, lengths = batch - logits = model(feats) - loss, corrects, acc = criterion(logits, labels, lengths) - num_corrects += corrects - num_samples += feats.shape[0] - - eval_acc = num_corrects / num_samples - print_msg = '[Evaluation result]' - print_msg += ' dev_acc={:.4f}'.format(eval_acc) logger.eval(print_msg)