From 08b6213bc8b88378cb090534be74eaeb7df306ce Mon Sep 17 00:00:00 2001 From: Haoxin Ma <745165806@qq.com> Date: Wed, 30 Jun 2021 03:00:18 +0000 Subject: [PATCH] fix private function --- deepspeech/training/trainer.py | 5 +- deepspeech/utils/checkpoint.py | 114 ++++++++++++++++++++++----------- 2 files changed, 79 insertions(+), 40 deletions(-) diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index cd915760..5ebba1a9 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -151,12 +151,11 @@ class Trainer(): resume training. """ scratch = None - infos = self.checkpoint._load_parameters( + infos = self.checkpoint.load_latest_parameters( self.model, self.optimizer, checkpoint_dir=self.checkpoint_dir, - checkpoint_path=self.args.checkpoint_path, - checkpoint_file='checkpoint_latest') + checkpoint_path=self.args.checkpoint_path) if infos: # restore from ckpt self.iteration = infos["step"] diff --git a/deepspeech/utils/checkpoint.py b/deepspeech/utils/checkpoint.py index be36fdbb..000fa87b 100644 --- a/deepspeech/utils/checkpoint.py +++ b/deepspeech/utils/checkpoint.py @@ -38,23 +38,7 @@ class Checkpoint(object): self.kbest_n = kbest_n self.latest_n = latest_n self._save_all = (kbest_n == -1) - - def _should_save_best(self, metric: float) -> bool: - if not self._best_full(): - return True - - # already full - worst_record_path = max(self.best_records, key=self.best_records.get) - # worst_record_path = max(self.best_records.iteritems(), key=operator.itemgetter(1))[0] - worst_metric = self.best_records[worst_record_path] - return metric < worst_metric - - def _best_full(self): - return (not self._save_all) and len(self.best_records) == self.kbest_n - - def _latest_full(self): - return len(self.latest_records) == self.latest_n - + def add_checkpoint(self, checkpoint_dir, tag_or_iteration, @@ -64,7 +48,7 @@ class Checkpoint(object): metric_type="val_loss"): if (metric_type not in infos.keys()): self._save_parameters(checkpoint_dir, tag_or_iteration, model, - optimizer, infos) + optimizer, infos) return #save best @@ -73,15 +57,71 @@ class Checkpoint(object): infos[metric_type], checkpoint_dir, tag_or_iteration, model, optimizer, infos) #save latest - self._save_latest_checkpoint_and_update(checkpoint_dir, tag_or_iteration, - model, optimizer, infos) + self._save_latest_checkpoint_and_update( + checkpoint_dir, tag_or_iteration, model, optimizer, infos) if isinstance(tag_or_iteration, int): self._save_checkpoint_record(checkpoint_dir, tag_or_iteration) + def load_latest_parameters(self, + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None): + """Load a last model checkpoint from disk. + Args: + model (Layer): model to load parameters. + optimizer (Optimizer, optional): optimizer to load states if needed. + Defaults to None. + checkpoint_dir (str, optional): the directory where checkpoint is saved. + checkpoint_path (str, optional): if specified, load the checkpoint + stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will + be ignored. Defaults to None. + Returns: + configs (dict): epoch or step, lr and other meta info should be saved. + """ + return self._load_parameters(model, optimizer, checkpoint_dir, checkpoint_path, + "checkpoint_latest") + + def load_best_parameters(self, + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None): + """Load a last model checkpoint from disk. + Args: + model (Layer): model to load parameters. + optimizer (Optimizer, optional): optimizer to load states if needed. + Defaults to None. + checkpoint_dir (str, optional): the directory where checkpoint is saved. + checkpoint_path (str, optional): if specified, load the checkpoint + stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will + be ignored. Defaults to None. + Returns: + configs (dict): epoch or step, lr and other meta info should be saved. + """ + return self._load_parameters(model, optimizer, checkpoint_dir, checkpoint_path, + "checkpoint_best") + + def _should_save_best(self, metric: float) -> bool: + if not self._best_full(): + return True + + # already full + worst_record_path = max(self.best_records, key=self.best_records.get) + # worst_record_path = max(self.best_records.iteritems(), key=operator.itemgetter(1))[0] + worst_metric = self.best_records[worst_record_path] + return metric < worst_metric + + def _best_full(self): + return (not self._save_all) and len(self.best_records) == self.kbest_n + + def _latest_full(self): + return len(self.latest_records) == self.latest_n + def _save_best_checkpoint_and_update(self, metric, checkpoint_dir, - tag_or_iteration, model, optimizer, - infos): + tag_or_iteration, model, optimizer, + infos): # remove the worst if self._best_full(): worst_record_path = max(self.best_records, @@ -93,8 +133,8 @@ class Checkpoint(object): self._del_checkpoint(checkpoint_dir, worst_record_path) # add the new one - self._save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer, - infos) + self._save_parameters(checkpoint_dir, tag_or_iteration, model, + optimizer, infos) self.best_records[tag_or_iteration] = metric def _save_latest_checkpoint_and_update( @@ -108,8 +148,8 @@ class Checkpoint(object): self._del_checkpoint(checkpoint_dir, to_del_fn) self.latest_records.append(tag_or_iteration) - self._save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer, - infos) + self._save_parameters(checkpoint_dir, tag_or_iteration, model, + optimizer, infos) def _del_checkpoint(self, checkpoint_dir, tag_or_iteration): checkpoint_path = os.path.join(checkpoint_dir, @@ -153,13 +193,12 @@ class Checkpoint(object): for i in self.latest_records: handle.write("model_checkpoint_path:{}\n".format(i)) - def _load_parameters(self, - model, - optimizer=None, - checkpoint_dir=None, - checkpoint_path=None, - checkpoint_file=None): + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None, + checkpoint_file=None): """Load a last model checkpoint from disk. Args: model (Layer): model to load parameters. @@ -209,13 +248,14 @@ class Checkpoint(object): configs = json.load(fin) return configs + @mp_tools.rank_zero_only def _save_parameters(self, - checkpoint_dir: str, - tag_or_iteration: Union[int, str], - model: paddle.nn.Layer, - optimizer: Optimizer=None, - infos: dict=None): + checkpoint_dir: str, + tag_or_iteration: Union[int, str], + model: paddle.nn.Layer, + optimizer: Optimizer=None, + infos: dict=None): """Checkpoint the latest trained model parameters. Args: checkpoint_dir (str): the directory where checkpoint is saved.