avg model dump val loss mean

pull/928/head
Hui Zhang 3 years ago
parent a4e27da64b
commit ee6446a3aa

@ -25,8 +25,8 @@ def main(args):
paddle.set_device('cpu')
val_scores = []
beat_val_scores = []
selected_epochs = []
beat_val_scores = None
selected_epochs = None
jsons = glob.glob(f'{args.ckpt_dir}/[!train]*.json')
jsons = sorted(jsons, key=os.path.getmtime, reverse=True)
@ -80,9 +80,10 @@ def main(args):
data = json.dumps({
"mode": 'val_best' if args.val_best else 'latest',
"avg_ckpt": args.dst_model,
"ckpt": path_list,
"epoch": selected_epochs.tolist(),
"val_loss": beat_val_scores.tolist(),
"val_loss_mean": np.mean(beat_val_scores),
"ckpts": path_list,
"epochs": selected_epochs.tolist(),
"val_losses": beat_val_scores.tolist(),
})
f.write(data + "\n")

Loading…
Cancel
Save