diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 7f68e67c..f8668370 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -64,7 +64,7 @@ class Trainer(): The parsed command line arguments. Examples -------- - >>> def p(config, args): + >>> def main_sp(config, args): >>> exp = Trainer(config, args) >>> exp.setup() >>> exp.run() diff --git a/deepspeech/utils/checkpoint.py b/deepspeech/utils/checkpoint.py index 52eccb67..b29ef2ab 100644 --- a/deepspeech/utils/checkpoint.py +++ b/deepspeech/utils/checkpoint.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import glob import json import os import re +from pathlib import Path from typing import Union import paddle @@ -22,25 +24,21 @@ from paddle.optimizer import Optimizer from deepspeech.utils import mp_tools from deepspeech.utils.log import Log - -import glob # import operator -from pathlib import Path logger = Log(__name__).getlog() __all__ = ["Checkpoint"] + class Checkpoint(object): - def __init__(self, - kbest_n: int=5, - latest_n: int=1): + def __init__(self, kbest_n: int=5, latest_n: int=1): self.best_records: Mapping[Path, float] = {} self.latest_records = [] self.kbest_n = kbest_n self.latest_n = latest_n self._save_all = (kbest_n == -1) - + def should_save_best(self, metric: float) -> bool: if not self.best_full(): return True @@ -53,68 +51,72 @@ class Checkpoint(object): def best_full(self): return (not self._save_all) and len(self.best_records) == self.kbest_n - + def latest_full(self): return len(self.latest_records) == self.latest_n - def add_checkpoint(self, checkpoint_dir, tag_or_iteration, - model, optimizer, infos, metric_type = "val_loss"): - if(metric_type not in infos.keys()): - self.save_parameters(checkpoint_dir, tag_or_iteration, - model, optimizer, infos) + def add_checkpoint(self, + checkpoint_dir, + tag_or_iteration, + model, + optimizer, + infos, + metric_type="val_loss"): + if (metric_type not in infos.keys()): + self.save_parameters(checkpoint_dir, tag_or_iteration, model, + optimizer, infos) return #save best if self.should_save_best(infos[metric_type]): - self.save_best_checkpoint_and_update(infos[metric_type], - checkpoint_dir, tag_or_iteration, - model, optimizer, infos) + self.save_best_checkpoint_and_update( + infos[metric_type], checkpoint_dir, tag_or_iteration, model, + optimizer, infos) #save latest self.save_latest_checkpoint_and_update(checkpoint_dir, tag_or_iteration, - model, optimizer, infos) - + model, optimizer, infos) + if isinstance(tag_or_iteration, int): self.save_checkpoint_record(checkpoint_dir, tag_or_iteration) - - def save_best_checkpoint_and_update(self, metric, - checkpoint_dir, tag_or_iteration, - model, optimizer, infos): + + def save_best_checkpoint_and_update(self, metric, checkpoint_dir, + tag_or_iteration, model, optimizer, + infos): # remove the worst if self.best_full(): worst_record_path = max(self.best_records, key=self.best_records.get) self.best_records.pop(worst_record_path) - if(worst_record_path not in self.latest_records): - logger.info("remove the worst checkpoint: {}".format(worst_record_path)) + if (worst_record_path not in self.latest_records): + logger.info( + "remove the worst checkpoint: {}".format(worst_record_path)) self.del_checkpoint(checkpoint_dir, worst_record_path) # add the new one - self.save_parameters(checkpoint_dir, tag_or_iteration, - model, optimizer, infos) + self.save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer, + infos) self.best_records[tag_or_iteration] = metric - - def save_latest_checkpoint_and_update(self, checkpoint_dir, tag_or_iteration, - model, optimizer, infos): + + def save_latest_checkpoint_and_update( + self, checkpoint_dir, tag_or_iteration, model, optimizer, infos): # remove the old if self.latest_full(): to_del_fn = self.latest_records.pop(0) - if(to_del_fn not in self.best_records.keys()): - logger.info("remove the latest checkpoint: {}".format(to_del_fn)) + if (to_del_fn not in self.best_records.keys()): + logger.info( + "remove the latest checkpoint: {}".format(to_del_fn)) self.del_checkpoint(checkpoint_dir, to_del_fn) self.latest_records.append(tag_or_iteration) - self.save_parameters(checkpoint_dir, tag_or_iteration, - model, optimizer, infos) - + self.save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer, + infos) def del_checkpoint(self, checkpoint_dir, tag_or_iteration): checkpoint_path = os.path.join(checkpoint_dir, - "{}".format(tag_or_iteration)) - for filename in glob.glob(checkpoint_path+".*"): + "{}".format(tag_or_iteration)) + for filename in glob.glob(checkpoint_path + ".*"): os.remove(filename) logger.info("delete file: {}".format(filename)) - - def load_checkpoint_idx(self, checkpoint_record: str) -> int: """Get the iteration number corresponding to the latest saved checkpoint. @@ -131,7 +133,6 @@ class Checkpoint(object): latest_checkpoint = handle.readlines()[-1].strip() iteration = int(latest_checkpoint.split(":")[-1]) return iteration - def save_checkpoint_record(self, checkpoint_dir: str, iteration: int): """Save the iteration number of the latest model to be checkpoint record. @@ -141,9 +142,10 @@ class Checkpoint(object): Returns: None """ - checkpoint_record_latest = os.path.join(checkpoint_dir, "checkpoint_latest") + checkpoint_record_latest = os.path.join(checkpoint_dir, + "checkpoint_latest") checkpoint_record_best = os.path.join(checkpoint_dir, "checkpoint_best") - + with open(checkpoint_record_best, "w") as handle: for i in self.best_records.keys(): handle.write("model_checkpoint_path:{}\n".format(i)) @@ -151,11 +153,11 @@ class Checkpoint(object): for i in self.latest_records: handle.write("model_checkpoint_path:{}\n".format(i)) - - def load_last_parameters(self, model, - optimizer=None, - checkpoint_dir=None, - checkpoint_path=None): + def load_last_parameters(self, + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None): """Load a last model checkpoint from disk. Args: model (Layer): model to load parameters. @@ -173,11 +175,13 @@ class Checkpoint(object): if checkpoint_path is not None: tag = os.path.basename(checkpoint_path).split(":")[-1] elif checkpoint_dir is not None: - checkpoint_record = os.path.join(checkpoint_dir, "checkpoint_latest") + checkpoint_record = os.path.join(checkpoint_dir, + "checkpoint_latest") iteration = self.load_checkpoint_idx(checkpoint_record) if iteration == -1: return configs - checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration)) + checkpoint_path = os.path.join(checkpoint_dir, + "{}".format(iteration)) else: raise ValueError( "At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!" @@ -203,11 +207,11 @@ class Checkpoint(object): configs = json.load(fin) return configs - - def load_best_parameters(self, model, - optimizer=None, - checkpoint_dir=None, - checkpoint_path=None): + def load_best_parameters(self, + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None): """Load a last model checkpoint from disk. Args: model (Layer): model to load parameters. @@ -229,7 +233,8 @@ class Checkpoint(object): iteration = self.load_checkpoint_idx(checkpoint_record) if iteration == -1: return configs - checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration)) + checkpoint_path = os.path.join(checkpoint_dir, + "{}".format(iteration)) else: raise ValueError( "At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!" @@ -255,10 +260,9 @@ class Checkpoint(object): configs = json.load(fin) return configs - - @mp_tools.rank_zero_only - def save_parameters(self, checkpoint_dir: str, + def save_parameters(self, + checkpoint_dir: str, tag_or_iteration: Union[int, str], model: paddle.nn.Layer, optimizer: Optimizer=None, @@ -275,7 +279,7 @@ class Checkpoint(object): None """ checkpoint_path = os.path.join(checkpoint_dir, - "{}".format(tag_or_iteration)) + "{}".format(tag_or_iteration)) model_dict = model.state_dict() params_path = checkpoint_path + ".pdparams" @@ -293,4 +297,3 @@ class Checkpoint(object): with open(info_path, 'w') as fout: data = json.dumps(infos) fout.write(data) -