diff --git a/.notebook/python_test.ipynb b/.notebook/python_test.ipynb index 99bbe0caa..0e6bde47f 100644 --- a/.notebook/python_test.ipynb +++ b/.notebook/python_test.ipynb @@ -637,7 +637,7 @@ { "cell_type": "code", "execution_count": 59, - "id": "norwegian-cleveland", + "id": "engaged-offense", "metadata": {}, "outputs": [ { @@ -660,7 +660,7 @@ { "cell_type": "code", "execution_count": 35, - "id": "endless-kidney", + "id": "level-fairy", "metadata": {}, "outputs": [ { @@ -705,7 +705,7 @@ { "cell_type": "code", "execution_count": 3, - "id": "later-louisiana", + "id": "beautiful-geometry", "metadata": {}, "outputs": [ { @@ -728,7 +728,7 @@ { "cell_type": "code", "execution_count": 4, - "id": "funded-nudist", + "id": "african-trustee", "metadata": {}, "outputs": [ { @@ -748,7 +748,7 @@ { "cell_type": "code", "execution_count": 5, - "id": "contrary-affiliation", + "id": "ready-wages", "metadata": {}, "outputs": [], "source": [ @@ -758,7 +758,7 @@ { "cell_type": "code", "execution_count": 6, - "id": "friendly-interpretation", + "id": "distinguished-printer", "metadata": {}, "outputs": [ { @@ -776,7 +776,7 @@ { "cell_type": "code", "execution_count": 7, - "id": "bottom-wilderness", + "id": "precious-limit", "metadata": {}, "outputs": [ { @@ -809,7 +809,7 @@ { "cell_type": "code", "execution_count": 17, - "id": "acquired-jacksonville", + "id": "chemical-convenience", "metadata": {}, "outputs": [ { @@ -839,7 +839,7 @@ { "cell_type": "code", "execution_count": 18, - "id": "entertaining-capture", + "id": "round-remark", "metadata": {}, "outputs": [ { @@ -871,7 +871,7 @@ { "cell_type": "code", "execution_count": 19, - "id": "amber-grade", + "id": "smaller-shower", "metadata": {}, "outputs": [ { @@ -903,7 +903,7 @@ { "cell_type": "code", "execution_count": 31, - "id": "hidden-playback", + "id": "integrated-block", "metadata": {}, "outputs": [ { @@ -935,7 +935,7 @@ { "cell_type": "code", "execution_count": 32, - "id": "twelve-university", + "id": "favorite-failure", "metadata": {}, "outputs": [ { @@ -966,7 +966,7 @@ { "cell_type": "code", "execution_count": 20, - "id": "minor-endorsement", + "id": "boolean-saint", "metadata": {}, "outputs": [], "source": [ @@ -977,7 +977,7 @@ { "cell_type": "code", "execution_count": 46, - "id": "upper-majority", + "id": "senior-hospital", "metadata": {}, "outputs": [ { @@ -997,7 +997,7 @@ { "cell_type": "code", "execution_count": 30, - "id": "supreme-coverage", + "id": "consolidated-incident", "metadata": {}, "outputs": [], "source": [ @@ -1007,7 +1007,7 @@ { "cell_type": "code", "execution_count": 31, - "id": "tough-domain", + "id": "pursuant-paragraph", "metadata": {}, "outputs": [], "source": [ @@ -1017,7 +1017,7 @@ { "cell_type": "code", "execution_count": 47, - "id": "indian-empire", + "id": "mexican-apollo", "metadata": {}, "outputs": [ { @@ -1038,7 +1038,7 @@ { "cell_type": "code", "execution_count": 1, - "id": "horizontal-paragraph", + "id": "encouraging-integration", "metadata": {}, "outputs": [], "source": [ @@ -1049,7 +1049,7 @@ { "cell_type": "code", "execution_count": 56, - "id": "homeless-zoning", + "id": "trying-auckland", "metadata": {}, "outputs": [], "source": [ @@ -1059,7 +1059,7 @@ { "cell_type": "code", "execution_count": 58, - "id": "floating-atmosphere", + "id": "national-edward", "metadata": {}, "outputs": [], "source": [ @@ -1069,7 +1069,7 @@ { "cell_type": "code", "execution_count": 60, - "id": "stupid-reducing", + "id": "aerial-campaign", "metadata": {}, "outputs": [], "source": [ @@ -1079,7 +1079,7 @@ { "cell_type": "code", "execution_count": 66, - "id": "practical-airline", + "id": "instant-violence", "metadata": {}, "outputs": [], "source": [ @@ -1089,7 +1089,7 @@ { "cell_type": "code", "execution_count": 95, - "id": "apart-comfort", + "id": "medical-globe", "metadata": {}, "outputs": [ { @@ -1110,7 +1110,7 @@ { "cell_type": "code", "execution_count": 81, - "id": "underlying-brand", + "id": "three-contrast", "metadata": {}, "outputs": [ { @@ -1131,7 +1131,7 @@ { "cell_type": "code", "execution_count": 11, - "id": "german-things", + "id": "cross-atlas", "metadata": {}, "outputs": [], "source": [ @@ -1161,7 +1161,7 @@ { "cell_type": "code", "execution_count": 12, - "id": "third-regression", + "id": "empirical-defense", "metadata": {}, "outputs": [], "source": [ @@ -1172,7 +1172,7 @@ { "cell_type": "code", "execution_count": 14, - "id": "thick-korea", + "id": "rocky-listening", "metadata": {}, "outputs": [ { @@ -1201,7 +1201,7 @@ { "cell_type": "code", "execution_count": 13, - "id": "institutional-hands", + "id": "surrounded-absolute", "metadata": {}, "outputs": [ { @@ -1230,7 +1230,7 @@ { "cell_type": "code", "execution_count": 15, - "id": "brave-native", + "id": "differential-surgery", "metadata": {}, "outputs": [ { @@ -1260,7 +1260,7 @@ { "cell_type": "code", "execution_count": 29, - "id": "turkish-ticket", + "id": "durable-powell", "metadata": {}, "outputs": [ { @@ -1290,7 +1290,7 @@ { "cell_type": "code", "execution_count": 30, - "id": "executed-excerpt", + "id": "young-continuity", "metadata": {}, "outputs": [ { @@ -1308,7 +1308,7 @@ { "cell_type": "code", "execution_count": 22, - "id": "continental-boring", + "id": "geological-sarah", "metadata": {}, "outputs": [ { @@ -1343,7 +1343,7 @@ { "cell_type": "code", "execution_count": 23, - "id": "linear-assembly", + "id": "possible-angle", "metadata": {}, "outputs": [ { @@ -1376,7 +1376,7 @@ { "cell_type": "code", "execution_count": 33, - "id": "applied-louis", + "id": "novel-sucking", "metadata": {}, "outputs": [], "source": [ @@ -1386,7 +1386,7 @@ { "cell_type": "code", "execution_count": 34, - "id": "historic-struggle", + "id": "fixed-wallet", "metadata": {}, "outputs": [ { @@ -1428,7 +1428,7 @@ { "cell_type": "code", "execution_count": 35, - "id": "monthly-roads", + "id": "north-seattle", "metadata": {}, "outputs": [], "source": [ @@ -1438,7 +1438,7 @@ { "cell_type": "code", "execution_count": 38, - "id": "boxed-peoples", + "id": "above-western", "metadata": {}, "outputs": [ { @@ -1471,17 +1471,99 @@ { "cell_type": "code", "execution_count": 41, - "id": "fresh-tender", + "id": "choice-diabetes", "metadata": {}, "outputs": [], "source": [ "!ls" ] }, + { + "cell_type": "code", + "execution_count": 3, + "id": "white-vessel", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(2, 2)\n", + "[ 1 20]\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "l = [(1, 20), (2, 30)]\n", + "scores = np.array(l)\n", + "print(scores.shape)\n", + "print(scores[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "treated-freedom", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0 1]\n" + ] + } + ], + "source": [ + "sort_idx = np.argsort(scores[:, -1])\n", + "print(sort_idx)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "convinced-safety", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[ 1 20]\n", + " [ 2 30]]\n" + ] + } + ], + "source": [ + "sorted_val_scores = scores[sort_idx][::1]\n", + "print(sorted_val_scores)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "blond-bunny", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[ 1 20]\n", + " [ 2 30]]\n" + ] + } + ], + "source": [ + "sorted_val_scores = scores[sort_idx]\n", + "print(sorted_val_scores)" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "religious-peripheral", + "id": "utility-monroe", "metadata": {}, "outputs": [], "source": [] diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 7c623a03f..b445a501b 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -89,9 +89,8 @@ class U2Trainer(Trainer): if (batch_index + 1) % train_conf.accum_grad == 0: if dist.get_rank() == 0 and self.visualizer: - for k, v in losses_np.items(): - self.visualizer.add_scalar("train/{}".format(k), v, - self.iteration) + losses_np.update({"lr": self.lr_scheduler()}) + self.visualizer.add_scalars("step", losses_np, self.iteration) self.optimizer.step() self.optimizer.clear_grad() self.lr_scheduler.step() @@ -144,7 +143,7 @@ class U2Trainer(Trainer): raise e valid_losses = self.valid() - self.save(infos=valid_losses) + self.save(tag=self.epoch, infos=valid_losses) self.new_epoch() @mp_tools.rank_zero_only @@ -172,9 +171,8 @@ class U2Trainer(Trainer): logger.info(msg) if self.visualizer: - for k, v in valid_losses.items(): - self.visualizer.add_scalar("valid/{}".format(k), v, - self.iteration) + valid_losses.update({"lr": self.lr_scheduler()}) + self.visualizer.add_scalars('epoch', valid_losses, self.epoch) return valid_losses def setup_dataloader(self): diff --git a/utils/avg_model.py b/utils/avg_model.py new file mode 100644 index 000000000..a8a1c0f5a --- /dev/null +++ b/utils/avg_model.py @@ -0,0 +1,96 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import glob +import json +import os + +import numpy as np +import paddle + + +def main(args): + checkpoints = [] + val_scores = [] + + if args.val_best: + jsons = glob.glob(f'{args.ckpt_dir}/[!train]*.json') + for y in jsons: + dic_json = json.load(y) + loss = dic_json['valid_loss'] + epoch = dic_json['epoch'] + if epoch >= args.min_epoch and epoch <= args.max_epoch: + val_scores.append((epoch, loss)) + + val_scores = np.array(val_scores) + sort_idx = np.argsort(val_scores[:, 1]) + sorted_val_scores = val_scores[sort_idx] + path_list = [ + args.ckpt_dir + '/{}.pdparams'.format(int(epoch)) + for epoch in sorted_val_scores[:args.num, 0] + ] + print("best val scores = " + str(sorted_val_scores[:args.num, 1])) + print("selected epochs = " + str(sorted_val_scores[:args.num, 0].astype( + np.int64))) + else: + path_list = glob.glob(f'{args.ckpt_dir}/[!avg][!final]*.pdparams') + path_list = sorted(path_list, key=os.path.getmtime) + path_list = path_list[-args.num:] + + print(path_list) + + avg = None + num = args.num + assert num == len(path_list) + for path in path_list: + print(f'Processing {path}') + states = paddle.load(path) + if avg is None: + avg = states + else: + for k in avg.keys(): + avg[k] += states[k] + # average + for k in avg.keys(): + if avg[k] is not None: + avg[k] = paddle.divide(avg[k], num) + + paddle.save(avg, args.dst_model) + print(f'Saving to {args.dst_model}') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='average model') + parser.add_argument('--dst_model', required=True, help='averaged model') + parser.add_argument( + '--ckpt_dir', required=True, help='ckpt model dir for average') + parser.add_argument( + '--val_best', action="store_true", help='averaged model') + parser.add_argument( + '--num', default=5, type=int, help='nums for averaged model') + parser.add_argument( + '--min_epoch', + default=0, + type=int, + help='min epoch used for averaging model') + parser.add_argument( + '--max_epoch', + default=65536, # Big enough + type=int, + help='max epoch used for averaging model') + + args = parser.parse_args() + print(args) + + main(args)