From f26161d1174a60c3f5da36a233a5a4e8ab3390b8 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Fri, 21 Apr 2023 11:21:08 +0000 Subject: [PATCH] fix avg ckpts --- paddlespeech/dataset/s2t/avg_model.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/paddlespeech/dataset/s2t/avg_model.py b/paddlespeech/dataset/s2t/avg_model.py index 99111ccc1..c5753b726 100755 --- a/paddlespeech/dataset/s2t/avg_model.py +++ b/paddlespeech/dataset/s2t/avg_model.py @@ -53,36 +53,34 @@ def average_checkpoints(dst_model="", paddle.set_device('cpu') val_scores = [] - beat_val_scores = None - selected_epochs = None - - jsons = glob.glob(f'{args.ckpt_dir}/[!train]*.json') + jsons = glob.glob(f'{ckpt_dir}/[!train]*.json') jsons = sorted(jsons, key=os.path.getmtime, reverse=True) for y in jsons: with open(y, 'r') as f: dic_json = json.load(f) loss = dic_json['val_loss'] epoch = dic_json['epoch'] - if epoch >= args.min_epoch and epoch <= args.max_epoch: + if epoch >= min_epoch and epoch <= max_epoch: val_scores.append((epoch, loss)) + assert val_scores, f"Not find any valid checkpoints: {val_scores}" val_scores = np.array(val_scores) - if args.val_best: + if val_best: sort_idx = np.argsort(val_scores[:, 1]) sorted_val_scores = val_scores[sort_idx] else: sorted_val_scores = val_scores - beat_val_scores = sorted_val_scores[:args.num, 1] - selected_epochs = sorted_val_scores[:args.num, 0].astype(np.int64) + beat_val_scores = sorted_val_scores[:num, 1] + selected_epochs = sorted_val_scores[:num, 0].astype(np.int64) avg_val_score = np.mean(beat_val_scores) print("selected val scores = " + str(beat_val_scores)) print("selected epochs = " + str(selected_epochs)) print("averaged val score = " + str(avg_val_score)) path_list = [ - args.ckpt_dir + '/{}.pdparams'.format(int(epoch)) - for epoch in sorted_val_scores[:args.num, 0] + ckpt_dir + '/{}.pdparams'.format(int(epoch)) + for epoch in sorted_val_scores[:num, 0] ] print(path_list)