fix private function

pull/680/head
Haoxin Ma 3 years ago
parent 6d92417edd
commit 08b6213bc8

@ -151,12 +151,11 @@ class Trainer():
resume training.
"""
scratch = None
infos = self.checkpoint._load_parameters(
infos = self.checkpoint.load_latest_parameters(
self.model,
self.optimizer,
checkpoint_dir=self.checkpoint_dir,
checkpoint_path=self.args.checkpoint_path,
checkpoint_file='checkpoint_latest')
checkpoint_path=self.args.checkpoint_path)
if infos:
# restore from ckpt
self.iteration = infos["step"]

@ -38,23 +38,7 @@ class Checkpoint(object):
self.kbest_n = kbest_n
self.latest_n = latest_n
self._save_all = (kbest_n == -1)
def _should_save_best(self, metric: float) -> bool:
if not self._best_full():
return True
# already full
worst_record_path = max(self.best_records, key=self.best_records.get)
# worst_record_path = max(self.best_records.iteritems(), key=operator.itemgetter(1))[0]
worst_metric = self.best_records[worst_record_path]
return metric < worst_metric
def _best_full(self):
return (not self._save_all) and len(self.best_records) == self.kbest_n
def _latest_full(self):
return len(self.latest_records) == self.latest_n
def add_checkpoint(self,
checkpoint_dir,
tag_or_iteration,
@ -64,7 +48,7 @@ class Checkpoint(object):
metric_type="val_loss"):
if (metric_type not in infos.keys()):
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
optimizer, infos)
optimizer, infos)
return
#save best
@ -73,15 +57,71 @@ class Checkpoint(object):
infos[metric_type], checkpoint_dir, tag_or_iteration, model,
optimizer, infos)
#save latest
self._save_latest_checkpoint_and_update(checkpoint_dir, tag_or_iteration,
model, optimizer, infos)
self._save_latest_checkpoint_and_update(
checkpoint_dir, tag_or_iteration, model, optimizer, infos)
if isinstance(tag_or_iteration, int):
self._save_checkpoint_record(checkpoint_dir, tag_or_iteration)
def load_latest_parameters(self,
model,
optimizer=None,
checkpoint_dir=None,
checkpoint_path=None):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
return self._load_parameters(model, optimizer, checkpoint_dir, checkpoint_path,
"checkpoint_latest")
def load_best_parameters(self,
model,
optimizer=None,
checkpoint_dir=None,
checkpoint_path=None):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
return self._load_parameters(model, optimizer, checkpoint_dir, checkpoint_path,
"checkpoint_best")
def _should_save_best(self, metric: float) -> bool:
if not self._best_full():
return True
# already full
worst_record_path = max(self.best_records, key=self.best_records.get)
# worst_record_path = max(self.best_records.iteritems(), key=operator.itemgetter(1))[0]
worst_metric = self.best_records[worst_record_path]
return metric < worst_metric
def _best_full(self):
return (not self._save_all) and len(self.best_records) == self.kbest_n
def _latest_full(self):
return len(self.latest_records) == self.latest_n
def _save_best_checkpoint_and_update(self, metric, checkpoint_dir,
tag_or_iteration, model, optimizer,
infos):
tag_or_iteration, model, optimizer,
infos):
# remove the worst
if self._best_full():
worst_record_path = max(self.best_records,
@ -93,8 +133,8 @@ class Checkpoint(object):
self._del_checkpoint(checkpoint_dir, worst_record_path)
# add the new one
self._save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer,
infos)
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
optimizer, infos)
self.best_records[tag_or_iteration] = metric
def _save_latest_checkpoint_and_update(
@ -108,8 +148,8 @@ class Checkpoint(object):
self._del_checkpoint(checkpoint_dir, to_del_fn)
self.latest_records.append(tag_or_iteration)
self._save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer,
infos)
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
optimizer, infos)
def _del_checkpoint(self, checkpoint_dir, tag_or_iteration):
checkpoint_path = os.path.join(checkpoint_dir,
@ -153,13 +193,12 @@ class Checkpoint(object):
for i in self.latest_records:
handle.write("model_checkpoint_path:{}\n".format(i))
def _load_parameters(self,
model,
optimizer=None,
checkpoint_dir=None,
checkpoint_path=None,
checkpoint_file=None):
model,
optimizer=None,
checkpoint_dir=None,
checkpoint_path=None,
checkpoint_file=None):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
@ -209,13 +248,14 @@ class Checkpoint(object):
configs = json.load(fin)
return configs
@mp_tools.rank_zero_only
def _save_parameters(self,
checkpoint_dir: str,
tag_or_iteration: Union[int, str],
model: paddle.nn.Layer,
optimizer: Optimizer=None,
infos: dict=None):
checkpoint_dir: str,
tag_or_iteration: Union[int, str],
model: paddle.nn.Layer,
optimizer: Optimizer=None,
infos: dict=None):
"""Checkpoint the latest trained model parameters.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.

Loading…
Cancel
Save