|
|
@ -17,6 +17,7 @@ import os
|
|
|
|
import re
|
|
|
|
import re
|
|
|
|
from pathlib import Path
|
|
|
|
from pathlib import Path
|
|
|
|
from typing import Union
|
|
|
|
from typing import Union
|
|
|
|
|
|
|
|
from typing import Text
|
|
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
import paddle
|
|
|
|
from paddle import distributed as dist
|
|
|
|
from paddle import distributed as dist
|
|
|
@ -30,7 +31,7 @@ logger = Log(__name__).getlog()
|
|
|
|
__all__ = ["Checkpoint"]
|
|
|
|
__all__ = ["Checkpoint"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Checkpoint(object):
|
|
|
|
class Checkpoint():
|
|
|
|
def __init__(self, kbest_n: int=5, latest_n: int=1):
|
|
|
|
def __init__(self, kbest_n: int=5, latest_n: int=1):
|
|
|
|
self.best_records: Mapping[Path, float] = {}
|
|
|
|
self.best_records: Mapping[Path, float] = {}
|
|
|
|
self.latest_records = []
|
|
|
|
self.latest_records = []
|
|
|
@ -40,11 +41,21 @@ class Checkpoint(object):
|
|
|
|
|
|
|
|
|
|
|
|
def add_checkpoint(self,
|
|
|
|
def add_checkpoint(self,
|
|
|
|
checkpoint_dir,
|
|
|
|
checkpoint_dir,
|
|
|
|
tag_or_iteration,
|
|
|
|
tag_or_iteration: Union[int, Text],
|
|
|
|
model,
|
|
|
|
model: paddle.nn.Layer,
|
|
|
|
optimizer,
|
|
|
|
optimizer: Optimizer=None,
|
|
|
|
infos,
|
|
|
|
infos: dict=None,
|
|
|
|
metric_type="val_loss"):
|
|
|
|
metric_type="val_loss"):
|
|
|
|
|
|
|
|
"""Save checkpoint in best_n and latest_n.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
checkpoint_dir (str): the directory where checkpoint is saved.
|
|
|
|
|
|
|
|
tag_or_iteration (int or str): the latest iteration(step or epoch) number or tag.
|
|
|
|
|
|
|
|
model (Layer): model to be checkpointed.
|
|
|
|
|
|
|
|
optimizer (Optimizer, optional): optimizer to be checkpointed.
|
|
|
|
|
|
|
|
infos (dict or None)): any info you want to save.
|
|
|
|
|
|
|
|
metric_type (str, optional): metric type. Defaults to "val_loss".
|
|
|
|
|
|
|
|
"""
|
|
|
|
if (metric_type not in infos.keys()):
|
|
|
|
if (metric_type not in infos.keys()):
|
|
|
|
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
|
|
|
|
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
|
|
|
|
optimizer, infos)
|
|
|
|
optimizer, infos)
|
|
|
@ -62,6 +73,62 @@ class Checkpoint(object):
|
|
|
|
if isinstance(tag_or_iteration, int):
|
|
|
|
if isinstance(tag_or_iteration, int):
|
|
|
|
self._save_checkpoint_record(checkpoint_dir, tag_or_iteration)
|
|
|
|
self._save_checkpoint_record(checkpoint_dir, tag_or_iteration)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_parameters(self,
|
|
|
|
|
|
|
|
model,
|
|
|
|
|
|
|
|
optimizer=None,
|
|
|
|
|
|
|
|
checkpoint_dir=None,
|
|
|
|
|
|
|
|
checkpoint_path=None,
|
|
|
|
|
|
|
|
record_file="checkpoint_latest"):
|
|
|
|
|
|
|
|
"""Load a last model checkpoint from disk.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
model (Layer): model to load parameters.
|
|
|
|
|
|
|
|
optimizer (Optimizer, optional): optimizer to load states if needed.
|
|
|
|
|
|
|
|
Defaults to None.
|
|
|
|
|
|
|
|
checkpoint_dir (str, optional): the directory where checkpoint is saved.
|
|
|
|
|
|
|
|
checkpoint_path (str, optional): if specified, load the checkpoint
|
|
|
|
|
|
|
|
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
|
|
|
|
|
|
|
|
be ignored. Defaults to None.
|
|
|
|
|
|
|
|
record_file "checkpoint_latest" or "checkpoint_best"
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
configs (dict): epoch or step, lr and other meta info should be saved.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
configs = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if checkpoint_path is not None:
|
|
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
elif checkpoint_dir is not None and record_file is not None:
|
|
|
|
|
|
|
|
# load checkpint from record file
|
|
|
|
|
|
|
|
checkpoint_record = os.path.join(checkpoint_dir, record_file)
|
|
|
|
|
|
|
|
iteration = self._load_checkpoint_idx(checkpoint_record)
|
|
|
|
|
|
|
|
if iteration == -1:
|
|
|
|
|
|
|
|
return configs
|
|
|
|
|
|
|
|
checkpoint_path = os.path.join(checkpoint_dir,
|
|
|
|
|
|
|
|
"{}".format(iteration))
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
|
|
|
"At least one of 'checkpoint_path' or 'checkpoint_dir' should be specified!"
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rank = dist.get_rank()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
params_path = checkpoint_path + ".pdparams"
|
|
|
|
|
|
|
|
model_dict = paddle.load(params_path)
|
|
|
|
|
|
|
|
model.set_state_dict(model_dict)
|
|
|
|
|
|
|
|
logger.info("Rank {}: loaded model from {}".format(rank, params_path))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimizer_path = checkpoint_path + ".pdopt"
|
|
|
|
|
|
|
|
if optimizer and os.path.isfile(optimizer_path):
|
|
|
|
|
|
|
|
optimizer_dict = paddle.load(optimizer_path)
|
|
|
|
|
|
|
|
optimizer.set_state_dict(optimizer_dict)
|
|
|
|
|
|
|
|
logger.info("Rank {}: loaded optimizer state from {}".format(
|
|
|
|
|
|
|
|
rank, optimizer_path))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
info_path = re.sub('.pdparams$', '.json', params_path)
|
|
|
|
|
|
|
|
if os.path.exists(info_path):
|
|
|
|
|
|
|
|
with open(info_path, 'r') as fin:
|
|
|
|
|
|
|
|
configs = json.load(fin)
|
|
|
|
|
|
|
|
return configs
|
|
|
|
|
|
|
|
|
|
|
|
def load_latest_parameters(self,
|
|
|
|
def load_latest_parameters(self,
|
|
|
|
model,
|
|
|
|
model,
|
|
|
|
optimizer=None,
|
|
|
|
optimizer=None,
|
|
|
@ -192,61 +259,6 @@ class Checkpoint(object):
|
|
|
|
for i in self.latest_records:
|
|
|
|
for i in self.latest_records:
|
|
|
|
handle.write("model_checkpoint_path:{}\n".format(i))
|
|
|
|
handle.write("model_checkpoint_path:{}\n".format(i))
|
|
|
|
|
|
|
|
|
|
|
|
def _load_parameters(self,
|
|
|
|
|
|
|
|
model,
|
|
|
|
|
|
|
|
optimizer=None,
|
|
|
|
|
|
|
|
checkpoint_dir=None,
|
|
|
|
|
|
|
|
checkpoint_path=None,
|
|
|
|
|
|
|
|
checkpoint_file=None):
|
|
|
|
|
|
|
|
"""Load a last model checkpoint from disk.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
model (Layer): model to load parameters.
|
|
|
|
|
|
|
|
optimizer (Optimizer, optional): optimizer to load states if needed.
|
|
|
|
|
|
|
|
Defaults to None.
|
|
|
|
|
|
|
|
checkpoint_dir (str, optional): the directory where checkpoint is saved.
|
|
|
|
|
|
|
|
checkpoint_path (str, optional): if specified, load the checkpoint
|
|
|
|
|
|
|
|
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
|
|
|
|
|
|
|
|
be ignored. Defaults to None.
|
|
|
|
|
|
|
|
checkpoint_file "checkpoint_latest" or "checkpoint_best"
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
configs (dict): epoch or step, lr and other meta info should be saved.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
configs = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if checkpoint_path is not None:
|
|
|
|
|
|
|
|
tag = os.path.basename(checkpoint_path).split(":")[-1]
|
|
|
|
|
|
|
|
elif checkpoint_dir is not None and checkpoint_file is not None:
|
|
|
|
|
|
|
|
checkpoint_record = os.path.join(checkpoint_dir, checkpoint_file)
|
|
|
|
|
|
|
|
iteration = self._load_checkpoint_idx(checkpoint_record)
|
|
|
|
|
|
|
|
if iteration == -1:
|
|
|
|
|
|
|
|
return configs
|
|
|
|
|
|
|
|
checkpoint_path = os.path.join(checkpoint_dir,
|
|
|
|
|
|
|
|
"{}".format(iteration))
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
|
|
|
"At least one of 'checkpoint_dir' and 'checkpoint_file' and 'checkpoint_path' should be specified!"
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rank = dist.get_rank()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
params_path = checkpoint_path + ".pdparams"
|
|
|
|
|
|
|
|
model_dict = paddle.load(params_path)
|
|
|
|
|
|
|
|
model.set_state_dict(model_dict)
|
|
|
|
|
|
|
|
logger.info("Rank {}: loaded model from {}".format(rank, params_path))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimizer_path = checkpoint_path + ".pdopt"
|
|
|
|
|
|
|
|
if optimizer and os.path.isfile(optimizer_path):
|
|
|
|
|
|
|
|
optimizer_dict = paddle.load(optimizer_path)
|
|
|
|
|
|
|
|
optimizer.set_state_dict(optimizer_dict)
|
|
|
|
|
|
|
|
logger.info("Rank {}: loaded optimizer state from {}".format(
|
|
|
|
|
|
|
|
rank, optimizer_path))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
info_path = re.sub('.pdparams$', '.json', params_path)
|
|
|
|
|
|
|
|
if os.path.exists(info_path):
|
|
|
|
|
|
|
|
with open(info_path, 'r') as fin:
|
|
|
|
|
|
|
|
configs = json.load(fin)
|
|
|
|
|
|
|
|
return configs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@mp_tools.rank_zero_only
|
|
|
|
@mp_tools.rank_zero_only
|
|
|
|
def _save_parameters(self,
|
|
|
|
def _save_parameters(self,
|
|
|
|
checkpoint_dir: str,
|
|
|
|
checkpoint_dir: str,
|
|
|
|