add checkpoint to save parameters

pull/879/head
Hui Zhang 4 years ago
parent 37563d975e
commit b291c69386

@ -163,7 +163,7 @@ class Trainer():
"epoch": self.epoch, "epoch": self.epoch,
"lr": self.optimizer.get_lr() "lr": self.optimizer.get_lr()
}) })
self.checkpoint.add_checkpoint(self.checkpoint_dir, self.iteration self.checkpoint.save_parameters(self.checkpoint_dir, self.iteration
if tag is None else tag, self.model, if tag is None else tag, self.model,
self.optimizer, infos) self.optimizer, infos)

@ -39,7 +39,7 @@ class Checkpoint():
self.latest_n = latest_n self.latest_n = latest_n
self._save_all = (kbest_n == -1) self._save_all = (kbest_n == -1)
def add_checkpoint(self, def save_parameters(self,
checkpoint_dir, checkpoint_dir,
tag_or_iteration: Union[int, Text], tag_or_iteration: Union[int, Text],
model: paddle.nn.Layer, model: paddle.nn.Layer,

Loading…
Cancel
Save