From eadf64536cf4d158062316a96b8a507fbca14c9a Mon Sep 17 00:00:00 2001 From: Haoxin Ma <745165806@qq.com> Date: Tue, 22 Jun 2021 07:46:50 +0000 Subject: [PATCH 1/7] save best and test on tiny/s0 --- deepspeech/training/trainer.py | 14 +- deepspeech/utils/checkpoint.py | 336 ++++++++++++++++--------- examples/tiny/s0/conf/deepspeech2.yaml | 5 +- 3 files changed, 230 insertions(+), 125 deletions(-) diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 56de32617..246175e3f 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -18,7 +18,7 @@ import paddle from paddle import distributed as dist from tensorboardX import SummaryWriter -from deepspeech.utils import checkpoint +from deepspeech.utils.checkpoint import KBestCheckpoint from deepspeech.utils import mp_tools from deepspeech.utils.log import Log @@ -139,9 +139,12 @@ class Trainer(): "epoch": self.epoch, "lr": self.optimizer.get_lr() }) - checkpoint.save_parameters(self.checkpoint_dir, self.iteration + self.checkpoint.add_checkpoint(self.checkpoint_dir, self.iteration if tag is None else tag, self.model, self.optimizer, infos) + # checkpoint.save_parameters(self.checkpoint_dir, self.iteration + # if tag is None else tag, self.model, + # self.optimizer, infos) def resume_or_scratch(self): """Resume from latest checkpoint at checkpoints in the output @@ -151,7 +154,7 @@ class Trainer(): resume training. """ scratch = None - infos = checkpoint.load_parameters( + infos = self.checkpoint.load_parameters( self.model, self.optimizer, checkpoint_dir=self.checkpoint_dir, @@ -180,7 +183,7 @@ class Trainer(): from_scratch = self.resume_or_scratch() if from_scratch: # save init model, i.e. 0 epoch - self.save(tag='init') + self.save(tag='init', infos=None) self.lr_scheduler.step(self.iteration) if self.parallel: @@ -263,6 +266,9 @@ class Trainer(): self.checkpoint_dir = checkpoint_dir + self.checkpoint = KBestCheckpoint(max_size=self.config.training.max_epoch, + last_size=self.config.training.last_epoch) + @mp_tools.rank_zero_only def destory(self): """Close visualizer to avoid hanging after training""" diff --git a/deepspeech/utils/checkpoint.py b/deepspeech/utils/checkpoint.py index 8ede6b8fd..ef73eb705 100644 --- a/deepspeech/utils/checkpoint.py +++ b/deepspeech/utils/checkpoint.py @@ -23,130 +23,226 @@ from paddle.optimizer import Optimizer from deepspeech.utils import mp_tools from deepspeech.utils.log import Log +import glob + logger = Log(__name__).getlog() __all__ = ["load_parameters", "save_parameters"] +class KBestCheckpoint(object): + def __init__(self, + max_size: int=5, + last_size: int=1): + self.best_records: Mapping[Path, float] = {} + self.last_records = [] + self.max_size = max_size + self.last_size = last_size + self._save_all = (max_size == -1) + + def should_save_best(self, metric: float) -> bool: + if not self.best_full(): + return True + + # already full + worst_record_path = max(self.best_records, key=self.best_records.get) + worst_metric = self.best_records[worst_record_path] + return metric < worst_metric + + def best_full(self): + return (not self._save_all) and len(self.best_records) == self.max_size + + def last_full(self): + return len(self.last_records) == self.last_size + + def add_checkpoint(self, + checkpoint_dir, tag_or_iteration, + model, optimizer, infos): + if("val_loss" not in infos.keys()): + self.save_parameters(checkpoint_dir, tag_or_iteration, + model, optimizer, infos) + return + + #save best + if self.should_save_best(infos["val_loss"]): + self.save_checkpoint_and_update(infos["val_loss"], + checkpoint_dir, tag_or_iteration, + model, optimizer, infos) + #save last + self.save_last_checkpoint_and_update(checkpoint_dir, tag_or_iteration, + model, optimizer, infos) + + if isinstance(tag_or_iteration, int): + self._save_record(checkpoint_dir, tag_or_iteration) + + def save_checkpoint_and_update(self, metric, + checkpoint_dir, tag_or_iteration, + model, optimizer, infos): + # remove the worst + if self.best_full(): + worst_record_path = max(self.best_records, + key=self.best_records.get) + self.best_records.pop(worst_record_path) + if(worst_record_path not in self.last_records): + print('----to remove (best)----') + print(worst_record_path) + self.del_checkpoint(checkpoint_dir, worst_record_path) + + # add the new one + self.save_parameters(checkpoint_dir, tag_or_iteration, + model, optimizer, infos) + self.best_records[tag_or_iteration] = metric + + def save_last_checkpoint_and_update(self, checkpoint_dir, tag_or_iteration, + model, optimizer, infos): + # remove the old + if self.last_full(): + to_del_fn = self.last_records.pop(0) + if(to_del_fn not in self.best_records.keys()): + print('----to remove (last)----') + print(to_del_fn) + self.del_checkpoint(checkpoint_dir, to_del_fn) + self.last_records.append(tag_or_iteration) + + self.save_parameters(checkpoint_dir, tag_or_iteration, + model, optimizer, infos) + # with open(os.path.join(checkpoint_dir, "checkpoint"), "w") as handle: + # for iteration in self.best_records + # handle.write("model_checkpoint_path:{}\n".format(iteration)) + + + def del_checkpoint(self, checkpoint_dir, tag_or_iteration): + checkpoint_path = os.path.join(checkpoint_dir, + "{}".format(tag_or_iteration)) + for filename in glob.glob(checkpoint_path+".*"): + os.remove(filename) + print("delete file: "+filename) + + + + def _load_latest_checkpoint(self, checkpoint_dir: str) -> int: + """Get the iteration number corresponding to the latest saved checkpoint. + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + Returns: + int: the latest iteration number. -1 for no checkpoint to load. + """ + checkpoint_record = os.path.join(checkpoint_dir, "checkpoint_last") + if not os.path.isfile(checkpoint_record): + return -1 + + # Fetch the latest checkpoint index. + with open(checkpoint_record, "rt") as handle: + latest_checkpoint = handle.readlines()[-1].strip() + iteration = int(latest_checkpoint.split(":")[-1]) + return iteration + + + def _save_record(self, checkpoint_dir: str, iteration: int): + """Save the iteration number of the latest model to be checkpoint record. + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + iteration (int): the latest iteration number. + Returns: + None + """ + checkpoint_record_last = os.path.join(checkpoint_dir, "checkpoint_last") + checkpoint_record_best = os.path.join(checkpoint_dir, "checkpoint_best") + # Update the latest checkpoint index. + # with open(checkpoint_record, "a+") as handle: + # handle.write("model_checkpoint_path:{}\n".format(iteration)) + with open(checkpoint_record_best, "w") as handle: + for i in self.best_records.keys(): + handle.write("model_checkpoint_path:{}\n".format(i)) + with open(checkpoint_record_last, "w") as handle: + for i in self.last_records: + handle.write("model_checkpoint_path:{}\n".format(i)) + + + def load_parameters(self, model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None): + """Load a specific model checkpoint from disk. + Args: + model (Layer): model to load parameters. + optimizer (Optimizer, optional): optimizer to load states if needed. + Defaults to None. + checkpoint_dir (str, optional): the directory where checkpoint is saved. + checkpoint_path (str, optional): if specified, load the checkpoint + stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will + be ignored. Defaults to None. + Returns: + configs (dict): epoch or step, lr and other meta info should be saved. + """ + configs = {} + + if checkpoint_path is not None: + tag = os.path.basename(checkpoint_path).split(":")[-1] + elif checkpoint_dir is not None: + iteration = self._load_latest_checkpoint(checkpoint_dir) + if iteration == -1: + return configs + checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration)) + else: + raise ValueError( + "At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!" + ) + + rank = dist.get_rank() + + params_path = checkpoint_path + ".pdparams" + model_dict = paddle.load(params_path) + model.set_state_dict(model_dict) + logger.info("Rank {}: loaded model from {}".format(rank, params_path)) -def _load_latest_checkpoint(checkpoint_dir: str) -> int: - """Get the iteration number corresponding to the latest saved checkpoint. - Args: - checkpoint_dir (str): the directory where checkpoint is saved. - Returns: - int: the latest iteration number. -1 for no checkpoint to load. - """ - checkpoint_record = os.path.join(checkpoint_dir, "checkpoint") - if not os.path.isfile(checkpoint_record): - return -1 - - # Fetch the latest checkpoint index. - with open(checkpoint_record, "rt") as handle: - latest_checkpoint = handle.readlines()[-1].strip() - iteration = int(latest_checkpoint.split(":")[-1]) - return iteration - - -def _save_record(checkpoint_dir: str, iteration: int): - """Save the iteration number of the latest model to be checkpoint record. - Args: - checkpoint_dir (str): the directory where checkpoint is saved. - iteration (int): the latest iteration number. - Returns: - None - """ - checkpoint_record = os.path.join(checkpoint_dir, "checkpoint") - # Update the latest checkpoint index. - with open(checkpoint_record, "a+") as handle: - handle.write("model_checkpoint_path:{}\n".format(iteration)) - - -def load_parameters(model, - optimizer=None, - checkpoint_dir=None, - checkpoint_path=None): - """Load a specific model checkpoint from disk. - Args: - model (Layer): model to load parameters. - optimizer (Optimizer, optional): optimizer to load states if needed. - Defaults to None. - checkpoint_dir (str, optional): the directory where checkpoint is saved. - checkpoint_path (str, optional): if specified, load the checkpoint - stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will - be ignored. Defaults to None. - Returns: - configs (dict): epoch or step, lr and other meta info should be saved. - """ - configs = {} - - if checkpoint_path is not None: - tag = os.path.basename(checkpoint_path).split(":")[-1] - elif checkpoint_dir is not None: - iteration = _load_latest_checkpoint(checkpoint_dir) - if iteration == -1: - return configs - checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration)) - else: - raise ValueError( - "At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!" - ) - - rank = dist.get_rank() - - params_path = checkpoint_path + ".pdparams" - model_dict = paddle.load(params_path) - model.set_state_dict(model_dict) - logger.info("Rank {}: loaded model from {}".format(rank, params_path)) - - optimizer_path = checkpoint_path + ".pdopt" - if optimizer and os.path.isfile(optimizer_path): - optimizer_dict = paddle.load(optimizer_path) - optimizer.set_state_dict(optimizer_dict) - logger.info("Rank {}: loaded optimizer state from {}".format( - rank, optimizer_path)) - - info_path = re.sub('.pdparams$', '.json', params_path) - if os.path.exists(info_path): - with open(info_path, 'r') as fin: - configs = json.load(fin) - return configs - - -@mp_tools.rank_zero_only -def save_parameters(checkpoint_dir: str, - tag_or_iteration: Union[int, str], - model: paddle.nn.Layer, - optimizer: Optimizer=None, - infos: dict=None): - """Checkpoint the latest trained model parameters. - Args: - checkpoint_dir (str): the directory where checkpoint is saved. - tag_or_iteration (int or str): the latest iteration(step or epoch) number. - model (Layer): model to be checkpointed. - optimizer (Optimizer, optional): optimizer to be checkpointed. - Defaults to None. - infos (dict or None): any info you want to save. - Returns: - None - """ - checkpoint_path = os.path.join(checkpoint_dir, - "{}".format(tag_or_iteration)) - - model_dict = model.state_dict() - params_path = checkpoint_path + ".pdparams" - paddle.save(model_dict, params_path) - logger.info("Saved model to {}".format(params_path)) - - if optimizer: - opt_dict = optimizer.state_dict() optimizer_path = checkpoint_path + ".pdopt" - paddle.save(opt_dict, optimizer_path) - logger.info("Saved optimzier state to {}".format(optimizer_path)) - - info_path = re.sub('.pdparams$', '.json', params_path) - infos = {} if infos is None else infos - with open(info_path, 'w') as fout: - data = json.dumps(infos) - fout.write(data) + if optimizer and os.path.isfile(optimizer_path): + optimizer_dict = paddle.load(optimizer_path) + optimizer.set_state_dict(optimizer_dict) + logger.info("Rank {}: loaded optimizer state from {}".format( + rank, optimizer_path)) + + info_path = re.sub('.pdparams$', '.json', params_path) + if os.path.exists(info_path): + with open(info_path, 'r') as fin: + configs = json.load(fin) + return configs + + + @mp_tools.rank_zero_only + def save_parameters(self, checkpoint_dir: str, + tag_or_iteration: Union[int, str], + model: paddle.nn.Layer, + optimizer: Optimizer=None, + infos: dict=None): + """Checkpoint the latest trained model parameters. + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + tag_or_iteration (int or str): the latest iteration(step or epoch) number. + model (Layer): model to be checkpointed. + optimizer (Optimizer, optional): optimizer to be checkpointed. + Defaults to None. + infos (dict or None): any info you want to save. + Returns: + None + """ + checkpoint_path = os.path.join(checkpoint_dir, + "{}".format(tag_or_iteration)) + + model_dict = model.state_dict() + params_path = checkpoint_path + ".pdparams" + paddle.save(model_dict, params_path) + logger.info("Saved model to {}".format(params_path)) + + if optimizer: + opt_dict = optimizer.state_dict() + optimizer_path = checkpoint_path + ".pdopt" + paddle.save(opt_dict, optimizer_path) + logger.info("Saved optimzier state to {}".format(optimizer_path)) + + info_path = re.sub('.pdparams$', '.json', params_path) + infos = {} if infos is None else infos + with open(info_path, 'w') as fout: + data = json.dumps(infos) + fout.write(data) - if isinstance(tag_or_iteration, int): - _save_record(checkpoint_dir, tag_or_iteration) diff --git a/examples/tiny/s0/conf/deepspeech2.yaml b/examples/tiny/s0/conf/deepspeech2.yaml index 6737d1b75..9ff6803d8 100644 --- a/examples/tiny/s0/conf/deepspeech2.yaml +++ b/examples/tiny/s0/conf/deepspeech2.yaml @@ -43,12 +43,15 @@ model: share_rnn_weights: True training: - n_epoch: 24 + n_epoch: 6 lr: 1e-5 lr_decay: 1.0 weight_decay: 1e-06 global_grad_clip: 5.0 log_interval: 1 + max_epoch: 3 + last_epoch: 2 + decoding: batch_size: 128 From f4401ac9a4f2480a4206882baa38d3a0ad43007a Mon Sep 17 00:00:00 2001 From: Haoxin Ma <745165806@qq.com> Date: Tue, 22 Jun 2021 11:36:27 +0000 Subject: [PATCH 2/7] revise config --- deepspeech/training/trainer.py | 4 ++-- examples/tiny/s0/conf/deepspeech2.yaml | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 246175e3f..6563e7c4d 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -266,8 +266,8 @@ class Trainer(): self.checkpoint_dir = checkpoint_dir - self.checkpoint = KBestCheckpoint(max_size=self.config.training.max_epoch, - last_size=self.config.training.last_epoch) + self.checkpoint = KBestCheckpoint(max_size=self.config.training.checkpoint.kbest_n, + last_size=self.config.training.checkpoint.latest_n) @mp_tools.rank_zero_only def destory(self): diff --git a/examples/tiny/s0/conf/deepspeech2.yaml b/examples/tiny/s0/conf/deepspeech2.yaml index 9ff6803d8..b9c2556c7 100644 --- a/examples/tiny/s0/conf/deepspeech2.yaml +++ b/examples/tiny/s0/conf/deepspeech2.yaml @@ -49,8 +49,9 @@ training: weight_decay: 1e-06 global_grad_clip: 5.0 log_interval: 1 - max_epoch: 3 - last_epoch: 2 + checkpoint: + kbest_n: 3 + latest_n: 2 decoding: From 6e2079ab7b189a15b293ea5973bebd42c84e1f92 Mon Sep 17 00:00:00 2001 From: Haoxin Ma <745165806@qq.com> Date: Fri, 25 Jun 2021 09:02:59 +0000 Subject: [PATCH 3/7] multi gpus --- deepspeech/training/trainer.py | 18 ++-- deepspeech/utils/checkpoint.py | 144 ++++++++++++++++--------- examples/tiny/s0/conf/deepspeech2.yaml | 2 +- 3 files changed, 105 insertions(+), 59 deletions(-) diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 6563e7c4d..7f68e67cb 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -18,8 +18,8 @@ import paddle from paddle import distributed as dist from tensorboardX import SummaryWriter -from deepspeech.utils.checkpoint import KBestCheckpoint from deepspeech.utils import mp_tools +from deepspeech.utils.checkpoint import Checkpoint from deepspeech.utils.log import Log __all__ = ["Trainer"] @@ -64,7 +64,7 @@ class Trainer(): The parsed command line arguments. Examples -------- - >>> def main_sp(config, args): + >>> def p(config, args): >>> exp = Trainer(config, args) >>> exp.setup() >>> exp.run() @@ -140,11 +140,8 @@ class Trainer(): "lr": self.optimizer.get_lr() }) self.checkpoint.add_checkpoint(self.checkpoint_dir, self.iteration - if tag is None else tag, self.model, - self.optimizer, infos) - # checkpoint.save_parameters(self.checkpoint_dir, self.iteration - # if tag is None else tag, self.model, - # self.optimizer, infos) + if tag is None else tag, self.model, + self.optimizer, infos) def resume_or_scratch(self): """Resume from latest checkpoint at checkpoints in the output @@ -154,7 +151,7 @@ class Trainer(): resume training. """ scratch = None - infos = self.checkpoint.load_parameters( + infos = self.checkpoint.load_last_parameters( self.model, self.optimizer, checkpoint_dir=self.checkpoint_dir, @@ -266,8 +263,9 @@ class Trainer(): self.checkpoint_dir = checkpoint_dir - self.checkpoint = KBestCheckpoint(max_size=self.config.training.checkpoint.kbest_n, - last_size=self.config.training.checkpoint.latest_n) + self.checkpoint = Checkpoint( + kbest_n=self.config.training.checkpoint.kbest_n, + latest_n=self.config.training.checkpoint.latest_n) @mp_tools.rank_zero_only def destory(self): diff --git a/deepspeech/utils/checkpoint.py b/deepspeech/utils/checkpoint.py index ef73eb705..52eccb673 100644 --- a/deepspeech/utils/checkpoint.py +++ b/deepspeech/utils/checkpoint.py @@ -24,20 +24,22 @@ from deepspeech.utils import mp_tools from deepspeech.utils.log import Log import glob +# import operator +from pathlib import Path logger = Log(__name__).getlog() -__all__ = ["load_parameters", "save_parameters"] +__all__ = ["Checkpoint"] -class KBestCheckpoint(object): +class Checkpoint(object): def __init__(self, - max_size: int=5, - last_size: int=1): + kbest_n: int=5, + latest_n: int=1): self.best_records: Mapping[Path, float] = {} - self.last_records = [] - self.max_size = max_size - self.last_size = last_size - self._save_all = (max_size == -1) + self.latest_records = [] + self.kbest_n = kbest_n + self.latest_n = latest_n + self._save_all = (kbest_n == -1) def should_save_best(self, metric: float) -> bool: if not self.best_full(): @@ -45,36 +47,36 @@ class KBestCheckpoint(object): # already full worst_record_path = max(self.best_records, key=self.best_records.get) + # worst_record_path = max(self.best_records.iteritems(), key=operator.itemgetter(1))[0] worst_metric = self.best_records[worst_record_path] return metric < worst_metric def best_full(self): - return (not self._save_all) and len(self.best_records) == self.max_size + return (not self._save_all) and len(self.best_records) == self.kbest_n - def last_full(self): - return len(self.last_records) == self.last_size + def latest_full(self): + return len(self.latest_records) == self.latest_n - def add_checkpoint(self, - checkpoint_dir, tag_or_iteration, - model, optimizer, infos): - if("val_loss" not in infos.keys()): + def add_checkpoint(self, checkpoint_dir, tag_or_iteration, + model, optimizer, infos, metric_type = "val_loss"): + if(metric_type not in infos.keys()): self.save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer, infos) return #save best - if self.should_save_best(infos["val_loss"]): - self.save_checkpoint_and_update(infos["val_loss"], + if self.should_save_best(infos[metric_type]): + self.save_best_checkpoint_and_update(infos[metric_type], checkpoint_dir, tag_or_iteration, model, optimizer, infos) - #save last - self.save_last_checkpoint_and_update(checkpoint_dir, tag_or_iteration, + #save latest + self.save_latest_checkpoint_and_update(checkpoint_dir, tag_or_iteration, model, optimizer, infos) if isinstance(tag_or_iteration, int): - self._save_record(checkpoint_dir, tag_or_iteration) + self.save_checkpoint_record(checkpoint_dir, tag_or_iteration) - def save_checkpoint_and_update(self, metric, + def save_best_checkpoint_and_update(self, metric, checkpoint_dir, tag_or_iteration, model, optimizer, infos): # remove the worst @@ -82,9 +84,8 @@ class KBestCheckpoint(object): worst_record_path = max(self.best_records, key=self.best_records.get) self.best_records.pop(worst_record_path) - if(worst_record_path not in self.last_records): - print('----to remove (best)----') - print(worst_record_path) + if(worst_record_path not in self.latest_records): + logger.info("remove the worst checkpoint: {}".format(worst_record_path)) self.del_checkpoint(checkpoint_dir, worst_record_path) # add the new one @@ -92,22 +93,18 @@ class KBestCheckpoint(object): model, optimizer, infos) self.best_records[tag_or_iteration] = metric - def save_last_checkpoint_and_update(self, checkpoint_dir, tag_or_iteration, + def save_latest_checkpoint_and_update(self, checkpoint_dir, tag_or_iteration, model, optimizer, infos): # remove the old - if self.last_full(): - to_del_fn = self.last_records.pop(0) + if self.latest_full(): + to_del_fn = self.latest_records.pop(0) if(to_del_fn not in self.best_records.keys()): - print('----to remove (last)----') - print(to_del_fn) + logger.info("remove the latest checkpoint: {}".format(to_del_fn)) self.del_checkpoint(checkpoint_dir, to_del_fn) - self.last_records.append(tag_or_iteration) + self.latest_records.append(tag_or_iteration) self.save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer, infos) - # with open(os.path.join(checkpoint_dir, "checkpoint"), "w") as handle: - # for iteration in self.best_records - # handle.write("model_checkpoint_path:{}\n".format(iteration)) def del_checkpoint(self, checkpoint_dir, tag_or_iteration): @@ -115,18 +112,17 @@ class KBestCheckpoint(object): "{}".format(tag_or_iteration)) for filename in glob.glob(checkpoint_path+".*"): os.remove(filename) - print("delete file: "+filename) + logger.info("delete file: {}".format(filename)) - def _load_latest_checkpoint(self, checkpoint_dir: str) -> int: + def load_checkpoint_idx(self, checkpoint_record: str) -> int: """Get the iteration number corresponding to the latest saved checkpoint. Args: - checkpoint_dir (str): the directory where checkpoint is saved. + checkpoint_path (str): the saved path of checkpoint. Returns: int: the latest iteration number. -1 for no checkpoint to load. """ - checkpoint_record = os.path.join(checkpoint_dir, "checkpoint_last") if not os.path.isfile(checkpoint_record): return -1 @@ -135,9 +131,9 @@ class KBestCheckpoint(object): latest_checkpoint = handle.readlines()[-1].strip() iteration = int(latest_checkpoint.split(":")[-1]) return iteration + - - def _save_record(self, checkpoint_dir: str, iteration: int): + def save_checkpoint_record(self, checkpoint_dir: str, iteration: int): """Save the iteration number of the latest model to be checkpoint record. Args: checkpoint_dir (str): the directory where checkpoint is saved. @@ -145,24 +141,22 @@ class KBestCheckpoint(object): Returns: None """ - checkpoint_record_last = os.path.join(checkpoint_dir, "checkpoint_last") + checkpoint_record_latest = os.path.join(checkpoint_dir, "checkpoint_latest") checkpoint_record_best = os.path.join(checkpoint_dir, "checkpoint_best") - # Update the latest checkpoint index. - # with open(checkpoint_record, "a+") as handle: - # handle.write("model_checkpoint_path:{}\n".format(iteration)) + with open(checkpoint_record_best, "w") as handle: for i in self.best_records.keys(): handle.write("model_checkpoint_path:{}\n".format(i)) - with open(checkpoint_record_last, "w") as handle: - for i in self.last_records: + with open(checkpoint_record_latest, "w") as handle: + for i in self.latest_records: handle.write("model_checkpoint_path:{}\n".format(i)) - def load_parameters(self, model, + def load_last_parameters(self, model, optimizer=None, checkpoint_dir=None, checkpoint_path=None): - """Load a specific model checkpoint from disk. + """Load a last model checkpoint from disk. Args: model (Layer): model to load parameters. optimizer (Optimizer, optional): optimizer to load states if needed. @@ -179,7 +173,8 @@ class KBestCheckpoint(object): if checkpoint_path is not None: tag = os.path.basename(checkpoint_path).split(":")[-1] elif checkpoint_dir is not None: - iteration = self._load_latest_checkpoint(checkpoint_dir) + checkpoint_record = os.path.join(checkpoint_dir, "checkpoint_latest") + iteration = self.load_checkpoint_idx(checkpoint_record) if iteration == -1: return configs checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration)) @@ -209,6 +204,59 @@ class KBestCheckpoint(object): return configs + def load_best_parameters(self, model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None): + """Load a last model checkpoint from disk. + Args: + model (Layer): model to load parameters. + optimizer (Optimizer, optional): optimizer to load states if needed. + Defaults to None. + checkpoint_dir (str, optional): the directory where checkpoint is saved. + checkpoint_path (str, optional): if specified, load the checkpoint + stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will + be ignored. Defaults to None. + Returns: + configs (dict): epoch or step, lr and other meta info should be saved. + """ + configs = {} + + if checkpoint_path is not None: + tag = os.path.basename(checkpoint_path).split(":")[-1] + elif checkpoint_dir is not None: + checkpoint_record = os.path.join(checkpoint_dir, "checkpoint_best") + iteration = self.load_checkpoint_idx(checkpoint_record) + if iteration == -1: + return configs + checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration)) + else: + raise ValueError( + "At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!" + ) + + rank = dist.get_rank() + + params_path = checkpoint_path + ".pdparams" + model_dict = paddle.load(params_path) + model.set_state_dict(model_dict) + logger.info("Rank {}: loaded model from {}".format(rank, params_path)) + + optimizer_path = checkpoint_path + ".pdopt" + if optimizer and os.path.isfile(optimizer_path): + optimizer_dict = paddle.load(optimizer_path) + optimizer.set_state_dict(optimizer_dict) + logger.info("Rank {}: loaded optimizer state from {}".format( + rank, optimizer_path)) + + info_path = re.sub('.pdparams$', '.json', params_path) + if os.path.exists(info_path): + with open(info_path, 'r') as fin: + configs = json.load(fin) + return configs + + + @mp_tools.rank_zero_only def save_parameters(self, checkpoint_dir: str, tag_or_iteration: Union[int, str], diff --git a/examples/tiny/s0/conf/deepspeech2.yaml b/examples/tiny/s0/conf/deepspeech2.yaml index b9c2556c7..ea433f341 100644 --- a/examples/tiny/s0/conf/deepspeech2.yaml +++ b/examples/tiny/s0/conf/deepspeech2.yaml @@ -43,7 +43,7 @@ model: share_rnn_weights: True training: - n_epoch: 6 + n_epoch: 10 lr: 1e-5 lr_decay: 1.0 weight_decay: 1e-06 From c3786fab0e22e1d2be1b50217901dcd9881c74f0 Mon Sep 17 00:00:00 2001 From: Haoxin Ma <745165806@qq.com> Date: Fri, 25 Jun 2021 09:08:30 +0000 Subject: [PATCH 4/7] fix bug --- deepspeech/training/trainer.py | 2 +- deepspeech/utils/checkpoint.py | 121 +++++++++++++++++---------------- 2 files changed, 63 insertions(+), 60 deletions(-) diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 7f68e67cb..f8668370a 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -64,7 +64,7 @@ class Trainer(): The parsed command line arguments. Examples -------- - >>> def p(config, args): + >>> def main_sp(config, args): >>> exp = Trainer(config, args) >>> exp.setup() >>> exp.run() diff --git a/deepspeech/utils/checkpoint.py b/deepspeech/utils/checkpoint.py index 52eccb673..b29ef2ab5 100644 --- a/deepspeech/utils/checkpoint.py +++ b/deepspeech/utils/checkpoint.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import glob import json import os import re +from pathlib import Path from typing import Union import paddle @@ -22,25 +24,21 @@ from paddle.optimizer import Optimizer from deepspeech.utils import mp_tools from deepspeech.utils.log import Log - -import glob # import operator -from pathlib import Path logger = Log(__name__).getlog() __all__ = ["Checkpoint"] + class Checkpoint(object): - def __init__(self, - kbest_n: int=5, - latest_n: int=1): + def __init__(self, kbest_n: int=5, latest_n: int=1): self.best_records: Mapping[Path, float] = {} self.latest_records = [] self.kbest_n = kbest_n self.latest_n = latest_n self._save_all = (kbest_n == -1) - + def should_save_best(self, metric: float) -> bool: if not self.best_full(): return True @@ -53,68 +51,72 @@ class Checkpoint(object): def best_full(self): return (not self._save_all) and len(self.best_records) == self.kbest_n - + def latest_full(self): return len(self.latest_records) == self.latest_n - def add_checkpoint(self, checkpoint_dir, tag_or_iteration, - model, optimizer, infos, metric_type = "val_loss"): - if(metric_type not in infos.keys()): - self.save_parameters(checkpoint_dir, tag_or_iteration, - model, optimizer, infos) + def add_checkpoint(self, + checkpoint_dir, + tag_or_iteration, + model, + optimizer, + infos, + metric_type="val_loss"): + if (metric_type not in infos.keys()): + self.save_parameters(checkpoint_dir, tag_or_iteration, model, + optimizer, infos) return #save best if self.should_save_best(infos[metric_type]): - self.save_best_checkpoint_and_update(infos[metric_type], - checkpoint_dir, tag_or_iteration, - model, optimizer, infos) + self.save_best_checkpoint_and_update( + infos[metric_type], checkpoint_dir, tag_or_iteration, model, + optimizer, infos) #save latest self.save_latest_checkpoint_and_update(checkpoint_dir, tag_or_iteration, - model, optimizer, infos) - + model, optimizer, infos) + if isinstance(tag_or_iteration, int): self.save_checkpoint_record(checkpoint_dir, tag_or_iteration) - - def save_best_checkpoint_and_update(self, metric, - checkpoint_dir, tag_or_iteration, - model, optimizer, infos): + + def save_best_checkpoint_and_update(self, metric, checkpoint_dir, + tag_or_iteration, model, optimizer, + infos): # remove the worst if self.best_full(): worst_record_path = max(self.best_records, key=self.best_records.get) self.best_records.pop(worst_record_path) - if(worst_record_path not in self.latest_records): - logger.info("remove the worst checkpoint: {}".format(worst_record_path)) + if (worst_record_path not in self.latest_records): + logger.info( + "remove the worst checkpoint: {}".format(worst_record_path)) self.del_checkpoint(checkpoint_dir, worst_record_path) # add the new one - self.save_parameters(checkpoint_dir, tag_or_iteration, - model, optimizer, infos) + self.save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer, + infos) self.best_records[tag_or_iteration] = metric - - def save_latest_checkpoint_and_update(self, checkpoint_dir, tag_or_iteration, - model, optimizer, infos): + + def save_latest_checkpoint_and_update( + self, checkpoint_dir, tag_or_iteration, model, optimizer, infos): # remove the old if self.latest_full(): to_del_fn = self.latest_records.pop(0) - if(to_del_fn not in self.best_records.keys()): - logger.info("remove the latest checkpoint: {}".format(to_del_fn)) + if (to_del_fn not in self.best_records.keys()): + logger.info( + "remove the latest checkpoint: {}".format(to_del_fn)) self.del_checkpoint(checkpoint_dir, to_del_fn) self.latest_records.append(tag_or_iteration) - self.save_parameters(checkpoint_dir, tag_or_iteration, - model, optimizer, infos) - + self.save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer, + infos) def del_checkpoint(self, checkpoint_dir, tag_or_iteration): checkpoint_path = os.path.join(checkpoint_dir, - "{}".format(tag_or_iteration)) - for filename in glob.glob(checkpoint_path+".*"): + "{}".format(tag_or_iteration)) + for filename in glob.glob(checkpoint_path + ".*"): os.remove(filename) logger.info("delete file: {}".format(filename)) - - def load_checkpoint_idx(self, checkpoint_record: str) -> int: """Get the iteration number corresponding to the latest saved checkpoint. @@ -131,7 +133,6 @@ class Checkpoint(object): latest_checkpoint = handle.readlines()[-1].strip() iteration = int(latest_checkpoint.split(":")[-1]) return iteration - def save_checkpoint_record(self, checkpoint_dir: str, iteration: int): """Save the iteration number of the latest model to be checkpoint record. @@ -141,9 +142,10 @@ class Checkpoint(object): Returns: None """ - checkpoint_record_latest = os.path.join(checkpoint_dir, "checkpoint_latest") + checkpoint_record_latest = os.path.join(checkpoint_dir, + "checkpoint_latest") checkpoint_record_best = os.path.join(checkpoint_dir, "checkpoint_best") - + with open(checkpoint_record_best, "w") as handle: for i in self.best_records.keys(): handle.write("model_checkpoint_path:{}\n".format(i)) @@ -151,11 +153,11 @@ class Checkpoint(object): for i in self.latest_records: handle.write("model_checkpoint_path:{}\n".format(i)) - - def load_last_parameters(self, model, - optimizer=None, - checkpoint_dir=None, - checkpoint_path=None): + def load_last_parameters(self, + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None): """Load a last model checkpoint from disk. Args: model (Layer): model to load parameters. @@ -173,11 +175,13 @@ class Checkpoint(object): if checkpoint_path is not None: tag = os.path.basename(checkpoint_path).split(":")[-1] elif checkpoint_dir is not None: - checkpoint_record = os.path.join(checkpoint_dir, "checkpoint_latest") + checkpoint_record = os.path.join(checkpoint_dir, + "checkpoint_latest") iteration = self.load_checkpoint_idx(checkpoint_record) if iteration == -1: return configs - checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration)) + checkpoint_path = os.path.join(checkpoint_dir, + "{}".format(iteration)) else: raise ValueError( "At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!" @@ -203,11 +207,11 @@ class Checkpoint(object): configs = json.load(fin) return configs - - def load_best_parameters(self, model, - optimizer=None, - checkpoint_dir=None, - checkpoint_path=None): + def load_best_parameters(self, + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None): """Load a last model checkpoint from disk. Args: model (Layer): model to load parameters. @@ -229,7 +233,8 @@ class Checkpoint(object): iteration = self.load_checkpoint_idx(checkpoint_record) if iteration == -1: return configs - checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration)) + checkpoint_path = os.path.join(checkpoint_dir, + "{}".format(iteration)) else: raise ValueError( "At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!" @@ -255,10 +260,9 @@ class Checkpoint(object): configs = json.load(fin) return configs - - @mp_tools.rank_zero_only - def save_parameters(self, checkpoint_dir: str, + def save_parameters(self, + checkpoint_dir: str, tag_or_iteration: Union[int, str], model: paddle.nn.Layer, optimizer: Optimizer=None, @@ -275,7 +279,7 @@ class Checkpoint(object): None """ checkpoint_path = os.path.join(checkpoint_dir, - "{}".format(tag_or_iteration)) + "{}".format(tag_or_iteration)) model_dict = model.state_dict() params_path = checkpoint_path + ".pdparams" @@ -293,4 +297,3 @@ class Checkpoint(object): with open(info_path, 'w') as fout: data = json.dumps(infos) fout.write(data) - From b41f70ddd2d9c3fc245b42569539ed73ff911fa4 Mon Sep 17 00:00:00 2001 From: Haoxin Ma <745165806@qq.com> Date: Tue, 29 Jun 2021 06:05:26 +0000 Subject: [PATCH 5/7] optimize the function --- deepspeech/training/trainer.py | 5 +- deepspeech/utils/checkpoint.py | 109 +++++++++------------------------ 2 files changed, 32 insertions(+), 82 deletions(-) diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index f8668370a..cd915760d 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -151,11 +151,12 @@ class Trainer(): resume training. """ scratch = None - infos = self.checkpoint.load_last_parameters( + infos = self.checkpoint._load_parameters( self.model, self.optimizer, checkpoint_dir=self.checkpoint_dir, - checkpoint_path=self.args.checkpoint_path) + checkpoint_path=self.args.checkpoint_path, + checkpoint_file='checkpoint_latest') if infos: # restore from ckpt self.iteration = infos["step"] diff --git a/deepspeech/utils/checkpoint.py b/deepspeech/utils/checkpoint.py index b29ef2ab5..be36fdbb2 100644 --- a/deepspeech/utils/checkpoint.py +++ b/deepspeech/utils/checkpoint.py @@ -39,8 +39,8 @@ class Checkpoint(object): self.latest_n = latest_n self._save_all = (kbest_n == -1) - def should_save_best(self, metric: float) -> bool: - if not self.best_full(): + def _should_save_best(self, metric: float) -> bool: + if not self._best_full(): return True # already full @@ -49,10 +49,10 @@ class Checkpoint(object): worst_metric = self.best_records[worst_record_path] return metric < worst_metric - def best_full(self): + def _best_full(self): return (not self._save_all) and len(self.best_records) == self.kbest_n - def latest_full(self): + def _latest_full(self): return len(self.latest_records) == self.latest_n def add_checkpoint(self, @@ -63,62 +63,62 @@ class Checkpoint(object): infos, metric_type="val_loss"): if (metric_type not in infos.keys()): - self.save_parameters(checkpoint_dir, tag_or_iteration, model, + self._save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer, infos) return #save best - if self.should_save_best(infos[metric_type]): - self.save_best_checkpoint_and_update( + if self._should_save_best(infos[metric_type]): + self._save_best_checkpoint_and_update( infos[metric_type], checkpoint_dir, tag_or_iteration, model, optimizer, infos) #save latest - self.save_latest_checkpoint_and_update(checkpoint_dir, tag_or_iteration, + self._save_latest_checkpoint_and_update(checkpoint_dir, tag_or_iteration, model, optimizer, infos) if isinstance(tag_or_iteration, int): - self.save_checkpoint_record(checkpoint_dir, tag_or_iteration) + self._save_checkpoint_record(checkpoint_dir, tag_or_iteration) - def save_best_checkpoint_and_update(self, metric, checkpoint_dir, + def _save_best_checkpoint_and_update(self, metric, checkpoint_dir, tag_or_iteration, model, optimizer, infos): # remove the worst - if self.best_full(): + if self._best_full(): worst_record_path = max(self.best_records, key=self.best_records.get) self.best_records.pop(worst_record_path) if (worst_record_path not in self.latest_records): logger.info( "remove the worst checkpoint: {}".format(worst_record_path)) - self.del_checkpoint(checkpoint_dir, worst_record_path) + self._del_checkpoint(checkpoint_dir, worst_record_path) # add the new one - self.save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer, + self._save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer, infos) self.best_records[tag_or_iteration] = metric - def save_latest_checkpoint_and_update( + def _save_latest_checkpoint_and_update( self, checkpoint_dir, tag_or_iteration, model, optimizer, infos): # remove the old - if self.latest_full(): + if self._latest_full(): to_del_fn = self.latest_records.pop(0) if (to_del_fn not in self.best_records.keys()): logger.info( "remove the latest checkpoint: {}".format(to_del_fn)) - self.del_checkpoint(checkpoint_dir, to_del_fn) + self._del_checkpoint(checkpoint_dir, to_del_fn) self.latest_records.append(tag_or_iteration) - self.save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer, + self._save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer, infos) - def del_checkpoint(self, checkpoint_dir, tag_or_iteration): + def _del_checkpoint(self, checkpoint_dir, tag_or_iteration): checkpoint_path = os.path.join(checkpoint_dir, "{}".format(tag_or_iteration)) for filename in glob.glob(checkpoint_path + ".*"): os.remove(filename) logger.info("delete file: {}".format(filename)) - def load_checkpoint_idx(self, checkpoint_record: str) -> int: + def _load_checkpoint_idx(self, checkpoint_record: str) -> int: """Get the iteration number corresponding to the latest saved checkpoint. Args: checkpoint_path (str): the saved path of checkpoint. @@ -134,7 +134,7 @@ class Checkpoint(object): iteration = int(latest_checkpoint.split(":")[-1]) return iteration - def save_checkpoint_record(self, checkpoint_dir: str, iteration: int): + def _save_checkpoint_record(self, checkpoint_dir: str, iteration: int): """Save the iteration number of the latest model to be checkpoint record. Args: checkpoint_dir (str): the directory where checkpoint is saved. @@ -153,65 +153,13 @@ class Checkpoint(object): for i in self.latest_records: handle.write("model_checkpoint_path:{}\n".format(i)) - def load_last_parameters(self, - model, - optimizer=None, - checkpoint_dir=None, - checkpoint_path=None): - """Load a last model checkpoint from disk. - Args: - model (Layer): model to load parameters. - optimizer (Optimizer, optional): optimizer to load states if needed. - Defaults to None. - checkpoint_dir (str, optional): the directory where checkpoint is saved. - checkpoint_path (str, optional): if specified, load the checkpoint - stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will - be ignored. Defaults to None. - Returns: - configs (dict): epoch or step, lr and other meta info should be saved. - """ - configs = {} - - if checkpoint_path is not None: - tag = os.path.basename(checkpoint_path).split(":")[-1] - elif checkpoint_dir is not None: - checkpoint_record = os.path.join(checkpoint_dir, - "checkpoint_latest") - iteration = self.load_checkpoint_idx(checkpoint_record) - if iteration == -1: - return configs - checkpoint_path = os.path.join(checkpoint_dir, - "{}".format(iteration)) - else: - raise ValueError( - "At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!" - ) - - rank = dist.get_rank() - - params_path = checkpoint_path + ".pdparams" - model_dict = paddle.load(params_path) - model.set_state_dict(model_dict) - logger.info("Rank {}: loaded model from {}".format(rank, params_path)) - - optimizer_path = checkpoint_path + ".pdopt" - if optimizer and os.path.isfile(optimizer_path): - optimizer_dict = paddle.load(optimizer_path) - optimizer.set_state_dict(optimizer_dict) - logger.info("Rank {}: loaded optimizer state from {}".format( - rank, optimizer_path)) - - info_path = re.sub('.pdparams$', '.json', params_path) - if os.path.exists(info_path): - with open(info_path, 'r') as fin: - configs = json.load(fin) - return configs - def load_best_parameters(self, + def _load_parameters(self, model, optimizer=None, checkpoint_dir=None, - checkpoint_path=None): + checkpoint_path=None, + checkpoint_file=None): """Load a last model checkpoint from disk. Args: model (Layer): model to load parameters. @@ -221,6 +169,7 @@ class Checkpoint(object): checkpoint_path (str, optional): if specified, load the checkpoint stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will be ignored. Defaults to None. + checkpoint_file "checkpoint_latest" or "checkpoint_best" Returns: configs (dict): epoch or step, lr and other meta info should be saved. """ @@ -228,16 +177,16 @@ class Checkpoint(object): if checkpoint_path is not None: tag = os.path.basename(checkpoint_path).split(":")[-1] - elif checkpoint_dir is not None: - checkpoint_record = os.path.join(checkpoint_dir, "checkpoint_best") - iteration = self.load_checkpoint_idx(checkpoint_record) + elif checkpoint_dir is not None and checkpoint_file is not None: + checkpoint_record = os.path.join(checkpoint_dir, checkpoint_file) + iteration = self._load_checkpoint_idx(checkpoint_record) if iteration == -1: return configs checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration)) else: raise ValueError( - "At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!" + "At least one of 'checkpoint_dir' and 'checkpoint_file' and 'checkpoint_path' should be specified!" ) rank = dist.get_rank() @@ -261,7 +210,7 @@ class Checkpoint(object): return configs @mp_tools.rank_zero_only - def save_parameters(self, + def _save_parameters(self, checkpoint_dir: str, tag_or_iteration: Union[int, str], model: paddle.nn.Layer, From 08118a0349a23fcf071c3cec0a422c090be0179c Mon Sep 17 00:00:00 2001 From: Haoxin Ma <745165806@qq.com> Date: Wed, 30 Jun 2021 03:00:18 +0000 Subject: [PATCH 6/7] fix private function --- deepspeech/training/trainer.py | 5 +- deepspeech/utils/checkpoint.py | 114 ++++++++++++++++++++++----------- 2 files changed, 79 insertions(+), 40 deletions(-) diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index cd915760d..5ebba1a98 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -151,12 +151,11 @@ class Trainer(): resume training. """ scratch = None - infos = self.checkpoint._load_parameters( + infos = self.checkpoint.load_latest_parameters( self.model, self.optimizer, checkpoint_dir=self.checkpoint_dir, - checkpoint_path=self.args.checkpoint_path, - checkpoint_file='checkpoint_latest') + checkpoint_path=self.args.checkpoint_path) if infos: # restore from ckpt self.iteration = infos["step"] diff --git a/deepspeech/utils/checkpoint.py b/deepspeech/utils/checkpoint.py index be36fdbb2..000fa87ba 100644 --- a/deepspeech/utils/checkpoint.py +++ b/deepspeech/utils/checkpoint.py @@ -38,23 +38,7 @@ class Checkpoint(object): self.kbest_n = kbest_n self.latest_n = latest_n self._save_all = (kbest_n == -1) - - def _should_save_best(self, metric: float) -> bool: - if not self._best_full(): - return True - - # already full - worst_record_path = max(self.best_records, key=self.best_records.get) - # worst_record_path = max(self.best_records.iteritems(), key=operator.itemgetter(1))[0] - worst_metric = self.best_records[worst_record_path] - return metric < worst_metric - - def _best_full(self): - return (not self._save_all) and len(self.best_records) == self.kbest_n - - def _latest_full(self): - return len(self.latest_records) == self.latest_n - + def add_checkpoint(self, checkpoint_dir, tag_or_iteration, @@ -64,7 +48,7 @@ class Checkpoint(object): metric_type="val_loss"): if (metric_type not in infos.keys()): self._save_parameters(checkpoint_dir, tag_or_iteration, model, - optimizer, infos) + optimizer, infos) return #save best @@ -73,15 +57,71 @@ class Checkpoint(object): infos[metric_type], checkpoint_dir, tag_or_iteration, model, optimizer, infos) #save latest - self._save_latest_checkpoint_and_update(checkpoint_dir, tag_or_iteration, - model, optimizer, infos) + self._save_latest_checkpoint_and_update( + checkpoint_dir, tag_or_iteration, model, optimizer, infos) if isinstance(tag_or_iteration, int): self._save_checkpoint_record(checkpoint_dir, tag_or_iteration) + def load_latest_parameters(self, + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None): + """Load a last model checkpoint from disk. + Args: + model (Layer): model to load parameters. + optimizer (Optimizer, optional): optimizer to load states if needed. + Defaults to None. + checkpoint_dir (str, optional): the directory where checkpoint is saved. + checkpoint_path (str, optional): if specified, load the checkpoint + stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will + be ignored. Defaults to None. + Returns: + configs (dict): epoch or step, lr and other meta info should be saved. + """ + return self._load_parameters(model, optimizer, checkpoint_dir, checkpoint_path, + "checkpoint_latest") + + def load_best_parameters(self, + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None): + """Load a last model checkpoint from disk. + Args: + model (Layer): model to load parameters. + optimizer (Optimizer, optional): optimizer to load states if needed. + Defaults to None. + checkpoint_dir (str, optional): the directory where checkpoint is saved. + checkpoint_path (str, optional): if specified, load the checkpoint + stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will + be ignored. Defaults to None. + Returns: + configs (dict): epoch or step, lr and other meta info should be saved. + """ + return self._load_parameters(model, optimizer, checkpoint_dir, checkpoint_path, + "checkpoint_best") + + def _should_save_best(self, metric: float) -> bool: + if not self._best_full(): + return True + + # already full + worst_record_path = max(self.best_records, key=self.best_records.get) + # worst_record_path = max(self.best_records.iteritems(), key=operator.itemgetter(1))[0] + worst_metric = self.best_records[worst_record_path] + return metric < worst_metric + + def _best_full(self): + return (not self._save_all) and len(self.best_records) == self.kbest_n + + def _latest_full(self): + return len(self.latest_records) == self.latest_n + def _save_best_checkpoint_and_update(self, metric, checkpoint_dir, - tag_or_iteration, model, optimizer, - infos): + tag_or_iteration, model, optimizer, + infos): # remove the worst if self._best_full(): worst_record_path = max(self.best_records, @@ -93,8 +133,8 @@ class Checkpoint(object): self._del_checkpoint(checkpoint_dir, worst_record_path) # add the new one - self._save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer, - infos) + self._save_parameters(checkpoint_dir, tag_or_iteration, model, + optimizer, infos) self.best_records[tag_or_iteration] = metric def _save_latest_checkpoint_and_update( @@ -108,8 +148,8 @@ class Checkpoint(object): self._del_checkpoint(checkpoint_dir, to_del_fn) self.latest_records.append(tag_or_iteration) - self._save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer, - infos) + self._save_parameters(checkpoint_dir, tag_or_iteration, model, + optimizer, infos) def _del_checkpoint(self, checkpoint_dir, tag_or_iteration): checkpoint_path = os.path.join(checkpoint_dir, @@ -153,13 +193,12 @@ class Checkpoint(object): for i in self.latest_records: handle.write("model_checkpoint_path:{}\n".format(i)) - def _load_parameters(self, - model, - optimizer=None, - checkpoint_dir=None, - checkpoint_path=None, - checkpoint_file=None): + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None, + checkpoint_file=None): """Load a last model checkpoint from disk. Args: model (Layer): model to load parameters. @@ -209,13 +248,14 @@ class Checkpoint(object): configs = json.load(fin) return configs + @mp_tools.rank_zero_only def _save_parameters(self, - checkpoint_dir: str, - tag_or_iteration: Union[int, str], - model: paddle.nn.Layer, - optimizer: Optimizer=None, - infos: dict=None): + checkpoint_dir: str, + tag_or_iteration: Union[int, str], + model: paddle.nn.Layer, + optimizer: Optimizer=None, + infos: dict=None): """Checkpoint the latest trained model parameters. Args: checkpoint_dir (str): the directory where checkpoint is saved. From 7850bcd423dd66f2aa79c8e361cf0187f38ed75f Mon Sep 17 00:00:00 2001 From: Haoxin Ma <745165806@qq.com> Date: Wed, 30 Jun 2021 03:10:34 +0000 Subject: [PATCH 7/7] revise conf/*.yaml --- deepspeech/utils/checkpoint.py | 28 +++++++++---------- examples/aishell/s0/conf/deepspeech2.yaml | 3 ++ examples/aishell/s1/conf/chunk_conformer.yaml | 3 ++ examples/aishell/s1/conf/conformer.yaml | 3 ++ examples/librispeech/s0/conf/deepspeech2.yaml | 3 ++ .../librispeech/s1/conf/chunk_confermer.yaml | 3 ++ .../s1/conf/chunk_transformer.yaml | 3 ++ examples/librispeech/s1/conf/conformer.yaml | 3 ++ examples/librispeech/s1/conf/transformer.yaml | 3 ++ examples/tiny/s1/conf/chunk_confermer.yaml | 3 ++ examples/tiny/s1/conf/chunk_transformer.yaml | 3 ++ examples/tiny/s1/conf/conformer.yaml | 3 ++ examples/tiny/s1/conf/transformer.yaml | 3 ++ 13 files changed, 49 insertions(+), 15 deletions(-) diff --git a/deepspeech/utils/checkpoint.py b/deepspeech/utils/checkpoint.py index 000fa87ba..8c5d8d605 100644 --- a/deepspeech/utils/checkpoint.py +++ b/deepspeech/utils/checkpoint.py @@ -24,7 +24,6 @@ from paddle.optimizer import Optimizer from deepspeech.utils import mp_tools from deepspeech.utils.log import Log -# import operator logger = Log(__name__).getlog() @@ -38,7 +37,7 @@ class Checkpoint(object): self.kbest_n = kbest_n self.latest_n = latest_n self._save_all = (kbest_n == -1) - + def add_checkpoint(self, checkpoint_dir, tag_or_iteration, @@ -64,10 +63,10 @@ class Checkpoint(object): self._save_checkpoint_record(checkpoint_dir, tag_or_iteration) def load_latest_parameters(self, - model, - optimizer=None, - checkpoint_dir=None, - checkpoint_path=None): + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None): """Load a last model checkpoint from disk. Args: model (Layer): model to load parameters. @@ -80,14 +79,14 @@ class Checkpoint(object): Returns: configs (dict): epoch or step, lr and other meta info should be saved. """ - return self._load_parameters(model, optimizer, checkpoint_dir, checkpoint_path, - "checkpoint_latest") + return self._load_parameters(model, optimizer, checkpoint_dir, + checkpoint_path, "checkpoint_latest") def load_best_parameters(self, - model, - optimizer=None, - checkpoint_dir=None, - checkpoint_path=None): + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None): """Load a last model checkpoint from disk. Args: model (Layer): model to load parameters. @@ -100,8 +99,8 @@ class Checkpoint(object): Returns: configs (dict): epoch or step, lr and other meta info should be saved. """ - return self._load_parameters(model, optimizer, checkpoint_dir, checkpoint_path, - "checkpoint_best") + return self._load_parameters(model, optimizer, checkpoint_dir, + checkpoint_path, "checkpoint_best") def _should_save_best(self, metric: float) -> bool: if not self._best_full(): @@ -248,7 +247,6 @@ class Checkpoint(object): configs = json.load(fin) return configs - @mp_tools.rank_zero_only def _save_parameters(self, checkpoint_dir: str, diff --git a/examples/aishell/s0/conf/deepspeech2.yaml b/examples/aishell/s0/conf/deepspeech2.yaml index 54ce240e7..27ede01bc 100644 --- a/examples/aishell/s0/conf/deepspeech2.yaml +++ b/examples/aishell/s0/conf/deepspeech2.yaml @@ -48,6 +48,9 @@ training: weight_decay: 1e-06 global_grad_clip: 3.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: batch_size: 128 diff --git a/examples/aishell/s1/conf/chunk_conformer.yaml b/examples/aishell/s1/conf/chunk_conformer.yaml index 904624c3c..1065dcb03 100644 --- a/examples/aishell/s1/conf/chunk_conformer.yaml +++ b/examples/aishell/s1/conf/chunk_conformer.yaml @@ -90,6 +90,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: diff --git a/examples/aishell/s1/conf/conformer.yaml b/examples/aishell/s1/conf/conformer.yaml index 116c91927..4b1430c58 100644 --- a/examples/aishell/s1/conf/conformer.yaml +++ b/examples/aishell/s1/conf/conformer.yaml @@ -88,6 +88,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: diff --git a/examples/librispeech/s0/conf/deepspeech2.yaml b/examples/librispeech/s0/conf/deepspeech2.yaml index d1746bff3..9f06a3802 100644 --- a/examples/librispeech/s0/conf/deepspeech2.yaml +++ b/examples/librispeech/s0/conf/deepspeech2.yaml @@ -43,6 +43,9 @@ training: weight_decay: 1e-06 global_grad_clip: 5.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: batch_size: 128 diff --git a/examples/librispeech/s1/conf/chunk_confermer.yaml b/examples/librispeech/s1/conf/chunk_confermer.yaml index ec945a188..979121639 100644 --- a/examples/librispeech/s1/conf/chunk_confermer.yaml +++ b/examples/librispeech/s1/conf/chunk_confermer.yaml @@ -91,6 +91,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: diff --git a/examples/librispeech/s1/conf/chunk_transformer.yaml b/examples/librispeech/s1/conf/chunk_transformer.yaml index 3939ffc68..dc2a51f92 100644 --- a/examples/librispeech/s1/conf/chunk_transformer.yaml +++ b/examples/librispeech/s1/conf/chunk_transformer.yaml @@ -84,6 +84,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: diff --git a/examples/librispeech/s1/conf/conformer.yaml b/examples/librispeech/s1/conf/conformer.yaml index 8f8bf4539..989af22a0 100644 --- a/examples/librispeech/s1/conf/conformer.yaml +++ b/examples/librispeech/s1/conf/conformer.yaml @@ -87,6 +87,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: diff --git a/examples/librispeech/s1/conf/transformer.yaml b/examples/librispeech/s1/conf/transformer.yaml index a094b0fba..931d7524b 100644 --- a/examples/librispeech/s1/conf/transformer.yaml +++ b/examples/librispeech/s1/conf/transformer.yaml @@ -82,6 +82,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: diff --git a/examples/tiny/s1/conf/chunk_confermer.yaml b/examples/tiny/s1/conf/chunk_confermer.yaml index 790066264..606300bdf 100644 --- a/examples/tiny/s1/conf/chunk_confermer.yaml +++ b/examples/tiny/s1/conf/chunk_confermer.yaml @@ -91,6 +91,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 1 + checkpoint: + kbest_n: 10 + latest_n: 1 decoding: diff --git a/examples/tiny/s1/conf/chunk_transformer.yaml b/examples/tiny/s1/conf/chunk_transformer.yaml index aa2b145a6..72d368485 100644 --- a/examples/tiny/s1/conf/chunk_transformer.yaml +++ b/examples/tiny/s1/conf/chunk_transformer.yaml @@ -84,6 +84,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 1 + checkpoint: + kbest_n: 10 + latest_n: 1 decoding: diff --git a/examples/tiny/s1/conf/conformer.yaml b/examples/tiny/s1/conf/conformer.yaml index 3813daa04..a6f730501 100644 --- a/examples/tiny/s1/conf/conformer.yaml +++ b/examples/tiny/s1/conf/conformer.yaml @@ -87,6 +87,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 1 + checkpoint: + kbest_n: 10 + latest_n: 1 decoding: diff --git a/examples/tiny/s1/conf/transformer.yaml b/examples/tiny/s1/conf/transformer.yaml index 250995faa..71cbdde7f 100644 --- a/examples/tiny/s1/conf/transformer.yaml +++ b/examples/tiny/s1/conf/transformer.yaml @@ -84,6 +84,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 1 + checkpoint: + kbest_n: 10 + latest_n: 1 decoding: