diff --git a/deepspeech/utils/checkpoint.py b/deepspeech/utils/checkpoint.py index 000fa87ba..8c5d8d605 100644 --- a/deepspeech/utils/checkpoint.py +++ b/deepspeech/utils/checkpoint.py @@ -24,7 +24,6 @@ from paddle.optimizer import Optimizer from deepspeech.utils import mp_tools from deepspeech.utils.log import Log -# import operator logger = Log(__name__).getlog() @@ -38,7 +37,7 @@ class Checkpoint(object): self.kbest_n = kbest_n self.latest_n = latest_n self._save_all = (kbest_n == -1) - + def add_checkpoint(self, checkpoint_dir, tag_or_iteration, @@ -64,10 +63,10 @@ class Checkpoint(object): self._save_checkpoint_record(checkpoint_dir, tag_or_iteration) def load_latest_parameters(self, - model, - optimizer=None, - checkpoint_dir=None, - checkpoint_path=None): + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None): """Load a last model checkpoint from disk. Args: model (Layer): model to load parameters. @@ -80,14 +79,14 @@ class Checkpoint(object): Returns: configs (dict): epoch or step, lr and other meta info should be saved. """ - return self._load_parameters(model, optimizer, checkpoint_dir, checkpoint_path, - "checkpoint_latest") + return self._load_parameters(model, optimizer, checkpoint_dir, + checkpoint_path, "checkpoint_latest") def load_best_parameters(self, - model, - optimizer=None, - checkpoint_dir=None, - checkpoint_path=None): + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None): """Load a last model checkpoint from disk. Args: model (Layer): model to load parameters. @@ -100,8 +99,8 @@ class Checkpoint(object): Returns: configs (dict): epoch or step, lr and other meta info should be saved. """ - return self._load_parameters(model, optimizer, checkpoint_dir, checkpoint_path, - "checkpoint_best") + return self._load_parameters(model, optimizer, checkpoint_dir, + checkpoint_path, "checkpoint_best") def _should_save_best(self, metric: float) -> bool: if not self._best_full(): @@ -248,7 +247,6 @@ class Checkpoint(object): configs = json.load(fin) return configs - @mp_tools.rank_zero_only def _save_parameters(self, checkpoint_dir: str, diff --git a/examples/aishell/s0/conf/deepspeech2.yaml b/examples/aishell/s0/conf/deepspeech2.yaml index 54ce240e7..27ede01bc 100644 --- a/examples/aishell/s0/conf/deepspeech2.yaml +++ b/examples/aishell/s0/conf/deepspeech2.yaml @@ -48,6 +48,9 @@ training: weight_decay: 1e-06 global_grad_clip: 3.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: batch_size: 128 diff --git a/examples/aishell/s1/conf/chunk_conformer.yaml b/examples/aishell/s1/conf/chunk_conformer.yaml index 904624c3c..1065dcb03 100644 --- a/examples/aishell/s1/conf/chunk_conformer.yaml +++ b/examples/aishell/s1/conf/chunk_conformer.yaml @@ -90,6 +90,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: diff --git a/examples/aishell/s1/conf/conformer.yaml b/examples/aishell/s1/conf/conformer.yaml index 116c91927..4b1430c58 100644 --- a/examples/aishell/s1/conf/conformer.yaml +++ b/examples/aishell/s1/conf/conformer.yaml @@ -88,6 +88,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: diff --git a/examples/librispeech/s0/conf/deepspeech2.yaml b/examples/librispeech/s0/conf/deepspeech2.yaml index d1746bff3..9f06a3802 100644 --- a/examples/librispeech/s0/conf/deepspeech2.yaml +++ b/examples/librispeech/s0/conf/deepspeech2.yaml @@ -43,6 +43,9 @@ training: weight_decay: 1e-06 global_grad_clip: 5.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: batch_size: 128 diff --git a/examples/librispeech/s1/conf/chunk_confermer.yaml b/examples/librispeech/s1/conf/chunk_confermer.yaml index ec945a188..979121639 100644 --- a/examples/librispeech/s1/conf/chunk_confermer.yaml +++ b/examples/librispeech/s1/conf/chunk_confermer.yaml @@ -91,6 +91,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: diff --git a/examples/librispeech/s1/conf/chunk_transformer.yaml b/examples/librispeech/s1/conf/chunk_transformer.yaml index 3939ffc68..dc2a51f92 100644 --- a/examples/librispeech/s1/conf/chunk_transformer.yaml +++ b/examples/librispeech/s1/conf/chunk_transformer.yaml @@ -84,6 +84,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: diff --git a/examples/librispeech/s1/conf/conformer.yaml b/examples/librispeech/s1/conf/conformer.yaml index 8f8bf4539..989af22a0 100644 --- a/examples/librispeech/s1/conf/conformer.yaml +++ b/examples/librispeech/s1/conf/conformer.yaml @@ -87,6 +87,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: diff --git a/examples/librispeech/s1/conf/transformer.yaml b/examples/librispeech/s1/conf/transformer.yaml index a094b0fba..931d7524b 100644 --- a/examples/librispeech/s1/conf/transformer.yaml +++ b/examples/librispeech/s1/conf/transformer.yaml @@ -82,6 +82,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: diff --git a/examples/tiny/s1/conf/chunk_confermer.yaml b/examples/tiny/s1/conf/chunk_confermer.yaml index 790066264..606300bdf 100644 --- a/examples/tiny/s1/conf/chunk_confermer.yaml +++ b/examples/tiny/s1/conf/chunk_confermer.yaml @@ -91,6 +91,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 1 + checkpoint: + kbest_n: 10 + latest_n: 1 decoding: diff --git a/examples/tiny/s1/conf/chunk_transformer.yaml b/examples/tiny/s1/conf/chunk_transformer.yaml index aa2b145a6..72d368485 100644 --- a/examples/tiny/s1/conf/chunk_transformer.yaml +++ b/examples/tiny/s1/conf/chunk_transformer.yaml @@ -84,6 +84,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 1 + checkpoint: + kbest_n: 10 + latest_n: 1 decoding: diff --git a/examples/tiny/s1/conf/conformer.yaml b/examples/tiny/s1/conf/conformer.yaml index 3813daa04..a6f730501 100644 --- a/examples/tiny/s1/conf/conformer.yaml +++ b/examples/tiny/s1/conf/conformer.yaml @@ -87,6 +87,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 1 + checkpoint: + kbest_n: 10 + latest_n: 1 decoding: diff --git a/examples/tiny/s1/conf/transformer.yaml b/examples/tiny/s1/conf/transformer.yaml index 250995faa..71cbdde7f 100644 --- a/examples/tiny/s1/conf/transformer.yaml +++ b/examples/tiny/s1/conf/transformer.yaml @@ -84,6 +84,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 1 + checkpoint: + kbest_n: 10 + latest_n: 1 decoding: