|
|
|
@ -24,7 +24,6 @@ from paddle.optimizer import Optimizer
|
|
|
|
|
|
|
|
|
|
from deepspeech.utils import mp_tools
|
|
|
|
|
from deepspeech.utils.log import Log
|
|
|
|
|
# import operator
|
|
|
|
|
|
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
|
|
|
|
|
@ -64,10 +63,10 @@ class Checkpoint(object):
|
|
|
|
|
self._save_checkpoint_record(checkpoint_dir, tag_or_iteration)
|
|
|
|
|
|
|
|
|
|
def load_latest_parameters(self,
|
|
|
|
|
model,
|
|
|
|
|
optimizer=None,
|
|
|
|
|
checkpoint_dir=None,
|
|
|
|
|
checkpoint_path=None):
|
|
|
|
|
model,
|
|
|
|
|
optimizer=None,
|
|
|
|
|
checkpoint_dir=None,
|
|
|
|
|
checkpoint_path=None):
|
|
|
|
|
"""Load a last model checkpoint from disk.
|
|
|
|
|
Args:
|
|
|
|
|
model (Layer): model to load parameters.
|
|
|
|
@ -80,14 +79,14 @@ class Checkpoint(object):
|
|
|
|
|
Returns:
|
|
|
|
|
configs (dict): epoch or step, lr and other meta info should be saved.
|
|
|
|
|
"""
|
|
|
|
|
return self._load_parameters(model, optimizer, checkpoint_dir, checkpoint_path,
|
|
|
|
|
"checkpoint_latest")
|
|
|
|
|
return self._load_parameters(model, optimizer, checkpoint_dir,
|
|
|
|
|
checkpoint_path, "checkpoint_latest")
|
|
|
|
|
|
|
|
|
|
def load_best_parameters(self,
|
|
|
|
|
model,
|
|
|
|
|
optimizer=None,
|
|
|
|
|
checkpoint_dir=None,
|
|
|
|
|
checkpoint_path=None):
|
|
|
|
|
model,
|
|
|
|
|
optimizer=None,
|
|
|
|
|
checkpoint_dir=None,
|
|
|
|
|
checkpoint_path=None):
|
|
|
|
|
"""Load a last model checkpoint from disk.
|
|
|
|
|
Args:
|
|
|
|
|
model (Layer): model to load parameters.
|
|
|
|
@ -100,8 +99,8 @@ class Checkpoint(object):
|
|
|
|
|
Returns:
|
|
|
|
|
configs (dict): epoch or step, lr and other meta info should be saved.
|
|
|
|
|
"""
|
|
|
|
|
return self._load_parameters(model, optimizer, checkpoint_dir, checkpoint_path,
|
|
|
|
|
"checkpoint_best")
|
|
|
|
|
return self._load_parameters(model, optimizer, checkpoint_dir,
|
|
|
|
|
checkpoint_path, "checkpoint_best")
|
|
|
|
|
|
|
|
|
|
def _should_save_best(self, metric: float) -> bool:
|
|
|
|
|
if not self._best_full():
|
|
|
|
@ -248,7 +247,6 @@ class Checkpoint(object):
|
|
|
|
|
configs = json.load(fin)
|
|
|
|
|
return configs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@mp_tools.rank_zero_only
|
|
|
|
|
def _save_parameters(self,
|
|
|
|
|
checkpoint_dir: str,
|
|
|
|
|