From ee6446a3aa625de5bcf8ea1523fa6dd112c9526e Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Sat, 23 Oct 2021 15:31:43 +0000 Subject: [PATCH] avg model dump val loss mean --- utils/avg_model.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/utils/avg_model.py b/utils/avg_model.py index 1fc00cb6..7c05ec78 100755 --- a/utils/avg_model.py +++ b/utils/avg_model.py @@ -25,8 +25,8 @@ def main(args): paddle.set_device('cpu') val_scores = [] - beat_val_scores = [] - selected_epochs = [] + beat_val_scores = None + selected_epochs = None jsons = glob.glob(f'{args.ckpt_dir}/[!train]*.json') jsons = sorted(jsons, key=os.path.getmtime, reverse=True) @@ -80,9 +80,10 @@ def main(args): data = json.dumps({ "mode": 'val_best' if args.val_best else 'latest', "avg_ckpt": args.dst_model, - "ckpt": path_list, - "epoch": selected_epochs.tolist(), - "val_loss": beat_val_scores.tolist(), + "val_loss_mean": np.mean(beat_val_scores), + "ckpts": path_list, + "epochs": selected_epochs.tolist(), + "val_losses": beat_val_scores.tolist(), }) f.write(data + "\n")