fix load param

pull/701/head
Hui Zhang 4 years ago
parent a37192c809
commit 2820537fcc

@ -16,8 +16,8 @@ import json
import os import os
import re import re
from pathlib import Path from pathlib import Path
from typing import Union
from typing import Text from typing import Text
from typing import Union
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
@ -146,7 +146,7 @@ class Checkpoint():
Returns: Returns:
configs (dict): epoch or step, lr and other meta info should be saved. configs (dict): epoch or step, lr and other meta info should be saved.
""" """
return self._load_parameters(model, optimizer, checkpoint_dir, return self.load_parameters(model, optimizer, checkpoint_dir,
checkpoint_path, "checkpoint_latest") checkpoint_path, "checkpoint_latest")
def load_best_parameters(self, def load_best_parameters(self,
@ -166,7 +166,7 @@ class Checkpoint():
Returns: Returns:
configs (dict): epoch or step, lr and other meta info should be saved. configs (dict): epoch or step, lr and other meta info should be saved.
""" """
return self._load_parameters(model, optimizer, checkpoint_dir, return self.load_parameters(model, optimizer, checkpoint_dir,
checkpoint_path, "checkpoint_best") checkpoint_path, "checkpoint_best")
def _should_save_best(self, metric: float) -> bool: def _should_save_best(self, metric: float) -> bool:

Loading…
Cancel
Save