|
|
@ -39,22 +39,6 @@ class Checkpoint(object):
|
|
|
|
self.latest_n = latest_n
|
|
|
|
self.latest_n = latest_n
|
|
|
|
self._save_all = (kbest_n == -1)
|
|
|
|
self._save_all = (kbest_n == -1)
|
|
|
|
|
|
|
|
|
|
|
|
def _should_save_best(self, metric: float) -> bool:
|
|
|
|
|
|
|
|
if not self._best_full():
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# already full
|
|
|
|
|
|
|
|
worst_record_path = max(self.best_records, key=self.best_records.get)
|
|
|
|
|
|
|
|
# worst_record_path = max(self.best_records.iteritems(), key=operator.itemgetter(1))[0]
|
|
|
|
|
|
|
|
worst_metric = self.best_records[worst_record_path]
|
|
|
|
|
|
|
|
return metric < worst_metric
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _best_full(self):
|
|
|
|
|
|
|
|
return (not self._save_all) and len(self.best_records) == self.kbest_n
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _latest_full(self):
|
|
|
|
|
|
|
|
return len(self.latest_records) == self.latest_n
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_checkpoint(self,
|
|
|
|
def add_checkpoint(self,
|
|
|
|
checkpoint_dir,
|
|
|
|
checkpoint_dir,
|
|
|
|
tag_or_iteration,
|
|
|
|
tag_or_iteration,
|
|
|
@ -64,7 +48,7 @@ class Checkpoint(object):
|
|
|
|
metric_type="val_loss"):
|
|
|
|
metric_type="val_loss"):
|
|
|
|
if (metric_type not in infos.keys()):
|
|
|
|
if (metric_type not in infos.keys()):
|
|
|
|
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
|
|
|
|
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
|
|
|
|
optimizer, infos)
|
|
|
|
optimizer, infos)
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
#save best
|
|
|
|
#save best
|
|
|
@ -73,15 +57,71 @@ class Checkpoint(object):
|
|
|
|
infos[metric_type], checkpoint_dir, tag_or_iteration, model,
|
|
|
|
infos[metric_type], checkpoint_dir, tag_or_iteration, model,
|
|
|
|
optimizer, infos)
|
|
|
|
optimizer, infos)
|
|
|
|
#save latest
|
|
|
|
#save latest
|
|
|
|
self._save_latest_checkpoint_and_update(checkpoint_dir, tag_or_iteration,
|
|
|
|
self._save_latest_checkpoint_and_update(
|
|
|
|
model, optimizer, infos)
|
|
|
|
checkpoint_dir, tag_or_iteration, model, optimizer, infos)
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(tag_or_iteration, int):
|
|
|
|
if isinstance(tag_or_iteration, int):
|
|
|
|
self._save_checkpoint_record(checkpoint_dir, tag_or_iteration)
|
|
|
|
self._save_checkpoint_record(checkpoint_dir, tag_or_iteration)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_latest_parameters(self,
|
|
|
|
|
|
|
|
model,
|
|
|
|
|
|
|
|
optimizer=None,
|
|
|
|
|
|
|
|
checkpoint_dir=None,
|
|
|
|
|
|
|
|
checkpoint_path=None):
|
|
|
|
|
|
|
|
"""Load a last model checkpoint from disk.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
model (Layer): model to load parameters.
|
|
|
|
|
|
|
|
optimizer (Optimizer, optional): optimizer to load states if needed.
|
|
|
|
|
|
|
|
Defaults to None.
|
|
|
|
|
|
|
|
checkpoint_dir (str, optional): the directory where checkpoint is saved.
|
|
|
|
|
|
|
|
checkpoint_path (str, optional): if specified, load the checkpoint
|
|
|
|
|
|
|
|
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
|
|
|
|
|
|
|
|
be ignored. Defaults to None.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
configs (dict): epoch or step, lr and other meta info should be saved.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
return self._load_parameters(model, optimizer, checkpoint_dir, checkpoint_path,
|
|
|
|
|
|
|
|
"checkpoint_latest")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_best_parameters(self,
|
|
|
|
|
|
|
|
model,
|
|
|
|
|
|
|
|
optimizer=None,
|
|
|
|
|
|
|
|
checkpoint_dir=None,
|
|
|
|
|
|
|
|
checkpoint_path=None):
|
|
|
|
|
|
|
|
"""Load a last model checkpoint from disk.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
model (Layer): model to load parameters.
|
|
|
|
|
|
|
|
optimizer (Optimizer, optional): optimizer to load states if needed.
|
|
|
|
|
|
|
|
Defaults to None.
|
|
|
|
|
|
|
|
checkpoint_dir (str, optional): the directory where checkpoint is saved.
|
|
|
|
|
|
|
|
checkpoint_path (str, optional): if specified, load the checkpoint
|
|
|
|
|
|
|
|
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
|
|
|
|
|
|
|
|
be ignored. Defaults to None.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
configs (dict): epoch or step, lr and other meta info should be saved.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
return self._load_parameters(model, optimizer, checkpoint_dir, checkpoint_path,
|
|
|
|
|
|
|
|
"checkpoint_best")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _should_save_best(self, metric: float) -> bool:
|
|
|
|
|
|
|
|
if not self._best_full():
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# already full
|
|
|
|
|
|
|
|
worst_record_path = max(self.best_records, key=self.best_records.get)
|
|
|
|
|
|
|
|
# worst_record_path = max(self.best_records.iteritems(), key=operator.itemgetter(1))[0]
|
|
|
|
|
|
|
|
worst_metric = self.best_records[worst_record_path]
|
|
|
|
|
|
|
|
return metric < worst_metric
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _best_full(self):
|
|
|
|
|
|
|
|
return (not self._save_all) and len(self.best_records) == self.kbest_n
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _latest_full(self):
|
|
|
|
|
|
|
|
return len(self.latest_records) == self.latest_n
|
|
|
|
|
|
|
|
|
|
|
|
def _save_best_checkpoint_and_update(self, metric, checkpoint_dir,
|
|
|
|
def _save_best_checkpoint_and_update(self, metric, checkpoint_dir,
|
|
|
|
tag_or_iteration, model, optimizer,
|
|
|
|
tag_or_iteration, model, optimizer,
|
|
|
|
infos):
|
|
|
|
infos):
|
|
|
|
# remove the worst
|
|
|
|
# remove the worst
|
|
|
|
if self._best_full():
|
|
|
|
if self._best_full():
|
|
|
|
worst_record_path = max(self.best_records,
|
|
|
|
worst_record_path = max(self.best_records,
|
|
|
@ -93,8 +133,8 @@ class Checkpoint(object):
|
|
|
|
self._del_checkpoint(checkpoint_dir, worst_record_path)
|
|
|
|
self._del_checkpoint(checkpoint_dir, worst_record_path)
|
|
|
|
|
|
|
|
|
|
|
|
# add the new one
|
|
|
|
# add the new one
|
|
|
|
self._save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer,
|
|
|
|
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
|
|
|
|
infos)
|
|
|
|
optimizer, infos)
|
|
|
|
self.best_records[tag_or_iteration] = metric
|
|
|
|
self.best_records[tag_or_iteration] = metric
|
|
|
|
|
|
|
|
|
|
|
|
def _save_latest_checkpoint_and_update(
|
|
|
|
def _save_latest_checkpoint_and_update(
|
|
|
@ -108,8 +148,8 @@ class Checkpoint(object):
|
|
|
|
self._del_checkpoint(checkpoint_dir, to_del_fn)
|
|
|
|
self._del_checkpoint(checkpoint_dir, to_del_fn)
|
|
|
|
self.latest_records.append(tag_or_iteration)
|
|
|
|
self.latest_records.append(tag_or_iteration)
|
|
|
|
|
|
|
|
|
|
|
|
self._save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer,
|
|
|
|
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
|
|
|
|
infos)
|
|
|
|
optimizer, infos)
|
|
|
|
|
|
|
|
|
|
|
|
def _del_checkpoint(self, checkpoint_dir, tag_or_iteration):
|
|
|
|
def _del_checkpoint(self, checkpoint_dir, tag_or_iteration):
|
|
|
|
checkpoint_path = os.path.join(checkpoint_dir,
|
|
|
|
checkpoint_path = os.path.join(checkpoint_dir,
|
|
|
@ -153,13 +193,12 @@ class Checkpoint(object):
|
|
|
|
for i in self.latest_records:
|
|
|
|
for i in self.latest_records:
|
|
|
|
handle.write("model_checkpoint_path:{}\n".format(i))
|
|
|
|
handle.write("model_checkpoint_path:{}\n".format(i))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_parameters(self,
|
|
|
|
def _load_parameters(self,
|
|
|
|
model,
|
|
|
|
model,
|
|
|
|
optimizer=None,
|
|
|
|
optimizer=None,
|
|
|
|
checkpoint_dir=None,
|
|
|
|
checkpoint_dir=None,
|
|
|
|
checkpoint_path=None,
|
|
|
|
checkpoint_path=None,
|
|
|
|
checkpoint_file=None):
|
|
|
|
checkpoint_file=None):
|
|
|
|
"""Load a last model checkpoint from disk.
|
|
|
|
"""Load a last model checkpoint from disk.
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
model (Layer): model to load parameters.
|
|
|
|
model (Layer): model to load parameters.
|
|
|
@ -209,13 +248,14 @@ class Checkpoint(object):
|
|
|
|
configs = json.load(fin)
|
|
|
|
configs = json.load(fin)
|
|
|
|
return configs
|
|
|
|
return configs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@mp_tools.rank_zero_only
|
|
|
|
@mp_tools.rank_zero_only
|
|
|
|
def _save_parameters(self,
|
|
|
|
def _save_parameters(self,
|
|
|
|
checkpoint_dir: str,
|
|
|
|
checkpoint_dir: str,
|
|
|
|
tag_or_iteration: Union[int, str],
|
|
|
|
tag_or_iteration: Union[int, str],
|
|
|
|
model: paddle.nn.Layer,
|
|
|
|
model: paddle.nn.Layer,
|
|
|
|
optimizer: Optimizer=None,
|
|
|
|
optimizer: Optimizer=None,
|
|
|
|
infos: dict=None):
|
|
|
|
infos: dict=None):
|
|
|
|
"""Checkpoint the latest trained model parameters.
|
|
|
|
"""Checkpoint the latest trained model parameters.
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
checkpoint_dir (str): the directory where checkpoint is saved.
|
|
|
|
checkpoint_dir (str): the directory where checkpoint is saved.
|
|
|
|