From 2820537fcc0adea30e271c69803c026c94be83cc Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 6 Jul 2021 02:39:41 +0000 Subject: [PATCH] fix load param --- deepspeech/utils/checkpoint.py | 42 +++++++++++++++++----------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/deepspeech/utils/checkpoint.py b/deepspeech/utils/checkpoint.py index a2f7e18a..a59f8be7 100644 --- a/deepspeech/utils/checkpoint.py +++ b/deepspeech/utils/checkpoint.py @@ -16,8 +16,8 @@ import json import os import re from pathlib import Path -from typing import Union from typing import Text +from typing import Union import paddle from paddle import distributed as dist @@ -51,7 +51,7 @@ class Checkpoint(): Args: checkpoint_dir (str): the directory where checkpoint is saved. tag_or_iteration (int or str): the latest iteration(step or epoch) number or tag. - model (Layer): model to be checkpointed. + model (Layer): model to be checkpointed. optimizer (Optimizer, optional): optimizer to be checkpointed. infos (dict or None)): any info you want to save. metric_type (str, optional): metric type. Defaults to "val_loss". @@ -72,22 +72,22 @@ class Checkpoint(): if isinstance(tag_or_iteration, int): self._save_checkpoint_record(checkpoint_dir, tag_or_iteration) - + def load_parameters(self, - model, - optimizer=None, - checkpoint_dir=None, - checkpoint_path=None, - record_file="checkpoint_latest"): - """Load a last model checkpoint from disk. + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None, + record_file="checkpoint_latest"): + """Load a last model checkpoint from disk. Args: model (Layer): model to load parameters. optimizer (Optimizer, optional): optimizer to load states if needed. Defaults to None. checkpoint_dir (str, optional): the directory where checkpoint is saved. checkpoint_path (str, optional): if specified, load the checkpoint - stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will - be ignored. Defaults to None. + stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will + be ignored. Defaults to None. record_file "checkpoint_latest" or "checkpoint_best" Returns: configs (dict): epoch or step, lr and other meta info should be saved. @@ -134,40 +134,40 @@ class Checkpoint(): optimizer=None, checkpoint_dir=None, checkpoint_path=None): - """Load a last model checkpoint from disk. + """Load a last model checkpoint from disk. Args: model (Layer): model to load parameters. optimizer (Optimizer, optional): optimizer to load states if needed. Defaults to None. checkpoint_dir (str, optional): the directory where checkpoint is saved. checkpoint_path (str, optional): if specified, load the checkpoint - stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will - be ignored. Defaults to None. + stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will + be ignored. Defaults to None. Returns: configs (dict): epoch or step, lr and other meta info should be saved. """ - return self._load_parameters(model, optimizer, checkpoint_dir, - checkpoint_path, "checkpoint_latest") + return self.load_parameters(model, optimizer, checkpoint_dir, + checkpoint_path, "checkpoint_latest") def load_best_parameters(self, model, optimizer=None, checkpoint_dir=None, checkpoint_path=None): - """Load a last model checkpoint from disk. + """Load a last model checkpoint from disk. Args: model (Layer): model to load parameters. optimizer (Optimizer, optional): optimizer to load states if needed. Defaults to None. checkpoint_dir (str, optional): the directory where checkpoint is saved. checkpoint_path (str, optional): if specified, load the checkpoint - stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will - be ignored. Defaults to None. + stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will + be ignored. Defaults to None. Returns: configs (dict): epoch or step, lr and other meta info should be saved. """ - return self._load_parameters(model, optimizer, checkpoint_dir, - checkpoint_path, "checkpoint_best") + return self.load_parameters(model, optimizer, checkpoint_dir, + checkpoint_path, "checkpoint_best") def _should_save_best(self, metric: float) -> bool: if not self._best_full():