|
|
@ -16,8 +16,8 @@ import json
|
|
|
|
import os
|
|
|
|
import os
|
|
|
|
import re
|
|
|
|
import re
|
|
|
|
from pathlib import Path
|
|
|
|
from pathlib import Path
|
|
|
|
from typing import Union
|
|
|
|
|
|
|
|
from typing import Text
|
|
|
|
from typing import Text
|
|
|
|
|
|
|
|
from typing import Union
|
|
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
import paddle
|
|
|
|
from paddle import distributed as dist
|
|
|
|
from paddle import distributed as dist
|
|
|
@ -51,7 +51,7 @@ class Checkpoint():
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
checkpoint_dir (str): the directory where checkpoint is saved.
|
|
|
|
checkpoint_dir (str): the directory where checkpoint is saved.
|
|
|
|
tag_or_iteration (int or str): the latest iteration(step or epoch) number or tag.
|
|
|
|
tag_or_iteration (int or str): the latest iteration(step or epoch) number or tag.
|
|
|
|
model (Layer): model to be checkpointed.
|
|
|
|
model (Layer): model to be checkpointed.
|
|
|
|
optimizer (Optimizer, optional): optimizer to be checkpointed.
|
|
|
|
optimizer (Optimizer, optional): optimizer to be checkpointed.
|
|
|
|
infos (dict or None)): any info you want to save.
|
|
|
|
infos (dict or None)): any info you want to save.
|
|
|
|
metric_type (str, optional): metric type. Defaults to "val_loss".
|
|
|
|
metric_type (str, optional): metric type. Defaults to "val_loss".
|
|
|
@ -72,22 +72,22 @@ class Checkpoint():
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(tag_or_iteration, int):
|
|
|
|
if isinstance(tag_or_iteration, int):
|
|
|
|
self._save_checkpoint_record(checkpoint_dir, tag_or_iteration)
|
|
|
|
self._save_checkpoint_record(checkpoint_dir, tag_or_iteration)
|
|
|
|
|
|
|
|
|
|
|
|
def load_parameters(self,
|
|
|
|
def load_parameters(self,
|
|
|
|
model,
|
|
|
|
model,
|
|
|
|
optimizer=None,
|
|
|
|
optimizer=None,
|
|
|
|
checkpoint_dir=None,
|
|
|
|
checkpoint_dir=None,
|
|
|
|
checkpoint_path=None,
|
|
|
|
checkpoint_path=None,
|
|
|
|
record_file="checkpoint_latest"):
|
|
|
|
record_file="checkpoint_latest"):
|
|
|
|
"""Load a last model checkpoint from disk.
|
|
|
|
"""Load a last model checkpoint from disk.
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
model (Layer): model to load parameters.
|
|
|
|
model (Layer): model to load parameters.
|
|
|
|
optimizer (Optimizer, optional): optimizer to load states if needed.
|
|
|
|
optimizer (Optimizer, optional): optimizer to load states if needed.
|
|
|
|
Defaults to None.
|
|
|
|
Defaults to None.
|
|
|
|
checkpoint_dir (str, optional): the directory where checkpoint is saved.
|
|
|
|
checkpoint_dir (str, optional): the directory where checkpoint is saved.
|
|
|
|
checkpoint_path (str, optional): if specified, load the checkpoint
|
|
|
|
checkpoint_path (str, optional): if specified, load the checkpoint
|
|
|
|
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
|
|
|
|
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
|
|
|
|
be ignored. Defaults to None.
|
|
|
|
be ignored. Defaults to None.
|
|
|
|
record_file "checkpoint_latest" or "checkpoint_best"
|
|
|
|
record_file "checkpoint_latest" or "checkpoint_best"
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
|
configs (dict): epoch or step, lr and other meta info should be saved.
|
|
|
|
configs (dict): epoch or step, lr and other meta info should be saved.
|
|
|
@ -134,40 +134,40 @@ class Checkpoint():
|
|
|
|
optimizer=None,
|
|
|
|
optimizer=None,
|
|
|
|
checkpoint_dir=None,
|
|
|
|
checkpoint_dir=None,
|
|
|
|
checkpoint_path=None):
|
|
|
|
checkpoint_path=None):
|
|
|
|
"""Load a last model checkpoint from disk.
|
|
|
|
"""Load a last model checkpoint from disk.
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
model (Layer): model to load parameters.
|
|
|
|
model (Layer): model to load parameters.
|
|
|
|
optimizer (Optimizer, optional): optimizer to load states if needed.
|
|
|
|
optimizer (Optimizer, optional): optimizer to load states if needed.
|
|
|
|
Defaults to None.
|
|
|
|
Defaults to None.
|
|
|
|
checkpoint_dir (str, optional): the directory where checkpoint is saved.
|
|
|
|
checkpoint_dir (str, optional): the directory where checkpoint is saved.
|
|
|
|
checkpoint_path (str, optional): if specified, load the checkpoint
|
|
|
|
checkpoint_path (str, optional): if specified, load the checkpoint
|
|
|
|
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
|
|
|
|
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
|
|
|
|
be ignored. Defaults to None.
|
|
|
|
be ignored. Defaults to None.
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
|
configs (dict): epoch or step, lr and other meta info should be saved.
|
|
|
|
configs (dict): epoch or step, lr and other meta info should be saved.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
return self._load_parameters(model, optimizer, checkpoint_dir,
|
|
|
|
return self.load_parameters(model, optimizer, checkpoint_dir,
|
|
|
|
checkpoint_path, "checkpoint_latest")
|
|
|
|
checkpoint_path, "checkpoint_latest")
|
|
|
|
|
|
|
|
|
|
|
|
def load_best_parameters(self,
|
|
|
|
def load_best_parameters(self,
|
|
|
|
model,
|
|
|
|
model,
|
|
|
|
optimizer=None,
|
|
|
|
optimizer=None,
|
|
|
|
checkpoint_dir=None,
|
|
|
|
checkpoint_dir=None,
|
|
|
|
checkpoint_path=None):
|
|
|
|
checkpoint_path=None):
|
|
|
|
"""Load a last model checkpoint from disk.
|
|
|
|
"""Load a last model checkpoint from disk.
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
model (Layer): model to load parameters.
|
|
|
|
model (Layer): model to load parameters.
|
|
|
|
optimizer (Optimizer, optional): optimizer to load states if needed.
|
|
|
|
optimizer (Optimizer, optional): optimizer to load states if needed.
|
|
|
|
Defaults to None.
|
|
|
|
Defaults to None.
|
|
|
|
checkpoint_dir (str, optional): the directory where checkpoint is saved.
|
|
|
|
checkpoint_dir (str, optional): the directory where checkpoint is saved.
|
|
|
|
checkpoint_path (str, optional): if specified, load the checkpoint
|
|
|
|
checkpoint_path (str, optional): if specified, load the checkpoint
|
|
|
|
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
|
|
|
|
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
|
|
|
|
be ignored. Defaults to None.
|
|
|
|
be ignored. Defaults to None.
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
|
configs (dict): epoch or step, lr and other meta info should be saved.
|
|
|
|
configs (dict): epoch or step, lr and other meta info should be saved.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
return self._load_parameters(model, optimizer, checkpoint_dir,
|
|
|
|
return self.load_parameters(model, optimizer, checkpoint_dir,
|
|
|
|
checkpoint_path, "checkpoint_best")
|
|
|
|
checkpoint_path, "checkpoint_best")
|
|
|
|
|
|
|
|
|
|
|
|
def _should_save_best(self, metric: float) -> bool:
|
|
|
|
def _should_save_best(self, metric: float) -> bool:
|
|
|
|
if not self._best_full():
|
|
|
|
if not self._best_full():
|
|
|
|