add checkpoint to save parameters

pull/879/head
Hui Zhang 3 years ago
parent 37563d975e
commit b291c69386

@ -163,9 +163,9 @@ class Trainer():
"epoch": self.epoch, "epoch": self.epoch,
"lr": self.optimizer.get_lr() "lr": self.optimizer.get_lr()
}) })
self.checkpoint.add_checkpoint(self.checkpoint_dir, self.iteration self.checkpoint.save_parameters(self.checkpoint_dir, self.iteration
if tag is None else tag, self.model, if tag is None else tag, self.model,
self.optimizer, infos) self.optimizer, infos)
def resume_or_scratch(self): def resume_or_scratch(self):
"""Resume from latest checkpoint at checkpoints in the output """Resume from latest checkpoint at checkpoints in the output

@ -39,13 +39,13 @@ class Checkpoint():
self.latest_n = latest_n self.latest_n = latest_n
self._save_all = (kbest_n == -1) self._save_all = (kbest_n == -1)
def add_checkpoint(self, def save_parameters(self,
checkpoint_dir, checkpoint_dir,
tag_or_iteration: Union[int, Text], tag_or_iteration: Union[int, Text],
model: paddle.nn.Layer, model: paddle.nn.Layer,
optimizer: Optimizer=None, optimizer: Optimizer=None,
infos: dict=None, infos: dict=None,
metric_type="val_loss"): metric_type="val_loss"):
"""Save checkpoint in best_n and latest_n. """Save checkpoint in best_n and latest_n.
Args: Args:

Loading…
Cancel
Save