diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 56de32617..246175e3f 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -18,7 +18,7 @@ import paddle from paddle import distributed as dist from tensorboardX import SummaryWriter -from deepspeech.utils import checkpoint +from deepspeech.utils.checkpoint import KBestCheckpoint from deepspeech.utils import mp_tools from deepspeech.utils.log import Log @@ -139,9 +139,12 @@ class Trainer(): "epoch": self.epoch, "lr": self.optimizer.get_lr() }) - checkpoint.save_parameters(self.checkpoint_dir, self.iteration + self.checkpoint.add_checkpoint(self.checkpoint_dir, self.iteration if tag is None else tag, self.model, self.optimizer, infos) + # checkpoint.save_parameters(self.checkpoint_dir, self.iteration + # if tag is None else tag, self.model, + # self.optimizer, infos) def resume_or_scratch(self): """Resume from latest checkpoint at checkpoints in the output @@ -151,7 +154,7 @@ class Trainer(): resume training. """ scratch = None - infos = checkpoint.load_parameters( + infos = self.checkpoint.load_parameters( self.model, self.optimizer, checkpoint_dir=self.checkpoint_dir, @@ -180,7 +183,7 @@ class Trainer(): from_scratch = self.resume_or_scratch() if from_scratch: # save init model, i.e. 0 epoch - self.save(tag='init') + self.save(tag='init', infos=None) self.lr_scheduler.step(self.iteration) if self.parallel: @@ -263,6 +266,9 @@ class Trainer(): self.checkpoint_dir = checkpoint_dir + self.checkpoint = KBestCheckpoint(max_size=self.config.training.max_epoch, + last_size=self.config.training.last_epoch) + @mp_tools.rank_zero_only def destory(self): """Close visualizer to avoid hanging after training""" diff --git a/deepspeech/utils/checkpoint.py b/deepspeech/utils/checkpoint.py index 8ede6b8fd..ef73eb705 100644 --- a/deepspeech/utils/checkpoint.py +++ b/deepspeech/utils/checkpoint.py @@ -23,130 +23,226 @@ from paddle.optimizer import Optimizer from deepspeech.utils import mp_tools from deepspeech.utils.log import Log +import glob + logger = Log(__name__).getlog() __all__ = ["load_parameters", "save_parameters"] +class KBestCheckpoint(object): + def __init__(self, + max_size: int=5, + last_size: int=1): + self.best_records: Mapping[Path, float] = {} + self.last_records = [] + self.max_size = max_size + self.last_size = last_size + self._save_all = (max_size == -1) + + def should_save_best(self, metric: float) -> bool: + if not self.best_full(): + return True + + # already full + worst_record_path = max(self.best_records, key=self.best_records.get) + worst_metric = self.best_records[worst_record_path] + return metric < worst_metric + + def best_full(self): + return (not self._save_all) and len(self.best_records) == self.max_size + + def last_full(self): + return len(self.last_records) == self.last_size + + def add_checkpoint(self, + checkpoint_dir, tag_or_iteration, + model, optimizer, infos): + if("val_loss" not in infos.keys()): + self.save_parameters(checkpoint_dir, tag_or_iteration, + model, optimizer, infos) + return + + #save best + if self.should_save_best(infos["val_loss"]): + self.save_checkpoint_and_update(infos["val_loss"], + checkpoint_dir, tag_or_iteration, + model, optimizer, infos) + #save last + self.save_last_checkpoint_and_update(checkpoint_dir, tag_or_iteration, + model, optimizer, infos) + + if isinstance(tag_or_iteration, int): + self._save_record(checkpoint_dir, tag_or_iteration) + + def save_checkpoint_and_update(self, metric, + checkpoint_dir, tag_or_iteration, + model, optimizer, infos): + # remove the worst + if self.best_full(): + worst_record_path = max(self.best_records, + key=self.best_records.get) + self.best_records.pop(worst_record_path) + if(worst_record_path not in self.last_records): + print('----to remove (best)----') + print(worst_record_path) + self.del_checkpoint(checkpoint_dir, worst_record_path) + + # add the new one + self.save_parameters(checkpoint_dir, tag_or_iteration, + model, optimizer, infos) + self.best_records[tag_or_iteration] = metric + + def save_last_checkpoint_and_update(self, checkpoint_dir, tag_or_iteration, + model, optimizer, infos): + # remove the old + if self.last_full(): + to_del_fn = self.last_records.pop(0) + if(to_del_fn not in self.best_records.keys()): + print('----to remove (last)----') + print(to_del_fn) + self.del_checkpoint(checkpoint_dir, to_del_fn) + self.last_records.append(tag_or_iteration) + + self.save_parameters(checkpoint_dir, tag_or_iteration, + model, optimizer, infos) + # with open(os.path.join(checkpoint_dir, "checkpoint"), "w") as handle: + # for iteration in self.best_records + # handle.write("model_checkpoint_path:{}\n".format(iteration)) + + + def del_checkpoint(self, checkpoint_dir, tag_or_iteration): + checkpoint_path = os.path.join(checkpoint_dir, + "{}".format(tag_or_iteration)) + for filename in glob.glob(checkpoint_path+".*"): + os.remove(filename) + print("delete file: "+filename) + + + + def _load_latest_checkpoint(self, checkpoint_dir: str) -> int: + """Get the iteration number corresponding to the latest saved checkpoint. + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + Returns: + int: the latest iteration number. -1 for no checkpoint to load. + """ + checkpoint_record = os.path.join(checkpoint_dir, "checkpoint_last") + if not os.path.isfile(checkpoint_record): + return -1 + + # Fetch the latest checkpoint index. + with open(checkpoint_record, "rt") as handle: + latest_checkpoint = handle.readlines()[-1].strip() + iteration = int(latest_checkpoint.split(":")[-1]) + return iteration + + + def _save_record(self, checkpoint_dir: str, iteration: int): + """Save the iteration number of the latest model to be checkpoint record. + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + iteration (int): the latest iteration number. + Returns: + None + """ + checkpoint_record_last = os.path.join(checkpoint_dir, "checkpoint_last") + checkpoint_record_best = os.path.join(checkpoint_dir, "checkpoint_best") + # Update the latest checkpoint index. + # with open(checkpoint_record, "a+") as handle: + # handle.write("model_checkpoint_path:{}\n".format(iteration)) + with open(checkpoint_record_best, "w") as handle: + for i in self.best_records.keys(): + handle.write("model_checkpoint_path:{}\n".format(i)) + with open(checkpoint_record_last, "w") as handle: + for i in self.last_records: + handle.write("model_checkpoint_path:{}\n".format(i)) + + + def load_parameters(self, model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None): + """Load a specific model checkpoint from disk. + Args: + model (Layer): model to load parameters. + optimizer (Optimizer, optional): optimizer to load states if needed. + Defaults to None. + checkpoint_dir (str, optional): the directory where checkpoint is saved. + checkpoint_path (str, optional): if specified, load the checkpoint + stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will + be ignored. Defaults to None. + Returns: + configs (dict): epoch or step, lr and other meta info should be saved. + """ + configs = {} + + if checkpoint_path is not None: + tag = os.path.basename(checkpoint_path).split(":")[-1] + elif checkpoint_dir is not None: + iteration = self._load_latest_checkpoint(checkpoint_dir) + if iteration == -1: + return configs + checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration)) + else: + raise ValueError( + "At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!" + ) + + rank = dist.get_rank() + + params_path = checkpoint_path + ".pdparams" + model_dict = paddle.load(params_path) + model.set_state_dict(model_dict) + logger.info("Rank {}: loaded model from {}".format(rank, params_path)) -def _load_latest_checkpoint(checkpoint_dir: str) -> int: - """Get the iteration number corresponding to the latest saved checkpoint. - Args: - checkpoint_dir (str): the directory where checkpoint is saved. - Returns: - int: the latest iteration number. -1 for no checkpoint to load. - """ - checkpoint_record = os.path.join(checkpoint_dir, "checkpoint") - if not os.path.isfile(checkpoint_record): - return -1 - - # Fetch the latest checkpoint index. - with open(checkpoint_record, "rt") as handle: - latest_checkpoint = handle.readlines()[-1].strip() - iteration = int(latest_checkpoint.split(":")[-1]) - return iteration - - -def _save_record(checkpoint_dir: str, iteration: int): - """Save the iteration number of the latest model to be checkpoint record. - Args: - checkpoint_dir (str): the directory where checkpoint is saved. - iteration (int): the latest iteration number. - Returns: - None - """ - checkpoint_record = os.path.join(checkpoint_dir, "checkpoint") - # Update the latest checkpoint index. - with open(checkpoint_record, "a+") as handle: - handle.write("model_checkpoint_path:{}\n".format(iteration)) - - -def load_parameters(model, - optimizer=None, - checkpoint_dir=None, - checkpoint_path=None): - """Load a specific model checkpoint from disk. - Args: - model (Layer): model to load parameters. - optimizer (Optimizer, optional): optimizer to load states if needed. - Defaults to None. - checkpoint_dir (str, optional): the directory where checkpoint is saved. - checkpoint_path (str, optional): if specified, load the checkpoint - stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will - be ignored. Defaults to None. - Returns: - configs (dict): epoch or step, lr and other meta info should be saved. - """ - configs = {} - - if checkpoint_path is not None: - tag = os.path.basename(checkpoint_path).split(":")[-1] - elif checkpoint_dir is not None: - iteration = _load_latest_checkpoint(checkpoint_dir) - if iteration == -1: - return configs - checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration)) - else: - raise ValueError( - "At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!" - ) - - rank = dist.get_rank() - - params_path = checkpoint_path + ".pdparams" - model_dict = paddle.load(params_path) - model.set_state_dict(model_dict) - logger.info("Rank {}: loaded model from {}".format(rank, params_path)) - - optimizer_path = checkpoint_path + ".pdopt" - if optimizer and os.path.isfile(optimizer_path): - optimizer_dict = paddle.load(optimizer_path) - optimizer.set_state_dict(optimizer_dict) - logger.info("Rank {}: loaded optimizer state from {}".format( - rank, optimizer_path)) - - info_path = re.sub('.pdparams$', '.json', params_path) - if os.path.exists(info_path): - with open(info_path, 'r') as fin: - configs = json.load(fin) - return configs - - -@mp_tools.rank_zero_only -def save_parameters(checkpoint_dir: str, - tag_or_iteration: Union[int, str], - model: paddle.nn.Layer, - optimizer: Optimizer=None, - infos: dict=None): - """Checkpoint the latest trained model parameters. - Args: - checkpoint_dir (str): the directory where checkpoint is saved. - tag_or_iteration (int or str): the latest iteration(step or epoch) number. - model (Layer): model to be checkpointed. - optimizer (Optimizer, optional): optimizer to be checkpointed. - Defaults to None. - infos (dict or None): any info you want to save. - Returns: - None - """ - checkpoint_path = os.path.join(checkpoint_dir, - "{}".format(tag_or_iteration)) - - model_dict = model.state_dict() - params_path = checkpoint_path + ".pdparams" - paddle.save(model_dict, params_path) - logger.info("Saved model to {}".format(params_path)) - - if optimizer: - opt_dict = optimizer.state_dict() optimizer_path = checkpoint_path + ".pdopt" - paddle.save(opt_dict, optimizer_path) - logger.info("Saved optimzier state to {}".format(optimizer_path)) - - info_path = re.sub('.pdparams$', '.json', params_path) - infos = {} if infos is None else infos - with open(info_path, 'w') as fout: - data = json.dumps(infos) - fout.write(data) + if optimizer and os.path.isfile(optimizer_path): + optimizer_dict = paddle.load(optimizer_path) + optimizer.set_state_dict(optimizer_dict) + logger.info("Rank {}: loaded optimizer state from {}".format( + rank, optimizer_path)) + + info_path = re.sub('.pdparams$', '.json', params_path) + if os.path.exists(info_path): + with open(info_path, 'r') as fin: + configs = json.load(fin) + return configs + + + @mp_tools.rank_zero_only + def save_parameters(self, checkpoint_dir: str, + tag_or_iteration: Union[int, str], + model: paddle.nn.Layer, + optimizer: Optimizer=None, + infos: dict=None): + """Checkpoint the latest trained model parameters. + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + tag_or_iteration (int or str): the latest iteration(step or epoch) number. + model (Layer): model to be checkpointed. + optimizer (Optimizer, optional): optimizer to be checkpointed. + Defaults to None. + infos (dict or None): any info you want to save. + Returns: + None + """ + checkpoint_path = os.path.join(checkpoint_dir, + "{}".format(tag_or_iteration)) + + model_dict = model.state_dict() + params_path = checkpoint_path + ".pdparams" + paddle.save(model_dict, params_path) + logger.info("Saved model to {}".format(params_path)) + + if optimizer: + opt_dict = optimizer.state_dict() + optimizer_path = checkpoint_path + ".pdopt" + paddle.save(opt_dict, optimizer_path) + logger.info("Saved optimzier state to {}".format(optimizer_path)) + + info_path = re.sub('.pdparams$', '.json', params_path) + infos = {} if infos is None else infos + with open(info_path, 'w') as fout: + data = json.dumps(infos) + fout.write(data) - if isinstance(tag_or_iteration, int): - _save_record(checkpoint_dir, tag_or_iteration) diff --git a/examples/tiny/s0/conf/deepspeech2.yaml b/examples/tiny/s0/conf/deepspeech2.yaml index 6737d1b75..9ff6803d8 100644 --- a/examples/tiny/s0/conf/deepspeech2.yaml +++ b/examples/tiny/s0/conf/deepspeech2.yaml @@ -43,12 +43,15 @@ model: share_rnn_weights: True training: - n_epoch: 24 + n_epoch: 6 lr: 1e-5 lr_decay: 1.0 weight_decay: 1e-06 global_grad_clip: 5.0 log_interval: 1 + max_epoch: 3 + last_epoch: 2 + decoding: batch_size: 128