Merge pull request #697 from PaddlePaddle/ckpt

fix ckpt load
pull/700/head
Hui Zhang 3 years ago committed by GitHub
commit a37192c809
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -599,26 +599,26 @@ class U2BaseModel(nn.Module):
best_index = i best_index = i
return hyps[best_index][0] return hyps[best_index][0]
@jit.export #@jit.export
def subsampling_rate(self) -> int: def subsampling_rate(self) -> int:
""" Export interface for c++ call, return subsampling_rate of the """ Export interface for c++ call, return subsampling_rate of the
model model
""" """
return self.encoder.embed.subsampling_rate return self.encoder.embed.subsampling_rate
@jit.export #@jit.export
def right_context(self) -> int: def right_context(self) -> int:
""" Export interface for c++ call, return right_context of the model """ Export interface for c++ call, return right_context of the model
""" """
return self.encoder.embed.right_context return self.encoder.embed.right_context
@jit.export #@jit.export
def sos_symbol(self) -> int: def sos_symbol(self) -> int:
""" Export interface for c++ call, return sos symbol id of the model """ Export interface for c++ call, return sos symbol id of the model
""" """
return self.sos return self.sos
@jit.export #@jit.export
def eos_symbol(self) -> int: def eos_symbol(self) -> int:
""" Export interface for c++ call, return eos symbol id of the model """ Export interface for c++ call, return eos symbol id of the model
""" """
@ -654,12 +654,14 @@ class U2BaseModel(nn.Module):
xs, offset, required_cache_size, subsampling_cache, xs, offset, required_cache_size, subsampling_cache,
elayers_output_cache, conformer_cnn_cache) elayers_output_cache, conformer_cnn_cache)
@jit.export # @jit.export([
# paddle.static.InputSpec(shape=[1, None, feat_dim],dtype='float32'), # audio feat, [B,T,D]
# ])
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
""" Export interface for c++ call, apply linear transform and log """ Export interface for c++ call, apply linear transform and log
softmax before ctc softmax before ctc
Args: Args:
xs (paddle.Tensor): encoder output xs (paddle.Tensor): encoder output, (B, T, D)
Returns: Returns:
paddle.Tensor: activation before ctc paddle.Tensor: activation before ctc
""" """
@ -894,7 +896,7 @@ class U2Model(U2BaseModel):
model = cls.from_config(config) model = cls.from_config(config)
if checkpoint_path: if checkpoint_path:
infos = checkpoint.load_parameters( infos = checkpoint.Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path) model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}") logger.info(f"checkpoint info: {infos}")
layer_tools.summary(model) layer_tools.summary(model)

