|
|
@ -39,22 +39,6 @@ class Checkpoint(object):
|
|
|
|
self.latest_n = latest_n
|
|
|
|
self.latest_n = latest_n
|
|
|
|
self._save_all = (kbest_n == -1)
|
|
|
|
self._save_all = (kbest_n == -1)
|
|
|
|
|
|
|
|
|
|
|
|
def _should_save_best(self, metric: float) -> bool:
|
|
|
|
|
|
|
|
if not self._best_full():
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# already full
|
|
|
|
|
|
|
|
worst_record_path = max(self.best_records, key=self.best_records.get)
|
|
|
|
|
|
|
|
# worst_record_path = max(self.best_records.iteritems(), key=operator.itemgetter(1))[0]
|
|
|
|
|
|
|
|
worst_metric = self.best_records[worst_record_path]
|
|
|
|
|
|
|
|
return metric < worst_metric
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _best_full(self):
|
|
|
|
|
|
|
|
return (not self._save_all) and len(self.best_records) == self.kbest_n
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _latest_full(self):
|
|
|
|
|
|
|
|
return len(self.latest_records) == self.latest_n
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_checkpoint(self,
|
|
|
|
def add_checkpoint(self,
|
|
|
|
checkpoint_dir,
|
|
|
|
checkpoint_dir,
|
|
|
|
tag_or_iteration,
|
|
|
|
tag_or_iteration,
|
|
|
@ -73,12 +57,68 @@ class Checkpoint(object):
|
|
|
|
infos[metric_type], checkpoint_dir, tag_or_iteration, model,
|
|
|
|
infos[metric_type], checkpoint_dir, tag_or_iteration, model,
|
|
|
|
optimizer, infos)
|
|
|
|
optimizer, infos)
|
|
|
|
#save latest
|
|
|
|
#save latest
|
|
|
|
self._save_latest_checkpoint_and_update(checkpoint_dir, tag_or_iteration,
|
|
|
|
self._save_latest_checkpoint_and_update(
|
|
|
|
model, optimizer, infos)
|
|
|
|
checkpoint_dir, tag_or_iteration, model, optimizer, infos)
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(tag_or_iteration, int):
|
|
|
|
if isinstance(tag_or_iteration, int):
|
|
|
|
self._save_checkpoint_record(checkpoint_dir, tag_or_iteration)
|
|
|
|
self._save_checkpoint_record(checkpoint_dir, tag_or_iteration)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_latest_parameters(self,
|
|
|
|
|
|
|
|
model,
|
|
|
|
|
|
|
|
optimizer=None,
|
|
|
|
|
|
|
|
checkpoint_dir=None,
|
|
|
|
|
|
|
|
checkpoint_path=None):
|
|
|
|
|
|
|
|
"""Load a last model checkpoint from disk.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
model (Layer): model to load parameters.
|
|
|
|
|
|
|
|
optimizer (Optimizer, optional): optimizer to load states if needed.
|
|
|
|
|
|
|
|
Defaults to None.
|
|
|
|
|
|
|
|
checkpoint_dir (str, optional): the directory where checkpoint is saved.
|
|
|
|
|
|
|
|
checkpoint_path (str, optional): if specified, load the checkpoint
|
|
|
|
|
|
|
|
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
|
|
|
|
|
|
|
|
be ignored. Defaults to None.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
configs (dict): epoch or step, lr and other meta info should be saved.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
return self._load_parameters(model, optimizer, checkpoint_dir, checkpoint_path,
|
|
|
|
|
|
|
|
"checkpoint_latest")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_best_parameters(self,
|
|
|
|
|
|
|
|
model,
|
|
|
|
|
|
|
|
optimizer=None,
|
|
|
|
|
|
|
|
checkpoint_dir=None,
|
|
|
|
|
|
|
|
checkpoint_path=None):
|
|
|
|
|
|
|
|
"""Load a last model checkpoint from disk.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
model (Layer): model to load parameters.
|
|
|
|
|
|
|
|
optimizer (Optimizer, optional): optimizer to load states if needed.
|
|
|
|
|
|
|
|
Defaults to None.
|
|
|
|
|
|
|
|
checkpoint_dir (str, optional): the directory where checkpoint is saved.
|
|
|
|
|
|
|
|
checkpoint_path (str, optional): if specified, load the checkpoint
|
|
|
|
|
|
|
|
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
|
|
|
|
|
|
|
|
be ignored. Defaults to None.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
configs (dict): epoch or step, lr and other meta info should be saved.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
return self._load_parameters(model, optimizer, checkpoint_dir, checkpoint_path,
|
|
|
|
|
|
|
|
"checkpoint_best")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _should_save_best(self, metric: float) -> bool:
|
|
|
|
|
|
|
|
if not self._best_full():
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# already full
|
|
|
|
|
|
|
|
worst_record_path = max(self.best_records, key=self.best_records.get)
|
|
|
|
|
|
|
|
# worst_record_path = max(self.best_records.iteritems(), key=operator.itemgetter(1))[0]
|
|
|
|
|
|
|
|
worst_metric = self.best_records[worst_record_path]
|
|
|
|
|
|
|
|
return metric < worst_metric
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _best_full(self):
|
|
|
|
|
|
|
|
return (not self._save_all) and len(self.best_records) == self.kbest_n
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _latest_full(self):
|
|
|
|
|
|
|
|
return len(self.latest_records) == self.latest_n
|
|
|
|
|
|
|
|
|
|
|
|
def _save_best_checkpoint_and_update(self, metric, checkpoint_dir,
|
|
|
|
def _save_best_checkpoint_and_update(self, metric, checkpoint_dir,
|
|
|
|
tag_or_iteration, model, optimizer,
|
|
|
|
tag_or_iteration, model, optimizer,
|
|
|
|
infos):
|
|
|
|
infos):
|
|
|
@ -93,8 +133,8 @@ class Checkpoint(object):
|
|
|
|
self._del_checkpoint(checkpoint_dir, worst_record_path)
|
|
|
|
self._del_checkpoint(checkpoint_dir, worst_record_path)
|
|
|
|
|
|
|
|
|
|
|
|
# add the new one
|
|
|
|
# add the new one
|
|
|
|
self._save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer,
|
|
|
|
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
|
|
|
|
infos)
|
|
|
|
optimizer, infos)
|
|
|
|
self.best_records[tag_or_iteration] = metric
|
|
|
|
self.best_records[tag_or_iteration] = metric
|
|
|
|
|
|
|
|
|
|
|
|
def _save_latest_checkpoint_and_update(
|
|
|
|
def _save_latest_checkpoint_and_update(
|
|
|
@ -108,8 +148,8 @@ class Checkpoint(object):
|
|
|
|
self._del_checkpoint(checkpoint_dir, to_del_fn)
|
|
|
|
self._del_checkpoint(checkpoint_dir, to_del_fn)
|
|
|
|
self.latest_records.append(tag_or_iteration)
|
|
|
|
self.latest_records.append(tag_or_iteration)
|
|
|
|
|
|
|
|
|
|
|
|
self._save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer,
|
|
|
|
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
|
|
|
|
infos)
|
|
|
|
optimizer, infos)
|
|
|
|
|
|
|
|
|
|
|
|
def _del_checkpoint(self, checkpoint_dir, tag_or_iteration):
|
|
|
|
def _del_checkpoint(self, checkpoint_dir, tag_or_iteration):
|
|
|
|
checkpoint_path = os.path.join(checkpoint_dir,
|
|
|
|
checkpoint_path = os.path.join(checkpoint_dir,
|
|
|
@ -153,7 +193,6 @@ class Checkpoint(object):
|
|
|
|
for i in self.latest_records:
|
|
|
|
for i in self.latest_records:
|
|
|
|
handle.write("model_checkpoint_path:{}\n".format(i))
|
|
|
|
handle.write("model_checkpoint_path:{}\n".format(i))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_parameters(self,
|
|
|
|
def _load_parameters(self,
|
|
|
|
model,
|
|
|
|
model,
|
|
|
|
optimizer=None,
|
|
|
|
optimizer=None,
|
|
|
@ -209,6 +248,7 @@ class Checkpoint(object):
|
|
|
|
configs = json.load(fin)
|
|
|
|
configs = json.load(fin)
|
|
|
|
return configs
|
|
|
|
return configs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@mp_tools.rank_zero_only
|
|
|
|
@mp_tools.rank_zero_only
|
|
|
|
def _save_parameters(self,
|
|
|
|
def _save_parameters(self,
|
|
|
|
checkpoint_dir: str,
|
|
|
|
checkpoint_dir: str,
|
|
|
|