Merge pull request #680 from PaddlePaddle/checkpoint

checkpoint refactor to save disk space
pull/695/head
Hui Zhang 4 years ago committed by GitHub
commit 717fe1e4bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -18,8 +18,8 @@ import paddle
from paddle import distributed as dist from paddle import distributed as dist
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from deepspeech.utils import checkpoint
from deepspeech.utils import mp_tools from deepspeech.utils import mp_tools
from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
__all__ = ["Trainer"] __all__ = ["Trainer"]
@ -139,7 +139,7 @@ class Trainer():
"epoch": self.epoch, "epoch": self.epoch,
"lr": self.optimizer.get_lr() "lr": self.optimizer.get_lr()
}) })
checkpoint.save_parameters(self.checkpoint_dir, self.iteration self.checkpoint.add_checkpoint(self.checkpoint_dir, self.iteration
if tag is None else tag, self.model, if tag is None else tag, self.model,
self.optimizer, infos) self.optimizer, infos)
@ -151,7 +151,7 @@ class Trainer():
resume training. resume training.
""" """
scratch = None scratch = None
infos = checkpoint.load_parameters( infos = self.checkpoint.load_latest_parameters(
self.model, self.model,
self.optimizer, self.optimizer,
checkpoint_dir=self.checkpoint_dir, checkpoint_dir=self.checkpoint_dir,
@ -180,7 +180,7 @@ class Trainer():
from_scratch = self.resume_or_scratch() from_scratch = self.resume_or_scratch()
if from_scratch: if from_scratch:
# save init model, i.e. 0 epoch # save init model, i.e. 0 epoch
self.save(tag='init') self.save(tag='init', infos=None)
self.lr_scheduler.step(self.iteration) self.lr_scheduler.step(self.iteration)
if self.parallel: if self.parallel:
@ -263,6 +263,10 @@ class Trainer():
self.checkpoint_dir = checkpoint_dir self.checkpoint_dir = checkpoint_dir
self.checkpoint = Checkpoint(
kbest_n=self.config.training.checkpoint.kbest_n,
latest_n=self.config.training.checkpoint.latest_n)
@mp_tools.rank_zero_only @mp_tools.rank_zero_only
def destory(self): def destory(self):
"""Close visualizer to avoid hanging after training""" """Close visualizer to avoid hanging after training"""

@ -11,9 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import glob
import json import json
import os import os
import re import re
from pathlib import Path
from typing import Union from typing import Union
import paddle import paddle
@ -25,17 +27,143 @@ from deepspeech.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
__all__ = ["load_parameters", "save_parameters"] __all__ = ["Checkpoint"]
def _load_latest_checkpoint(checkpoint_dir: str) -> int: class Checkpoint(object):
def __init__(self, kbest_n: int=5, latest_n: int=1):
self.best_records: Mapping[Path, float] = {}
self.latest_records = []
self.kbest_n = kbest_n
self.latest_n = latest_n
self._save_all = (kbest_n == -1)
def add_checkpoint(self,
checkpoint_dir,
tag_or_iteration,
model,
optimizer,
infos,
metric_type="val_loss"):
if (metric_type not in infos.keys()):
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
optimizer, infos)
return
#save best
if self._should_save_best(infos[metric_type]):
self._save_best_checkpoint_and_update(
infos[metric_type], checkpoint_dir, tag_or_iteration, model,
optimizer, infos)
#save latest
self._save_latest_checkpoint_and_update(
checkpoint_dir, tag_or_iteration, model, optimizer, infos)
if isinstance(tag_or_iteration, int):
self._save_checkpoint_record(checkpoint_dir, tag_or_iteration)
def load_latest_parameters(self,
model,
optimizer=None,
checkpoint_dir=None,
checkpoint_path=None):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
return self._load_parameters(model, optimizer, checkpoint_dir,
checkpoint_path, "checkpoint_latest")
def load_best_parameters(self,
model,
optimizer=None,
checkpoint_dir=None,
checkpoint_path=None):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
return self._load_parameters(model, optimizer, checkpoint_dir,
checkpoint_path, "checkpoint_best")
def _should_save_best(self, metric: float) -> bool:
if not self._best_full():
return True
# already full
worst_record_path = max(self.best_records, key=self.best_records.get)
# worst_record_path = max(self.best_records.iteritems(), key=operator.itemgetter(1))[0]
worst_metric = self.best_records[worst_record_path]
return metric < worst_metric
def _best_full(self):
return (not self._save_all) and len(self.best_records) == self.kbest_n
def _latest_full(self):
return len(self.latest_records) == self.latest_n
def _save_best_checkpoint_and_update(self, metric, checkpoint_dir,
tag_or_iteration, model, optimizer,
infos):
# remove the worst
if self._best_full():
worst_record_path = max(self.best_records,
key=self.best_records.get)
self.best_records.pop(worst_record_path)
if (worst_record_path not in self.latest_records):
logger.info(
"remove the worst checkpoint: {}".format(worst_record_path))
self._del_checkpoint(checkpoint_dir, worst_record_path)
# add the new one
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
optimizer, infos)
self.best_records[tag_or_iteration] = metric
def _save_latest_checkpoint_and_update(
self, checkpoint_dir, tag_or_iteration, model, optimizer, infos):
# remove the old
if self._latest_full():
to_del_fn = self.latest_records.pop(0)
if (to_del_fn not in self.best_records.keys()):
logger.info(
"remove the latest checkpoint: {}".format(to_del_fn))
self._del_checkpoint(checkpoint_dir, to_del_fn)
self.latest_records.append(tag_or_iteration)
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
optimizer, infos)
def _del_checkpoint(self, checkpoint_dir, tag_or_iteration):
checkpoint_path = os.path.join(checkpoint_dir,
"{}".format(tag_or_iteration))
for filename in glob.glob(checkpoint_path + ".*"):
os.remove(filename)
logger.info("delete file: {}".format(filename))
def _load_checkpoint_idx(self, checkpoint_record: str) -> int:
"""Get the iteration number corresponding to the latest saved checkpoint. """Get the iteration number corresponding to the latest saved checkpoint.
Args: Args:
checkpoint_dir (str): the directory where checkpoint is saved. checkpoint_path (str): the saved path of checkpoint.
Returns: Returns:
int: the latest iteration number. -1 for no checkpoint to load. int: the latest iteration number. -1 for no checkpoint to load.
""" """
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
if not os.path.isfile(checkpoint_record): if not os.path.isfile(checkpoint_record):
return -1 return -1
@ -45,8 +173,7 @@ def _load_latest_checkpoint(checkpoint_dir: str) -> int:
iteration = int(latest_checkpoint.split(":")[-1]) iteration = int(latest_checkpoint.split(":")[-1])
return iteration return iteration
def _save_checkpoint_record(self, checkpoint_dir: str, iteration: int):
def _save_record(checkpoint_dir: str, iteration: int):
"""Save the iteration number of the latest model to be checkpoint record. """Save the iteration number of the latest model to be checkpoint record.
Args: Args:
checkpoint_dir (str): the directory where checkpoint is saved. checkpoint_dir (str): the directory where checkpoint is saved.
@ -54,17 +181,24 @@ def _save_record(checkpoint_dir: str, iteration: int):
Returns: Returns:
None None
""" """
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint") checkpoint_record_latest = os.path.join(checkpoint_dir,
# Update the latest checkpoint index. "checkpoint_latest")
with open(checkpoint_record, "a+") as handle: checkpoint_record_best = os.path.join(checkpoint_dir, "checkpoint_best")
handle.write("model_checkpoint_path:{}\n".format(iteration))
with open(checkpoint_record_best, "w") as handle:
for i in self.best_records.keys():
handle.write("model_checkpoint_path:{}\n".format(i))
with open(checkpoint_record_latest, "w") as handle:
for i in self.latest_records:
handle.write("model_checkpoint_path:{}\n".format(i))
def load_parameters(model, def _load_parameters(self,
model,
optimizer=None, optimizer=None,
checkpoint_dir=None, checkpoint_dir=None,
checkpoint_path=None): checkpoint_path=None,
"""Load a specific model checkpoint from disk. checkpoint_file=None):
"""Load a last model checkpoint from disk.
Args: Args:
model (Layer): model to load parameters. model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed. optimizer (Optimizer, optional): optimizer to load states if needed.
@ -73,6 +207,7 @@ def load_parameters(model,
checkpoint_path (str, optional): if specified, load the checkpoint checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None. be ignored. Defaults to None.
checkpoint_file "checkpoint_latest" or "checkpoint_best"
Returns: Returns:
configs (dict): epoch or step, lr and other meta info should be saved. configs (dict): epoch or step, lr and other meta info should be saved.
""" """
@ -80,14 +215,16 @@ def load_parameters(model,
if checkpoint_path is not None: if checkpoint_path is not None:
tag = os.path.basename(checkpoint_path).split(":")[-1] tag = os.path.basename(checkpoint_path).split(":")[-1]
elif checkpoint_dir is not None: elif checkpoint_dir is not None and checkpoint_file is not None:
iteration = _load_latest_checkpoint(checkpoint_dir) checkpoint_record = os.path.join(checkpoint_dir, checkpoint_file)
iteration = self._load_checkpoint_idx(checkpoint_record)
if iteration == -1: if iteration == -1:
return configs return configs
checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration)) checkpoint_path = os.path.join(checkpoint_dir,
"{}".format(iteration))
else: else:
raise ValueError( raise ValueError(
"At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!" "At least one of 'checkpoint_dir' and 'checkpoint_file' and 'checkpoint_path' should be specified!"
) )
rank = dist.get_rank() rank = dist.get_rank()
@ -110,9 +247,9 @@ def load_parameters(model,
configs = json.load(fin) configs = json.load(fin)
return configs return configs
@mp_tools.rank_zero_only
@mp_tools.rank_zero_only def _save_parameters(self,
def save_parameters(checkpoint_dir: str, checkpoint_dir: str,
tag_or_iteration: Union[int, str], tag_or_iteration: Union[int, str],
model: paddle.nn.Layer, model: paddle.nn.Layer,
optimizer: Optimizer=None, optimizer: Optimizer=None,
@ -147,6 +284,3 @@ def save_parameters(checkpoint_dir: str,
with open(info_path, 'w') as fout: with open(info_path, 'w') as fout:
data = json.dumps(infos) data = json.dumps(infos)
fout.write(data) fout.write(data)
if isinstance(tag_or_iteration, int):
_save_record(checkpoint_dir, tag_or_iteration)

@ -48,6 +48,9 @@ training:
weight_decay: 1e-06 weight_decay: 1e-06
global_grad_clip: 3.0 global_grad_clip: 3.0
log_interval: 100 log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding: decoding:
batch_size: 128 batch_size: 128

@ -93,6 +93,9 @@ training:
warmup_steps: 25000 warmup_steps: 25000
lr_decay: 1.0 lr_decay: 1.0
log_interval: 100 log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding: decoding:

@ -88,6 +88,9 @@ training:
warmup_steps: 25000 warmup_steps: 25000
lr_decay: 1.0 lr_decay: 1.0
log_interval: 100 log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding: decoding:

@ -48,6 +48,9 @@ training:
weight_decay: 1e-06 weight_decay: 1e-06
global_grad_clip: 5.0 global_grad_clip: 5.0
log_interval: 100 log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding: decoding:
batch_size: 128 batch_size: 128

@ -93,6 +93,9 @@ training:
warmup_steps: 25000 warmup_steps: 25000
lr_decay: 1.0 lr_decay: 1.0
log_interval: 100 log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding: decoding:

@ -86,6 +86,9 @@ training:
warmup_steps: 25000 warmup_steps: 25000
lr_decay: 1.0 lr_decay: 1.0
log_interval: 100 log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding: decoding:

@ -89,6 +89,9 @@ training:
warmup_steps: 25000 warmup_steps: 25000
lr_decay: 1.0 lr_decay: 1.0
log_interval: 100 log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding: decoding:

@ -84,6 +84,9 @@ training:
warmup_steps: 25000 warmup_steps: 25000
lr_decay: 1.0 lr_decay: 1.0
log_interval: 100 log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding: decoding:

@ -43,12 +43,16 @@ model:
share_rnn_weights: True share_rnn_weights: True
training: training:
n_epoch: 24 n_epoch: 10
lr: 1e-5 lr: 1e-5
lr_decay: 1.0 lr_decay: 1.0
weight_decay: 1e-06 weight_decay: 1e-06
global_grad_clip: 5.0 global_grad_clip: 5.0
log_interval: 1 log_interval: 1
checkpoint:
kbest_n: 3
latest_n: 2
decoding: decoding:
batch_size: 128 batch_size: 128

@ -91,6 +91,9 @@ training:
warmup_steps: 25000 warmup_steps: 25000
lr_decay: 1.0 lr_decay: 1.0
log_interval: 1 log_interval: 1
checkpoint:
kbest_n: 10
latest_n: 1
decoding: decoding:

@ -84,6 +84,9 @@ training:
warmup_steps: 25000 warmup_steps: 25000
lr_decay: 1.0 lr_decay: 1.0
log_interval: 1 log_interval: 1
checkpoint:
kbest_n: 10
latest_n: 1
decoding: decoding:

@ -87,6 +87,9 @@ training:
warmup_steps: 25000 warmup_steps: 25000
lr_decay: 1.0 lr_decay: 1.0
log_interval: 1 log_interval: 1
checkpoint:
kbest_n: 10
latest_n: 1
decoding: decoding:

@ -84,6 +84,9 @@ training:
warmup_steps: 25000 warmup_steps: 25000
lr_decay: 1.0 lr_decay: 1.0
log_interval: 1 log_interval: 1
checkpoint:
kbest_n: 10
latest_n: 1
decoding: decoding:

Loading…
Cancel
Save