From b291c693868b65473f032699ff1f381131a1acfc Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 5 Oct 2021 09:41:38 +0000 Subject: [PATCH] add checkpoint to save parameters --- deepspeech/training/trainer.py | 6 +++--- deepspeech/utils/checkpoint.py | 14 +++++++------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 75652ead..c3e1bec8 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -163,9 +163,9 @@ class Trainer(): "epoch": self.epoch, "lr": self.optimizer.get_lr() }) - self.checkpoint.add_checkpoint(self.checkpoint_dir, self.iteration - if tag is None else tag, self.model, - self.optimizer, infos) + self.checkpoint.save_parameters(self.checkpoint_dir, self.iteration + if tag is None else tag, self.model, + self.optimizer, infos) def resume_or_scratch(self): """Resume from latest checkpoint at checkpoints in the output diff --git a/deepspeech/utils/checkpoint.py b/deepspeech/utils/checkpoint.py index 8e31edfa..796cafe0 100644 --- a/deepspeech/utils/checkpoint.py +++ b/deepspeech/utils/checkpoint.py @@ -39,13 +39,13 @@ class Checkpoint(): self.latest_n = latest_n self._save_all = (kbest_n == -1) - def add_checkpoint(self, - checkpoint_dir, - tag_or_iteration: Union[int, Text], - model: paddle.nn.Layer, - optimizer: Optimizer=None, - infos: dict=None, - metric_type="val_loss"): + def save_parameters(self, + checkpoint_dir, + tag_or_iteration: Union[int, Text], + model: paddle.nn.Layer, + optimizer: Optimizer=None, + infos: dict=None, + metric_type="val_loss"): """Save checkpoint in best_n and latest_n. Args: