From 6d92417edd57b73996cf042633ff1d06219c95f1 Mon Sep 17 00:00:00 2001 From: Haoxin Ma <745165806@qq.com> Date: Tue, 29 Jun 2021 06:05:26 +0000 Subject: [PATCH] optimize the function --- deepspeech/training/trainer.py | 5 +- deepspeech/utils/checkpoint.py | 109 +++++++++------------------------ 2 files changed, 32 insertions(+), 82 deletions(-) diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index f8668370..cd915760 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -151,11 +151,12 @@ class Trainer(): resume training. """ scratch = None - infos = self.checkpoint.load_last_parameters( + infos = self.checkpoint._load_parameters( self.model, self.optimizer, checkpoint_dir=self.checkpoint_dir, - checkpoint_path=self.args.checkpoint_path) + checkpoint_path=self.args.checkpoint_path, + checkpoint_file='checkpoint_latest') if infos: # restore from ckpt self.iteration = infos["step"] diff --git a/deepspeech/utils/checkpoint.py b/deepspeech/utils/checkpoint.py index b29ef2ab..be36fdbb 100644 --- a/deepspeech/utils/checkpoint.py +++ b/deepspeech/utils/checkpoint.py @@ -39,8 +39,8 @@ class Checkpoint(object): self.latest_n = latest_n self._save_all = (kbest_n == -1) - def should_save_best(self, metric: float) -> bool: - if not self.best_full(): + def _should_save_best(self, metric: float) -> bool: + if not self._best_full(): return True # already full @@ -49,10 +49,10 @@ class Checkpoint(object): worst_metric = self.best_records[worst_record_path] return metric < worst_metric - def best_full(self): + def _best_full(self): return (not self._save_all) and len(self.best_records) == self.kbest_n - def latest_full(self): + def _latest_full(self): return len(self.latest_records) == self.latest_n def add_checkpoint(self, @@ -63,62 +63,62 @@ class Checkpoint(object): infos, metric_type="val_loss"): if (metric_type not in infos.keys()): - self.save_parameters(checkpoint_dir, tag_or_iteration, model, + self._save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer, infos) return #save best - if self.should_save_best(infos[metric_type]): - self.save_best_checkpoint_and_update( + if self._should_save_best(infos[metric_type]): + self._save_best_checkpoint_and_update( infos[metric_type], checkpoint_dir, tag_or_iteration, model, optimizer, infos) #save latest - self.save_latest_checkpoint_and_update(checkpoint_dir, tag_or_iteration, + self._save_latest_checkpoint_and_update(checkpoint_dir, tag_or_iteration, model, optimizer, infos) if isinstance(tag_or_iteration, int): - self.save_checkpoint_record(checkpoint_dir, tag_or_iteration) + self._save_checkpoint_record(checkpoint_dir, tag_or_iteration) - def save_best_checkpoint_and_update(self, metric, checkpoint_dir, + def _save_best_checkpoint_and_update(self, metric, checkpoint_dir, tag_or_iteration, model, optimizer, infos): # remove the worst - if self.best_full(): + if self._best_full(): worst_record_path = max(self.best_records, key=self.best_records.get) self.best_records.pop(worst_record_path) if (worst_record_path not in self.latest_records): logger.info( "remove the worst checkpoint: {}".format(worst_record_path)) - self.del_checkpoint(checkpoint_dir, worst_record_path) + self._del_checkpoint(checkpoint_dir, worst_record_path) # add the new one - self.save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer, + self._save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer, infos) self.best_records[tag_or_iteration] = metric - def save_latest_checkpoint_and_update( + def _save_latest_checkpoint_and_update( self, checkpoint_dir, tag_or_iteration, model, optimizer, infos): # remove the old - if self.latest_full(): + if self._latest_full(): to_del_fn = self.latest_records.pop(0) if (to_del_fn not in self.best_records.keys()): logger.info( "remove the latest checkpoint: {}".format(to_del_fn)) - self.del_checkpoint(checkpoint_dir, to_del_fn) + self._del_checkpoint(checkpoint_dir, to_del_fn) self.latest_records.append(tag_or_iteration) - self.save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer, + self._save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer, infos) - def del_checkpoint(self, checkpoint_dir, tag_or_iteration): + def _del_checkpoint(self, checkpoint_dir, tag_or_iteration): checkpoint_path = os.path.join(checkpoint_dir, "{}".format(tag_or_iteration)) for filename in glob.glob(checkpoint_path + ".*"): os.remove(filename) logger.info("delete file: {}".format(filename)) - def load_checkpoint_idx(self, checkpoint_record: str) -> int: + def _load_checkpoint_idx(self, checkpoint_record: str) -> int: """Get the iteration number corresponding to the latest saved checkpoint. Args: checkpoint_path (str): the saved path of checkpoint. @@ -134,7 +134,7 @@ class Checkpoint(object): iteration = int(latest_checkpoint.split(":")[-1]) return iteration - def save_checkpoint_record(self, checkpoint_dir: str, iteration: int): + def _save_checkpoint_record(self, checkpoint_dir: str, iteration: int): """Save the iteration number of the latest model to be checkpoint record. Args: checkpoint_dir (str): the directory where checkpoint is saved. @@ -153,65 +153,13 @@ class Checkpoint(object): for i in self.latest_records: handle.write("model_checkpoint_path:{}\n".format(i)) - def load_last_parameters(self, - model, - optimizer=None, - checkpoint_dir=None, - checkpoint_path=None): - """Load a last model checkpoint from disk. - Args: - model (Layer): model to load parameters. - optimizer (Optimizer, optional): optimizer to load states if needed. - Defaults to None. - checkpoint_dir (str, optional): the directory where checkpoint is saved. - checkpoint_path (str, optional): if specified, load the checkpoint - stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will - be ignored. Defaults to None. - Returns: - configs (dict): epoch or step, lr and other meta info should be saved. - """ - configs = {} - - if checkpoint_path is not None: - tag = os.path.basename(checkpoint_path).split(":")[-1] - elif checkpoint_dir is not None: - checkpoint_record = os.path.join(checkpoint_dir, - "checkpoint_latest") - iteration = self.load_checkpoint_idx(checkpoint_record) - if iteration == -1: - return configs - checkpoint_path = os.path.join(checkpoint_dir, - "{}".format(iteration)) - else: - raise ValueError( - "At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!" - ) - - rank = dist.get_rank() - - params_path = checkpoint_path + ".pdparams" - model_dict = paddle.load(params_path) - model.set_state_dict(model_dict) - logger.info("Rank {}: loaded model from {}".format(rank, params_path)) - - optimizer_path = checkpoint_path + ".pdopt" - if optimizer and os.path.isfile(optimizer_path): - optimizer_dict = paddle.load(optimizer_path) - optimizer.set_state_dict(optimizer_dict) - logger.info("Rank {}: loaded optimizer state from {}".format( - rank, optimizer_path)) - - info_path = re.sub('.pdparams$', '.json', params_path) - if os.path.exists(info_path): - with open(info_path, 'r') as fin: - configs = json.load(fin) - return configs - def load_best_parameters(self, + def _load_parameters(self, model, optimizer=None, checkpoint_dir=None, - checkpoint_path=None): + checkpoint_path=None, + checkpoint_file=None): """Load a last model checkpoint from disk. Args: model (Layer): model to load parameters. @@ -221,6 +169,7 @@ class Checkpoint(object): checkpoint_path (str, optional): if specified, load the checkpoint stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will be ignored. Defaults to None. + checkpoint_file "checkpoint_latest" or "checkpoint_best" Returns: configs (dict): epoch or step, lr and other meta info should be saved. """ @@ -228,16 +177,16 @@ class Checkpoint(object): if checkpoint_path is not None: tag = os.path.basename(checkpoint_path).split(":")[-1] - elif checkpoint_dir is not None: - checkpoint_record = os.path.join(checkpoint_dir, "checkpoint_best") - iteration = self.load_checkpoint_idx(checkpoint_record) + elif checkpoint_dir is not None and checkpoint_file is not None: + checkpoint_record = os.path.join(checkpoint_dir, checkpoint_file) + iteration = self._load_checkpoint_idx(checkpoint_record) if iteration == -1: return configs checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration)) else: raise ValueError( - "At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!" + "At least one of 'checkpoint_dir' and 'checkpoint_file' and 'checkpoint_path' should be specified!" ) rank = dist.get_rank() @@ -261,7 +210,7 @@ class Checkpoint(object): return configs @mp_tools.rank_zero_only - def save_parameters(self, + def _save_parameters(self, checkpoint_dir: str, tag_or_iteration: Union[int, str], model: paddle.nn.Layer,