fix load param

pull/701/head
Hui Zhang 4 years ago
parent a37192c809
commit 2820537fcc

@ -16,8 +16,8 @@ import json
import os
import re
from pathlib import Path
from typing import Union
from typing import Text
from typing import Union
import paddle
from paddle import distributed as dist
@ -74,11 +74,11 @@ class Checkpoint():
self._save_checkpoint_record(checkpoint_dir, tag_or_iteration)
def load_parameters(self,
model,
optimizer=None,
checkpoint_dir=None,
checkpoint_path=None,
record_file="checkpoint_latest"):
model,
optimizer=None,
checkpoint_dir=None,
checkpoint_path=None,
record_file="checkpoint_latest"):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
@ -146,8 +146,8 @@ class Checkpoint():
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
return self._load_parameters(model, optimizer, checkpoint_dir,
checkpoint_path, "checkpoint_latest")
return self.load_parameters(model, optimizer, checkpoint_dir,
checkpoint_path, "checkpoint_latest")
def load_best_parameters(self,
model,
@ -166,8 +166,8 @@ class Checkpoint():
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
return self._load_parameters(model, optimizer, checkpoint_dir,
checkpoint_path, "checkpoint_best")
return self.load_parameters(model, optimizer, checkpoint_dir,
checkpoint_path, "checkpoint_best")
def _should_save_best(self, metric: float) -> bool:
if not self._best_full():

Loading…
Cancel
Save