From 91e70a2857c62b7db1db958d9b0528beb2bf0b77 Mon Sep 17 00:00:00 2001 From: Haoxin Ma <745165806@qq.com> Date: Fri, 25 Jun 2021 09:02:59 +0000 Subject: [PATCH] multi gpus --- deepspeech/training/trainer.py | 18 ++-- deepspeech/utils/checkpoint.py | 144 ++++++++++++++++--------- examples/tiny/s0/conf/deepspeech2.yaml | 2 +- 3 files changed, 105 insertions(+), 59 deletions(-) diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 6563e7c4..7f68e67c 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -18,8 +18,8 @@ import paddle from paddle import distributed as dist from tensorboardX import SummaryWriter -from deepspeech.utils.checkpoint import KBestCheckpoint from deepspeech.utils import mp_tools +from deepspeech.utils.checkpoint import Checkpoint from deepspeech.utils.log import Log __all__ = ["Trainer"] @@ -64,7 +64,7 @@ class Trainer(): The parsed command line arguments. Examples -------- - >>> def main_sp(config, args): + >>> def p(config, args): >>> exp = Trainer(config, args) >>> exp.setup() >>> exp.run() @@ -140,11 +140,8 @@ class Trainer(): "lr": self.optimizer.get_lr() }) self.checkpoint.add_checkpoint(self.checkpoint_dir, self.iteration - if tag is None else tag, self.model, - self.optimizer, infos) - # checkpoint.save_parameters(self.checkpoint_dir, self.iteration - # if tag is None else tag, self.model, - # self.optimizer, infos) + if tag is None else tag, self.model, + self.optimizer, infos) def resume_or_scratch(self): """Resume from latest checkpoint at checkpoints in the output @@ -154,7 +151,7 @@ class Trainer(): resume training. """ scratch = None - infos = self.checkpoint.load_parameters( + infos = self.checkpoint.load_last_parameters( self.model, self.optimizer, checkpoint_dir=self.checkpoint_dir, @@ -266,8 +263,9 @@ class Trainer(): self.checkpoint_dir = checkpoint_dir - self.checkpoint = KBestCheckpoint(max_size=self.config.training.checkpoint.kbest_n, - last_size=self.config.training.checkpoint.latest_n) + self.checkpoint = Checkpoint( + kbest_n=self.config.training.checkpoint.kbest_n, + latest_n=self.config.training.checkpoint.latest_n) @mp_tools.rank_zero_only def destory(self): diff --git a/deepspeech/utils/checkpoint.py b/deepspeech/utils/checkpoint.py index ef73eb70..52eccb67 100644 --- a/deepspeech/utils/checkpoint.py +++ b/deepspeech/utils/checkpoint.py @@ -24,20 +24,22 @@ from deepspeech.utils import mp_tools from deepspeech.utils.log import Log import glob +# import operator +from pathlib import Path logger = Log(__name__).getlog() -__all__ = ["load_parameters", "save_parameters"] +__all__ = ["Checkpoint"] -class KBestCheckpoint(object): +class Checkpoint(object): def __init__(self, - max_size: int=5, - last_size: int=1): + kbest_n: int=5, + latest_n: int=1): self.best_records: Mapping[Path, float] = {} - self.last_records = [] - self.max_size = max_size - self.last_size = last_size - self._save_all = (max_size == -1) + self.latest_records = [] + self.kbest_n = kbest_n + self.latest_n = latest_n + self._save_all = (kbest_n == -1) def should_save_best(self, metric: float) -> bool: if not self.best_full(): @@ -45,36 +47,36 @@ class KBestCheckpoint(object): # already full worst_record_path = max(self.best_records, key=self.best_records.get) + # worst_record_path = max(self.best_records.iteritems(), key=operator.itemgetter(1))[0] worst_metric = self.best_records[worst_record_path] return metric < worst_metric def best_full(self): - return (not self._save_all) and len(self.best_records) == self.max_size + return (not self._save_all) and len(self.best_records) == self.kbest_n - def last_full(self): - return len(self.last_records) == self.last_size + def latest_full(self): + return len(self.latest_records) == self.latest_n - def add_checkpoint(self, - checkpoint_dir, tag_or_iteration, - model, optimizer, infos): - if("val_loss" not in infos.keys()): + def add_checkpoint(self, checkpoint_dir, tag_or_iteration, + model, optimizer, infos, metric_type = "val_loss"): + if(metric_type not in infos.keys()): self.save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer, infos) return #save best - if self.should_save_best(infos["val_loss"]): - self.save_checkpoint_and_update(infos["val_loss"], + if self.should_save_best(infos[metric_type]): + self.save_best_checkpoint_and_update(infos[metric_type], checkpoint_dir, tag_or_iteration, model, optimizer, infos) - #save last - self.save_last_checkpoint_and_update(checkpoint_dir, tag_or_iteration, + #save latest + self.save_latest_checkpoint_and_update(checkpoint_dir, tag_or_iteration, model, optimizer, infos) if isinstance(tag_or_iteration, int): - self._save_record(checkpoint_dir, tag_or_iteration) + self.save_checkpoint_record(checkpoint_dir, tag_or_iteration) - def save_checkpoint_and_update(self, metric, + def save_best_checkpoint_and_update(self, metric, checkpoint_dir, tag_or_iteration, model, optimizer, infos): # remove the worst @@ -82,9 +84,8 @@ class KBestCheckpoint(object): worst_record_path = max(self.best_records, key=self.best_records.get) self.best_records.pop(worst_record_path) - if(worst_record_path not in self.last_records): - print('----to remove (best)----') - print(worst_record_path) + if(worst_record_path not in self.latest_records): + logger.info("remove the worst checkpoint: {}".format(worst_record_path)) self.del_checkpoint(checkpoint_dir, worst_record_path) # add the new one @@ -92,22 +93,18 @@ class KBestCheckpoint(object): model, optimizer, infos) self.best_records[tag_or_iteration] = metric - def save_last_checkpoint_and_update(self, checkpoint_dir, tag_or_iteration, + def save_latest_checkpoint_and_update(self, checkpoint_dir, tag_or_iteration, model, optimizer, infos): # remove the old - if self.last_full(): - to_del_fn = self.last_records.pop(0) + if self.latest_full(): + to_del_fn = self.latest_records.pop(0) if(to_del_fn not in self.best_records.keys()): - print('----to remove (last)----') - print(to_del_fn) + logger.info("remove the latest checkpoint: {}".format(to_del_fn)) self.del_checkpoint(checkpoint_dir, to_del_fn) - self.last_records.append(tag_or_iteration) + self.latest_records.append(tag_or_iteration) self.save_parameters(checkpoint_dir, tag_or_iteration, model, optimizer, infos) - # with open(os.path.join(checkpoint_dir, "checkpoint"), "w") as handle: - # for iteration in self.best_records - # handle.write("model_checkpoint_path:{}\n".format(iteration)) def del_checkpoint(self, checkpoint_dir, tag_or_iteration): @@ -115,18 +112,17 @@ class KBestCheckpoint(object): "{}".format(tag_or_iteration)) for filename in glob.glob(checkpoint_path+".*"): os.remove(filename) - print("delete file: "+filename) + logger.info("delete file: {}".format(filename)) - def _load_latest_checkpoint(self, checkpoint_dir: str) -> int: + def load_checkpoint_idx(self, checkpoint_record: str) -> int: """Get the iteration number corresponding to the latest saved checkpoint. Args: - checkpoint_dir (str): the directory where checkpoint is saved. + checkpoint_path (str): the saved path of checkpoint. Returns: int: the latest iteration number. -1 for no checkpoint to load. """ - checkpoint_record = os.path.join(checkpoint_dir, "checkpoint_last") if not os.path.isfile(checkpoint_record): return -1 @@ -135,9 +131,9 @@ class KBestCheckpoint(object): latest_checkpoint = handle.readlines()[-1].strip() iteration = int(latest_checkpoint.split(":")[-1]) return iteration + - - def _save_record(self, checkpoint_dir: str, iteration: int): + def save_checkpoint_record(self, checkpoint_dir: str, iteration: int): """Save the iteration number of the latest model to be checkpoint record. Args: checkpoint_dir (str): the directory where checkpoint is saved. @@ -145,24 +141,22 @@ class KBestCheckpoint(object): Returns: None """ - checkpoint_record_last = os.path.join(checkpoint_dir, "checkpoint_last") + checkpoint_record_latest = os.path.join(checkpoint_dir, "checkpoint_latest") checkpoint_record_best = os.path.join(checkpoint_dir, "checkpoint_best") - # Update the latest checkpoint index. - # with open(checkpoint_record, "a+") as handle: - # handle.write("model_checkpoint_path:{}\n".format(iteration)) + with open(checkpoint_record_best, "w") as handle: for i in self.best_records.keys(): handle.write("model_checkpoint_path:{}\n".format(i)) - with open(checkpoint_record_last, "w") as handle: - for i in self.last_records: + with open(checkpoint_record_latest, "w") as handle: + for i in self.latest_records: handle.write("model_checkpoint_path:{}\n".format(i)) - def load_parameters(self, model, + def load_last_parameters(self, model, optimizer=None, checkpoint_dir=None, checkpoint_path=None): - """Load a specific model checkpoint from disk. + """Load a last model checkpoint from disk. Args: model (Layer): model to load parameters. optimizer (Optimizer, optional): optimizer to load states if needed. @@ -179,7 +173,8 @@ class KBestCheckpoint(object): if checkpoint_path is not None: tag = os.path.basename(checkpoint_path).split(":")[-1] elif checkpoint_dir is not None: - iteration = self._load_latest_checkpoint(checkpoint_dir) + checkpoint_record = os.path.join(checkpoint_dir, "checkpoint_latest") + iteration = self.load_checkpoint_idx(checkpoint_record) if iteration == -1: return configs checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration)) @@ -209,6 +204,59 @@ class KBestCheckpoint(object): return configs + def load_best_parameters(self, model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None): + """Load a last model checkpoint from disk. + Args: + model (Layer): model to load parameters. + optimizer (Optimizer, optional): optimizer to load states if needed. + Defaults to None. + checkpoint_dir (str, optional): the directory where checkpoint is saved. + checkpoint_path (str, optional): if specified, load the checkpoint + stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will + be ignored. Defaults to None. + Returns: + configs (dict): epoch or step, lr and other meta info should be saved. + """ + configs = {} + + if checkpoint_path is not None: + tag = os.path.basename(checkpoint_path).split(":")[-1] + elif checkpoint_dir is not None: + checkpoint_record = os.path.join(checkpoint_dir, "checkpoint_best") + iteration = self.load_checkpoint_idx(checkpoint_record) + if iteration == -1: + return configs + checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration)) + else: + raise ValueError( + "At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!" + ) + + rank = dist.get_rank() + + params_path = checkpoint_path + ".pdparams" + model_dict = paddle.load(params_path) + model.set_state_dict(model_dict) + logger.info("Rank {}: loaded model from {}".format(rank, params_path)) + + optimizer_path = checkpoint_path + ".pdopt" + if optimizer and os.path.isfile(optimizer_path): + optimizer_dict = paddle.load(optimizer_path) + optimizer.set_state_dict(optimizer_dict) + logger.info("Rank {}: loaded optimizer state from {}".format( + rank, optimizer_path)) + + info_path = re.sub('.pdparams$', '.json', params_path) + if os.path.exists(info_path): + with open(info_path, 'r') as fin: + configs = json.load(fin) + return configs + + + @mp_tools.rank_zero_only def save_parameters(self, checkpoint_dir: str, tag_or_iteration: Union[int, str], diff --git a/examples/tiny/s0/conf/deepspeech2.yaml b/examples/tiny/s0/conf/deepspeech2.yaml index b9c2556c..ea433f34 100644 --- a/examples/tiny/s0/conf/deepspeech2.yaml +++ b/examples/tiny/s0/conf/deepspeech2.yaml @@ -43,7 +43,7 @@ model: share_rnn_weights: True training: - n_epoch: 6 + n_epoch: 10 lr: 1e-5 lr_decay: 1.0 weight_decay: 1e-06