fix load param

pull/701/head
Hui Zhang 4 years ago
parent a37192c809
commit 2820537fcc

@ -16,8 +16,8 @@ import json
import os import os
import re import re
from pathlib import Path from pathlib import Path
from typing import Union
from typing import Text from typing import Text
from typing import Union
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
@ -51,7 +51,7 @@ class Checkpoint():
Args: Args:
checkpoint_dir (str): the directory where checkpoint is saved. checkpoint_dir (str): the directory where checkpoint is saved.
tag_or_iteration (int or str): the latest iteration(step or epoch) number or tag. tag_or_iteration (int or str): the latest iteration(step or epoch) number or tag.
model (Layer): model to be checkpointed. model (Layer): model to be checkpointed.
optimizer (Optimizer, optional): optimizer to be checkpointed. optimizer (Optimizer, optional): optimizer to be checkpointed.
infos (dict or None)): any info you want to save. infos (dict or None)): any info you want to save.
metric_type (str, optional): metric type. Defaults to "val_loss". metric_type (str, optional): metric type. Defaults to "val_loss".
@ -72,22 +72,22 @@ class Checkpoint():
if isinstance(tag_or_iteration, int): if isinstance(tag_or_iteration, int):
self._save_checkpoint_record(checkpoint_dir, tag_or_iteration) self._save_checkpoint_record(checkpoint_dir, tag_or_iteration)
def load_parameters(self, def load_parameters(self,
model, model,
optimizer=None, optimizer=None,
checkpoint_dir=None, checkpoint_dir=None,
checkpoint_path=None, checkpoint_path=None,
record_file="checkpoint_latest"): record_file="checkpoint_latest"):
"""Load a last model checkpoint from disk. """Load a last model checkpoint from disk.
Args: Args:
model (Layer): model to load parameters. model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed. optimizer (Optimizer, optional): optimizer to load states if needed.
Defaults to None. Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved. checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None. be ignored. Defaults to None.
record_file "checkpoint_latest" or "checkpoint_best" record_file "checkpoint_latest" or "checkpoint_best"
Returns: Returns:
configs (dict): epoch or step, lr and other meta info should be saved. configs (dict): epoch or step, lr and other meta info should be saved.
@ -134,40 +134,40 @@ class Checkpoint():
optimizer=None, optimizer=None,
checkpoint_dir=None, checkpoint_dir=None,
checkpoint_path=None): checkpoint_path=None):
"""Load a last model checkpoint from disk. """Load a last model checkpoint from disk.
Args: Args:
model (Layer): model to load parameters. model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed. optimizer (Optimizer, optional): optimizer to load states if needed.
Defaults to None. Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved. checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None. be ignored. Defaults to None.
Returns: Returns:
configs (dict): epoch or step, lr and other meta info should be saved. configs (dict): epoch or step, lr and other meta info should be saved.
""" """
return self._load_parameters(model, optimizer, checkpoint_dir, return self.load_parameters(model, optimizer, checkpoint_dir,
checkpoint_path, "checkpoint_latest") checkpoint_path, "checkpoint_latest")
def load_best_parameters(self, def load_best_parameters(self,
model, model,
optimizer=None, optimizer=None,
checkpoint_dir=None, checkpoint_dir=None,
checkpoint_path=None): checkpoint_path=None):
"""Load a last model checkpoint from disk. """Load a last model checkpoint from disk.
Args: Args:
model (Layer): model to load parameters. model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed. optimizer (Optimizer, optional): optimizer to load states if needed.
Defaults to None. Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved. checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None. be ignored. Defaults to None.
Returns: Returns:
configs (dict): epoch or step, lr and other meta info should be saved. configs (dict): epoch or step, lr and other meta info should be saved.
""" """
return self._load_parameters(model, optimizer, checkpoint_dir, return self.load_parameters(model, optimizer, checkpoint_dir,
checkpoint_path, "checkpoint_best") checkpoint_path, "checkpoint_best")
def _should_save_best(self, metric: float) -> bool: def _should_save_best(self, metric: float) -> bool:
if not self._best_full(): if not self._best_full():

Loading…
Cancel
Save