pull/680/head
Haoxin Ma 4 years ago
parent 91e70a2857
commit 16210c0587

@ -64,7 +64,7 @@ class Trainer():
The parsed command line arguments. The parsed command line arguments.
Examples Examples
-------- --------
>>> def p(config, args): >>> def main_sp(config, args):
>>> exp = Trainer(config, args) >>> exp = Trainer(config, args)
>>> exp.setup() >>> exp.setup()
>>> exp.run() >>> exp.run()

@ -11,9 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import glob
import json import json
import os import os
import re import re
from pathlib import Path
from typing import Union from typing import Union
import paddle import paddle
@ -22,19 +24,15 @@ from paddle.optimizer import Optimizer
from deepspeech.utils import mp_tools from deepspeech.utils import mp_tools
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
import glob
# import operator # import operator
from pathlib import Path
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
__all__ = ["Checkpoint"] __all__ = ["Checkpoint"]
class Checkpoint(object): class Checkpoint(object):
def __init__(self, def __init__(self, kbest_n: int=5, latest_n: int=1):
kbest_n: int=5,
latest_n: int=1):
self.best_records: Mapping[Path, float] = {} self.best_records: Mapping[Path, float] = {}
self.latest_records = [] self.latest_records = []
self.kbest_n = kbest_n self.kbest_n = kbest_n
@ -57,65 +55,69 @@ class Checkpoint(object):
def latest_full(self): def latest_full(self):
return len(self.latest_records) == self.latest_n return len(self.latest_records) == self.latest_n
def add_checkpoint(self, checkpoint_dir, tag_or_iteration, def add_checkpoint(self,
model, optimizer, infos, metric_type = "val_loss"): checkpoint_dir,
if(metric_type not in infos.keys()): tag_or_iteration,
self.save_parameters(checkpoint_dir, tag_or_iteration, model,
model, optimizer, infos) optimizer,
infos,
metric_type="val_loss"):
if (metric_type not in infos.keys()):
self.save_parameters(checkpoint_dir, tag_or_iteration, model,
optimizer, infos)
return return
#save best #save best
if self.should_save_best(infos[metric_type]): if self.should_save_best(infos[metric_type]):
self.save_best_checkpoint_and_update(infos[metric_type], self.save_best_checkpoint_and_update(
checkpoint_dir, tag_or_iteration, infos[metric_type], checkpoint_dir, tag_or_iteration, model,
model, optimizer, infos) optimizer, infos)
#save latest #save latest
self.save_latest_checkpoint_and_update(checkpoint_dir, tag_or_iteration, self.save_latest_checkpoint_and_update(checkpoint_dir, tag_or_iteration,
model, optimizer, infos) model, optimizer, infos)
if isinstance(tag_or_iteration, int): if isinstance(tag_or_iteration, int):
self.save_checkpoint_record(checkpoint_dir, tag_or_iteration) self.save_checkpoint_record(checkpoint_dir, tag_or_iteration)
def save_best_checkpoint_and_update(self, metric, def save_best_checkpoint_and_update(self, metric, checkpoint_dir,
checkpoint_dir, tag_or_iteration, tag_or_iteration, model, optimizer,
model, optimizer, infos): infos):
# remove the worst # remove the worst
if self.best_full(): if self.best_full():
worst_record_path = max(self.best_records, worst_record_path = max(self.best_records,
key=self.best_records.get) key=self.best_records.get)
self.best_records.pop(worst_record_path) self.best_records.pop(worst_record_path)
if(worst_record_path not in self.latest_records): if (worst_record_path not in self.latest_records):
logger.info("remove the worst checkpoint: {}".format(worst_record_path)) logger.info(
"remove the worst checkpoint: {}".format(worst_record_path))
self.del_checkpoint(checkpoint_dir, worst_record_path) self.del_checkpoint(checkpoint_dir, worst_record_path)
# add the new one # add the new one
self.save_parameters(checkpoint_dir, tag_or_iteration, self.save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer,
model, optimizer, infos) infos)
self.best_records[tag_or_iteration] = metric self.best_records[tag_or_iteration] = metric
def save_latest_checkpoint_and_update(self, checkpoint_dir, tag_or_iteration, def save_latest_checkpoint_and_update(
model, optimizer, infos): self, checkpoint_dir, tag_or_iteration, model, optimizer, infos):
# remove the old # remove the old
if self.latest_full(): if self.latest_full():
to_del_fn = self.latest_records.pop(0) to_del_fn = self.latest_records.pop(0)
if(to_del_fn not in self.best_records.keys()): if (to_del_fn not in self.best_records.keys()):
logger.info("remove the latest checkpoint: {}".format(to_del_fn)) logger.info(
"remove the latest checkpoint: {}".format(to_del_fn))
self.del_checkpoint(checkpoint_dir, to_del_fn) self.del_checkpoint(checkpoint_dir, to_del_fn)
self.latest_records.append(tag_or_iteration) self.latest_records.append(tag_or_iteration)
self.save_parameters(checkpoint_dir, tag_or_iteration, self.save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer,
model, optimizer, infos) infos)
def del_checkpoint(self, checkpoint_dir, tag_or_iteration): def del_checkpoint(self, checkpoint_dir, tag_or_iteration):
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_path = os.path.join(checkpoint_dir,
"{}".format(tag_or_iteration)) "{}".format(tag_or_iteration))
for filename in glob.glob(checkpoint_path+".*"): for filename in glob.glob(checkpoint_path + ".*"):
os.remove(filename) os.remove(filename)
logger.info("delete file: {}".format(filename)) logger.info("delete file: {}".format(filename))
def load_checkpoint_idx(self, checkpoint_record: str) -> int: def load_checkpoint_idx(self, checkpoint_record: str) -> int:
"""Get the iteration number corresponding to the latest saved checkpoint. """Get the iteration number corresponding to the latest saved checkpoint.
Args: Args:
@ -132,7 +134,6 @@ class Checkpoint(object):
iteration = int(latest_checkpoint.split(":")[-1]) iteration = int(latest_checkpoint.split(":")[-1])
return iteration return iteration
def save_checkpoint_record(self, checkpoint_dir: str, iteration: int): def save_checkpoint_record(self, checkpoint_dir: str, iteration: int):
"""Save the iteration number of the latest model to be checkpoint record. """Save the iteration number of the latest model to be checkpoint record.
Args: Args:
@ -141,7 +142,8 @@ class Checkpoint(object):
Returns: Returns:
None None
""" """
checkpoint_record_latest = os.path.join(checkpoint_dir, "checkpoint_latest") checkpoint_record_latest = os.path.join(checkpoint_dir,
"checkpoint_latest")
checkpoint_record_best = os.path.join(checkpoint_dir, "checkpoint_best") checkpoint_record_best = os.path.join(checkpoint_dir, "checkpoint_best")
with open(checkpoint_record_best, "w") as handle: with open(checkpoint_record_best, "w") as handle:
@ -151,11 +153,11 @@ class Checkpoint(object):
for i in self.latest_records: for i in self.latest_records:
handle.write("model_checkpoint_path:{}\n".format(i)) handle.write("model_checkpoint_path:{}\n".format(i))
def load_last_parameters(self,
def load_last_parameters(self, model, model,
optimizer=None, optimizer=None,
checkpoint_dir=None, checkpoint_dir=None,
checkpoint_path=None): checkpoint_path=None):
"""Load a last model checkpoint from disk. """Load a last model checkpoint from disk.
Args: Args:
model (Layer): model to load parameters. model (Layer): model to load parameters.
@ -173,11 +175,13 @@ class Checkpoint(object):
if checkpoint_path is not None: if checkpoint_path is not None:
tag = os.path.basename(checkpoint_path).split(":")[-1] tag = os.path.basename(checkpoint_path).split(":")[-1]
elif checkpoint_dir is not None: elif checkpoint_dir is not None:
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint_latest") checkpoint_record = os.path.join(checkpoint_dir,
"checkpoint_latest")
iteration = self.load_checkpoint_idx(checkpoint_record) iteration = self.load_checkpoint_idx(checkpoint_record)
if iteration == -1: if iteration == -1:
return configs return configs
checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration)) checkpoint_path = os.path.join(checkpoint_dir,
"{}".format(iteration))
else: else:
raise ValueError( raise ValueError(
"At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!" "At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!"
@ -203,11 +207,11 @@ class Checkpoint(object):
configs = json.load(fin) configs = json.load(fin)
return configs return configs
def load_best_parameters(self,
def load_best_parameters(self, model, model,
optimizer=None, optimizer=None,
checkpoint_dir=None, checkpoint_dir=None,
checkpoint_path=None): checkpoint_path=None):
"""Load a last model checkpoint from disk. """Load a last model checkpoint from disk.
Args: Args:
model (Layer): model to load parameters. model (Layer): model to load parameters.
@ -229,7 +233,8 @@ class Checkpoint(object):
iteration = self.load_checkpoint_idx(checkpoint_record) iteration = self.load_checkpoint_idx(checkpoint_record)
if iteration == -1: if iteration == -1:
return configs return configs
checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration)) checkpoint_path = os.path.join(checkpoint_dir,
"{}".format(iteration))
else: else:
raise ValueError( raise ValueError(
"At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!" "At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!"
@ -255,10 +260,9 @@ class Checkpoint(object):
configs = json.load(fin) configs = json.load(fin)
return configs return configs
@mp_tools.rank_zero_only @mp_tools.rank_zero_only
def save_parameters(self, checkpoint_dir: str, def save_parameters(self,
checkpoint_dir: str,
tag_or_iteration: Union[int, str], tag_or_iteration: Union[int, str],
model: paddle.nn.Layer, model: paddle.nn.Layer,
optimizer: Optimizer=None, optimizer: Optimizer=None,
@ -275,7 +279,7 @@ class Checkpoint(object):
None None
""" """
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_path = os.path.join(checkpoint_dir,
"{}".format(tag_or_iteration)) "{}".format(tag_or_iteration))
model_dict = model.state_dict() model_dict = model.state_dict()
params_path = checkpoint_path + ".pdparams" params_path = checkpoint_path + ".pdparams"
@ -293,4 +297,3 @@ class Checkpoint(object):
with open(info_path, 'w') as fout: with open(info_path, 'w') as fout:
data = json.dumps(infos) data = json.dumps(infos)
fout.write(data) fout.write(data)

Loading…
Cancel
Save