From f2a42bd30fb391a36a624c6c5afc0430d1ec1d42 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 28 Apr 2021 09:18:13 +0000 Subject: [PATCH] more avg and test info --- deepspeech/__init__.py | 1 + deepspeech/exps/u2/model.py | 27 +++++++++++++++++++++++++++ examples/aishell/s1/local/avg.sh | 4 ++-- utils/avg_model.py | 29 +++++++++++++++++++++-------- 4 files changed, 51 insertions(+), 10 deletions(-) diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py index 1c11de385..1a6690d7a 100644 --- a/deepspeech/__init__.py +++ b/deepspeech/__init__.py @@ -125,6 +125,7 @@ if not hasattr(paddle, 'cat'): def item(x: paddle.Tensor): return x.numpy().item() + if not hasattr(paddle.Tensor, 'item'): logger.warn( "override item of paddle.Tensor if exists or register, remove this when fixed!" diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 5b1ed7e40..ce0b0a0ae 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Contains U2 model.""" +import json +import os import sys import time from collections import defaultdict @@ -439,6 +441,31 @@ class U2Tester(U2Trainer): error_rate_type, num_ins, num_ins, errors_sum / len_refs) logger.info(msg) + # test meta results + err_meta_path = os.path.splitext(self.args.checkpoint_path)[0] + '.err' + err_type_str = "{}".format(error_rate_type) + with open(err_meta_path, 'w') as f: + data = json.dumps({ + "epoch": + self.epoch, + "step": + self.iteration, + "rtf": + rtf, + error_rate_type: + errors_sum / len_refs, + "dataset_hour": (num_frames * stride_ms) / 1000.0 / 3600.0, + "process_hour": + num_time / 1000.0 / 3600.0, + "num_examples": + num_ins, + "err_sum": + errors_sum, + "ref_len": + len_refs, + }) + f.write(data + '\n') + def run_test(self): self.resume_or_scratch() try: diff --git a/examples/aishell/s1/local/avg.sh b/examples/aishell/s1/local/avg.sh index 7e2befee6..17dbd87b1 100644 --- a/examples/aishell/s1/local/avg.sh +++ b/examples/aishell/s1/local/avg.sh @@ -7,7 +7,7 @@ fi ckpt_path=${1} average_num=${2} -decode_checkpoint=${ckpt_path}/avg_${average_num}.pt +decode_checkpoint=${ckpt_path}/avg_${average_num}.pdparams python3 -u ${MAIN_ROOT}/utils/avg_model.py \ --dst_model ${decode_checkpoint} \ @@ -21,4 +21,4 @@ if [ $? -ne 0 ]; then fi -exit 0 \ No newline at end of file +exit 0 diff --git a/utils/avg_model.py b/utils/avg_model.py index a8a1c0f5a..a002c1b0d 100644 --- a/utils/avg_model.py +++ b/utils/avg_model.py @@ -21,14 +21,15 @@ import paddle def main(args): - checkpoints = [] val_scores = [] - + beat_val_scores = [] + selected_epochs = [] if args.val_best: jsons = glob.glob(f'{args.ckpt_dir}/[!train]*.json') for y in jsons: - dic_json = json.load(y) - loss = dic_json['valid_loss'] + with open(y, 'r') as f: + dic_json = json.load(f) + loss = dic_json['val_loss'] epoch = dic_json['epoch'] if epoch >= args.min_epoch and epoch <= args.max_epoch: val_scores.append((epoch, loss)) @@ -40,9 +41,11 @@ def main(args): args.ckpt_dir + '/{}.pdparams'.format(int(epoch)) for epoch in sorted_val_scores[:args.num, 0] ] - print("best val scores = " + str(sorted_val_scores[:args.num, 1])) - print("selected epochs = " + str(sorted_val_scores[:args.num, 0].astype( - np.int64))) + + beat_val_scores = sorted_val_scores[:args.num, 1] + selected_epochs = sorted_val_scores[:args.num, 0].astype(np.int64) + print("best val scores = " + str(beat_val_scores)) + print("selected epochs = " + str(selected_epochs)) else: path_list = glob.glob(f'{args.ckpt_dir}/[!avg][!final]*.pdparams') path_list = sorted(path_list, key=os.path.getmtime) @@ -64,11 +67,21 @@ def main(args): # average for k in avg.keys(): if avg[k] is not None: - avg[k] = paddle.divide(avg[k], num) + avg[k] /= num paddle.save(avg, args.dst_model) print(f'Saving to {args.dst_model}') + meta_path = os.path.splitext(args.dst_model)[0] + '.avg.json' + with open(meta_path, 'w') as f: + data = json.dumps({ + "avg_ckpt": args.dst_model, + "ckpt": path_list, + "epoch": selected_epochs.tolist(), + "val_loss": beat_val_scores.tolist(), + }) + f.write(data + "\n") + if __name__ == '__main__': parser = argparse.ArgumentParser(description='average model')