add checkpoint to save parameters

pull/879/head
Hui Zhang 3 years ago
parent 37563d975e
commit b291c69386

@ -163,9 +163,9 @@ class Trainer():
"epoch": self.epoch,
"lr": self.optimizer.get_lr()
})
self.checkpoint.add_checkpoint(self.checkpoint_dir, self.iteration
if tag is None else tag, self.model,
self.optimizer, infos)
self.checkpoint.save_parameters(self.checkpoint_dir, self.iteration
if tag is None else tag, self.model,
self.optimizer, infos)
def resume_or_scratch(self):
"""Resume from latest checkpoint at checkpoints in the output

@ -39,13 +39,13 @@ class Checkpoint():
self.latest_n = latest_n
self._save_all = (kbest_n == -1)
def add_checkpoint(self,
checkpoint_dir,
tag_or_iteration: Union[int, Text],
model: paddle.nn.Layer,
optimizer: Optimizer=None,
infos: dict=None,
metric_type="val_loss"):
def save_parameters(self,
checkpoint_dir,
tag_or_iteration: Union[int, Text],
model: paddle.nn.Layer,
optimizer: Optimizer=None,
infos: dict=None,
metric_type="val_loss"):
"""Save checkpoint in best_n and latest_n.
Args:

Loading…
Cancel
Save