@ -17,6 +17,7 @@ import os
import re import re
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
from typing import Text
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
@ -30,7 +31,7 @@ logger = Log(__name__).getlog()
__all__ = ["Checkpoint"] __all__ = ["Checkpoint"]
class Checkpoint(object): class Checkpoint():
def __init__(self, kbest_n: int=5, latest_n: int=1): def __init__(self, kbest_n: int=5, latest_n: int=1):
self.best_records: Mapping[Path, float] = {} self.best_records: Mapping[Path, float] = {}
self.latest_records = [] self.latest_records = []
@ -40,11 +41,21 @@ class Checkpoint(object):
def add_checkpoint(self, def add_checkpoint(self,
checkpoint_dir, checkpoint_dir,
tag_or_iteration, tag_or_iteration: Union[int, Text],
model, model: paddle.nn.Layer,
optimizer, optimizer: Optimizer=None,
infos, infos: dict=None,
metric_type="val_loss"): metric_type="val_loss"):
"""Save checkpoint in best_n and latest_n.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
tag_or_iteration (int or str): the latest iteration(step or epoch) number or tag.
model (Layer): model to be checkpointed.
optimizer (Optimizer, optional): optimizer to be checkpointed.
infos (dict or None)): any info you want to save.
metric_type (str, optional): metric type. Defaults to "val_loss".
"""
if (metric_type not in infos.keys()): if (metric_type not in infos.keys()):
self._save_parameters(checkpoint_dir, tag_or_iteration, model, self._save_parameters(checkpoint_dir, tag_or_iteration, model,
optimizer, infos) optimizer, infos)
@ -61,6 +72,62 @@ class Checkpoint(object):
if isinstance(tag_or_iteration, int): if isinstance(tag_or_iteration, int):
self._save_checkpoint_record(checkpoint_dir, tag_or_iteration) self._save_checkpoint_record(checkpoint_dir, tag_or_iteration)
def load_parameters(self,
model,
optimizer=None,
checkpoint_dir=None,
checkpoint_path=None,
record_file="checkpoint_latest"):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
record_file "checkpoint_latest" or "checkpoint_best"
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
configs = {}
if checkpoint_path is not None:
pass
elif checkpoint_dir is not None and record_file is not None:
# load checkpint from record file
checkpoint_record = os.path.join(checkpoint_dir, record_file)
iteration = self._load_checkpoint_idx(checkpoint_record)
if iteration == -1:
return configs
checkpoint_path = os.path.join(checkpoint_dir,
"{}".format(iteration))
else:
raise ValueError(
"At least one of 'checkpoint_path' or 'checkpoint_dir' should be specified!"
)
rank = dist.get_rank()
params_path = checkpoint_path + ".pdparams"
model_dict = paddle.load(params_path)
model.set_state_dict(model_dict)
logger.info("Rank {}: loaded model from {}".format(rank, params_path))
optimizer_path = checkpoint_path + ".pdopt"
if optimizer and os.path.isfile(optimizer_path):
optimizer_dict = paddle.load(optimizer_path)
optimizer.set_state_dict(optimizer_dict)
logger.info("Rank {}: loaded optimizer state from {}".format(
rank, optimizer_path))
info_path = re.sub('.pdparams$', '.json', params_path)
if os.path.exists(info_path):
with open(info_path, 'r') as fin:
configs = json.load(fin)
return configs
def load_latest_parameters(self, def load_latest_parameters(self,
model, model,
@ -192,61 +259,6 @@ class Checkpoint(object):
for i in self.latest_records: for i in self.latest_records:
handle.write("model_checkpoint_path:{}\n".format(i)) handle.write("model_checkpoint_path:{}\n".format(i))
def _load_parameters(self,
model,
optimizer=None,
checkpoint_dir=None,
checkpoint_path=None,
checkpoint_file=None):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
checkpoint_file "checkpoint_latest" or "checkpoint_best"
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
configs = {}
if checkpoint_path is not None:
tag = os.path.basename(checkpoint_path).split(":")[-1]
elif checkpoint_dir is not None and checkpoint_file is not None:
checkpoint_record = os.path.join(checkpoint_dir, checkpoint_file)
iteration = self._load_checkpoint_idx(checkpoint_record)
if iteration == -1:
return configs
checkpoint_path = os.path.join(checkpoint_dir,
"{}".format(iteration))
else:
raise ValueError(
"At least one of 'checkpoint_dir' and 'checkpoint_file' and 'checkpoint_path' should be specified!"
)
rank = dist.get_rank()
params_path = checkpoint_path + ".pdparams"
model_dict = paddle.load(params_path)
model.set_state_dict(model_dict)
logger.info("Rank {}: loaded model from {}".format(rank, params_path))
optimizer_path = checkpoint_path + ".pdopt"
if optimizer and os.path.isfile(optimizer_path):
optimizer_dict = paddle.load(optimizer_path)
optimizer.set_state_dict(optimizer_dict)
logger.info("Rank {}: loaded optimizer state from {}".format(
rank, optimizer_path))
info_path = re.sub('.pdparams$', '.json', params_path)
if os.path.exists(info_path):
with open(info_path, 'r') as fin:
configs = json.load(fin)
return configs
@mp_tools.rank_zero_only @mp_tools.rank_zero_only
def _save_parameters(self, def _save_parameters(self,
checkpoint_dir: str, checkpoint_dir: str,

@ -40,5 +40,5 @@ fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# export ckpt avg_n # export ckpt avg_n
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit CUDA_VISIBLE_DEVICES=0 ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
fi fi

Loading…
Cancel
Save