more avg and test info

pull/578/head
Hui Zhang 4 years ago
parent c693bb0829
commit f2a42bd30f

@ -125,6 +125,7 @@ if not hasattr(paddle, 'cat'):
def item(x: paddle.Tensor): def item(x: paddle.Tensor):
return x.numpy().item() return x.numpy().item()
if not hasattr(paddle.Tensor, 'item'): if not hasattr(paddle.Tensor, 'item'):
logger.warn( logger.warn(
"override item of paddle.Tensor if exists or register, remove this when fixed!" "override item of paddle.Tensor if exists or register, remove this when fixed!"

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Contains U2 model.""" """Contains U2 model."""
import json
import os
import sys import sys
import time import time
from collections import defaultdict from collections import defaultdict
@ -439,6 +441,31 @@ class U2Tester(U2Trainer):
error_rate_type, num_ins, num_ins, errors_sum / len_refs) error_rate_type, num_ins, num_ins, errors_sum / len_refs)
logger.info(msg) logger.info(msg)
# test meta results
err_meta_path = os.path.splitext(self.args.checkpoint_path)[0] + '.err'
err_type_str = "{}".format(error_rate_type)
with open(err_meta_path, 'w') as f:
data = json.dumps({
"epoch":
self.epoch,
"step":
self.iteration,
"rtf":
rtf,
error_rate_type:
errors_sum / len_refs,
"dataset_hour": (num_frames * stride_ms) / 1000.0 / 3600.0,
"process_hour":
num_time / 1000.0 / 3600.0,
"num_examples":
num_ins,
"err_sum":
errors_sum,
"ref_len":
len_refs,
})
f.write(data + '\n')
def run_test(self): def run_test(self):
self.resume_or_scratch() self.resume_or_scratch()
try: try:

@ -7,7 +7,7 @@ fi
ckpt_path=${1} ckpt_path=${1}
average_num=${2} average_num=${2}
decode_checkpoint=${ckpt_path}/avg_${average_num}.pt decode_checkpoint=${ckpt_path}/avg_${average_num}.pdparams
python3 -u ${MAIN_ROOT}/utils/avg_model.py \ python3 -u ${MAIN_ROOT}/utils/avg_model.py \
--dst_model ${decode_checkpoint} \ --dst_model ${decode_checkpoint} \

@ -21,14 +21,15 @@ import paddle
def main(args): def main(args):
checkpoints = []
val_scores = [] val_scores = []
beat_val_scores = []
selected_epochs = []
if args.val_best: if args.val_best:
jsons = glob.glob(f'{args.ckpt_dir}/[!train]*.json') jsons = glob.glob(f'{args.ckpt_dir}/[!train]*.json')
for y in jsons: for y in jsons:
dic_json = json.load(y) with open(y, 'r') as f:
loss = dic_json['valid_loss'] dic_json = json.load(f)
loss = dic_json['val_loss']
epoch = dic_json['epoch'] epoch = dic_json['epoch']
if epoch >= args.min_epoch and epoch <= args.max_epoch: if epoch >= args.min_epoch and epoch <= args.max_epoch:
val_scores.append((epoch, loss)) val_scores.append((epoch, loss))
@ -40,9 +41,11 @@ def main(args):
args.ckpt_dir + '/{}.pdparams'.format(int(epoch)) args.ckpt_dir + '/{}.pdparams'.format(int(epoch))
for epoch in sorted_val_scores[:args.num, 0] for epoch in sorted_val_scores[:args.num, 0]
] ]
print("best val scores = " + str(sorted_val_scores[:args.num, 1]))
print("selected epochs = " + str(sorted_val_scores[:args.num, 0].astype( beat_val_scores = sorted_val_scores[:args.num, 1]
np.int64))) selected_epochs = sorted_val_scores[:args.num, 0].astype(np.int64)
print("best val scores = " + str(beat_val_scores))
print("selected epochs = " + str(selected_epochs))
else: else:
path_list = glob.glob(f'{args.ckpt_dir}/[!avg][!final]*.pdparams') path_list = glob.glob(f'{args.ckpt_dir}/[!avg][!final]*.pdparams')
path_list = sorted(path_list, key=os.path.getmtime) path_list = sorted(path_list, key=os.path.getmtime)
@ -64,11 +67,21 @@ def main(args):
# average # average
for k in avg.keys(): for k in avg.keys():
if avg[k] is not None: if avg[k] is not None:
avg[k] = paddle.divide(avg[k], num) avg[k] /= num
paddle.save(avg, args.dst_model) paddle.save(avg, args.dst_model)
print(f'Saving to {args.dst_model}') print(f'Saving to {args.dst_model}')
meta_path = os.path.splitext(args.dst_model)[0] + '.avg.json'
with open(meta_path, 'w') as f:
data = json.dumps({
"avg_ckpt": args.dst_model,
"ckpt": path_list,
"epoch": selected_epochs.tolist(),
"val_loss": beat_val_scores.tolist(),
})
f.write(data + "\n")
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='average model') parser = argparse.ArgumentParser(description='average model')

Loading…
Cancel
Save