fix private function

pull/680/head
Haoxin Ma 3 years ago
parent 6d92417edd
commit 08b6213bc8

@ -151,12 +151,11 @@ class Trainer():
resume training.
"""
scratch = None
infos = self.checkpoint._load_parameters(
infos = self.checkpoint.load_latest_parameters(
self.model,
self.optimizer,
checkpoint_dir=self.checkpoint_dir,
checkpoint_path=self.args.checkpoint_path,
checkpoint_file='checkpoint_latest')
checkpoint_path=self.args.checkpoint_path)
if infos:
# restore from ckpt
self.iteration = infos["step"]

@ -39,22 +39,6 @@ class Checkpoint(object):
self.latest_n = latest_n
self._save_all = (kbest_n == -1)
def _should_save_best(self, metric: float) -> bool:
if not self._best_full():
return True
# already full
worst_record_path = max(self.best_records, key=self.best_records.get)
# worst_record_path = max(self.best_records.iteritems(), key=operator.itemgetter(1))[0]
worst_metric = self.best_records[worst_record_path]
return metric < worst_metric
def _best_full(self):
return (not self._save_all) and len(self.best_records) == self.kbest_n
def _latest_full(self):
return len(self.latest_records) == self.latest_n
def add_checkpoint(self,
checkpoint_dir,
tag_or_iteration,
@ -73,12 +57,68 @@ class Checkpoint(object):
infos[metric_type], checkpoint_dir, tag_or_iteration, model,
optimizer, infos)
#save latest
self._save_latest_checkpoint_and_update(checkpoint_dir, tag_or_iteration,
model, optimizer, infos)
self._save_latest_checkpoint_and_update(
checkpoint_dir, tag_or_iteration, model, optimizer, infos)
if isinstance(tag_or_iteration, int):
self._save_checkpoint_record(checkpoint_dir, tag_or_iteration)
def load_latest_parameters(self,
model,
optimizer=None,
checkpoint_dir=None,
checkpoint_path=None):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
return self._load_parameters(model, optimizer, checkpoint_dir, checkpoint_path,
"checkpoint_latest")
def load_best_parameters(self,
model,
optimizer=None,
checkpoint_dir=None,
checkpoint_path=None):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
return self._load_parameters(model, optimizer, checkpoint_dir, checkpoint_path,
"checkpoint_best")
def _should_save_best(self, metric: float) -> bool:
if not self._best_full():
return True
# already full
worst_record_path = max(self.best_records, key=self.best_records.get)
# worst_record_path = max(self.best_records.iteritems(), key=operator.itemgetter(1))[0]
worst_metric = self.best_records[worst_record_path]
return metric < worst_metric
def _best_full(self):
return (not self._save_all) and len(self.best_records) == self.kbest_n
def _latest_full(self):
return len(self.latest_records) == self.latest_n
def _save_best_checkpoint_and_update(self, metric, checkpoint_dir,
tag_or_iteration, model, optimizer,
infos):
@ -93,8 +133,8 @@ class Checkpoint(object):
self._del_checkpoint(checkpoint_dir, worst_record_path)
# add the new one
self._save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer,
infos)
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
optimizer, infos)
self.best_records[tag_or_iteration] = metric
def _save_latest_checkpoint_and_update(
@ -108,8 +148,8 @@ class Checkpoint(object):
self._del_checkpoint(checkpoint_dir, to_del_fn)
self.latest_records.append(tag_or_iteration)
self._save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer,
infos)
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
optimizer, infos)
def _del_checkpoint(self, checkpoint_dir, tag_or_iteration):
checkpoint_path = os.path.join(checkpoint_dir,
@ -153,7 +193,6 @@ class Checkpoint(object):
for i in self.latest_records:
handle.write("model_checkpoint_path:{}\n".format(i))
def _load_parameters(self,
model,
optimizer=None,
@ -209,6 +248,7 @@ class Checkpoint(object):
configs = json.load(fin)
return configs
@mp_tools.rank_zero_only
def _save_parameters(self,
checkpoint_dir: str,

Loading…
Cancel
Save