From cfdca210ff243a45afa96a64c6ba42bf2586d5eb Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 24 Aug 2021 12:24:59 +0000 Subject: [PATCH] chaner style updater --- deepspeech/training/extensions/__init__.py | 28 +++ deepspeech/training/extensions/evaluator.py | 58 ++++++ deepspeech/training/extensions/extension.py | 41 ++++ deepspeech/training/extensions/snapshot.py | 102 ++++++++++ deepspeech/training/extensions/visualizer.py | 24 +++ deepspeech/training/reporter.py | 131 +++++++++++++ deepspeech/training/triggers/__init__.py | 13 ++ .../training/triggers/interval_trigger.py | 24 +++ deepspeech/training/triggers/limit_trigger.py | 17 ++ deepspeech/training/triggers/time_trigger.py | 17 ++ deepspeech/training/updaters/__init__.py | 0 .../training/updaters/standard_updater.py | 179 ++++++++++++++++++ deepspeech/training/updaters/trainer.py | 171 +++++++++++++++++ deepspeech/training/updaters/updater.py | 82 ++++++++ requirements.txt | 1 + 15 files changed, 888 insertions(+) create mode 100644 deepspeech/training/extensions/__init__.py create mode 100644 deepspeech/training/extensions/evaluator.py create mode 100644 deepspeech/training/extensions/extension.py create mode 100644 deepspeech/training/extensions/snapshot.py create mode 100644 deepspeech/training/extensions/visualizer.py create mode 100644 deepspeech/training/reporter.py create mode 100644 deepspeech/training/triggers/__init__.py create mode 100644 deepspeech/training/triggers/interval_trigger.py create mode 100644 deepspeech/training/triggers/limit_trigger.py create mode 100644 deepspeech/training/triggers/time_trigger.py create mode 100644 deepspeech/training/updaters/__init__.py create mode 100644 deepspeech/training/updaters/standard_updater.py create mode 100644 deepspeech/training/updaters/trainer.py create mode 100644 deepspeech/training/updaters/updater.py diff --git a/deepspeech/training/extensions/__init__.py b/deepspeech/training/extensions/__init__.py new file mode 100644 index 00000000..7ea7470e --- /dev/null +++ b/deepspeech/training/extensions/__init__.py @@ -0,0 +1,28 @@ + +from typing import Callable + +from .extension import Extension + +def make_extension(trigger: Callable=None, + default_name: str=None, + priority: int=None, + finalizer: Callable=None, + initializer: Callable=None, + on_error: Callable=None): + """Make an Extension-like object by injecting required attributes to it. + """ + if trigger is None: + trigger = Extension.trigger + if priority is None: + priority = Extension.priority + + def decorator(ext): + ext.trigger = trigger + ext.default_name = default_name or ext.__name__ + ext.priority = priority + ext.finalize = finalizer + ext.on_error = on_error + ext.initialize = initializer + return ext + + return decorator \ No newline at end of file diff --git a/deepspeech/training/extensions/evaluator.py b/deepspeech/training/extensions/evaluator.py new file mode 100644 index 00000000..ffb7b3a2 --- /dev/null +++ b/deepspeech/training/extensions/evaluator.py @@ -0,0 +1,58 @@ +from typing import Dict + +import paddle +from paddle.io import DataLoader +from paddle.nn import Layer + +import extension +from ..reporter import DictSummary +from ..reporter import report +from ..reporter import scope + + +class StandardEvaluator(extension.Extension): + + trigger = (1, 'epoch') + default_name = 'validation' + priority = extension.PRIORITY_WRITER + + name = None + + def __init__(self, model: Layer, dataloader: DataLoader): + # it is designed to hold multiple models + models = {"main": model} + self.models: Dict[str, Layer] = models + self.model = model + + # dataloaders + self.dataloader = dataloader + + def evaluate_core(self, batch): + # compute + self.model(batch) # you may report here + + def evaluate(self): + # switch to eval mode + for model in self.models.values(): + model.eval() + + # to average evaluation metrics + summary = DictSummary() + for batch in self.dataloader: + observation = {} + with scope(observation): + # main evaluation computation here. + with paddle.no_grad(): + self.evaluate_core(batch) + summary.add(observation) + summary = summary.compute_mean() + return summary + + def __call__(self, trainer=None): + # evaluate and report the averaged metric to current observation + # if it is used to extend a trainer, the metrics is reported to + # to observation of the trainer + # or otherwise, you can use your own observation + summary = self.evaluate() + for k, v in summary.items(): + report(k, v) \ No newline at end of file diff --git a/deepspeech/training/extensions/extension.py b/deepspeech/training/extensions/extension.py new file mode 100644 index 00000000..f8fcede3 --- /dev/null +++ b/deepspeech/training/extensions/extension.py @@ -0,0 +1,41 @@ +from typing import Callable + +PRIORITY_WRITER = 300 +PRIORITY_EDITOR = 200 +PRIORITY_READER = 100 + + +class Extension(): + """Extension to customize the behavior of Trainer.""" + trigger = (1, 'iteration') + priority = PRIORITY_READER + name = None + + @property + def default_name(self): + """Default name of the extension, class name by default.""" + return type(self).__name__ + + def __call__(self, trainer): + """Main action of the extention. After each update, it is executed + when the trigger fires.""" + raise NotImplementedError( + 'Extension implementation must override __call__.') + + def initialize(self, trainer): + """Action that is executed once to get the corect trainer state. + It is called before training normally, but if the trainer restores + states with an Snapshot extension, this method should also be called. + """ + pass + + def on_error(self, trainer, exc, tb): + """Handles the error raised during training before finalization. + """ + pass + + def finalize(self, trainer): + """Action that is executed when training is done. + For example, visualizers would need to be closed. + """ + pass \ No newline at end of file diff --git a/deepspeech/training/extensions/snapshot.py b/deepspeech/training/extensions/snapshot.py new file mode 100644 index 00000000..a15537a0 --- /dev/null +++ b/deepspeech/training/extensions/snapshot.py @@ -0,0 +1,102 @@ +import os +from datetime import datetime +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List + +import jsonlines + +from deepspeech.training.updaters.trainer import Trainer +from deepspeech.training.extensions import extension +from deepspeech.utils.mp_tools import rank_zero_only + +from deepspeech.utils.log import Log + +logger = Log(__name__).getlog() + + +def load_records(records_fp): + """Load record files (json lines.)""" + with jsonlines.open(records_fp, 'r') as reader: + records = list(reader) + return records + + +class Snapshot(extension.Extension): + """An extension to make snapshot of the updater object inside + the trainer. It is done by calling the updater's `save` method. + An Updater save its state_dict by default, which contains the + updater state, (i.e. epoch and iteration) and all the model + parameters and optimizer states. If the updater inside the trainer + subclasses StandardUpdater, everything is good to go. + Parameters + ---------- + checkpoint_dir : Union[str, Path] + The directory to save checkpoints into. + """ + + trigger = (1, 'epoch') + priority = -100 + default_name = "snapshot" + + def __init__(self, max_size: int=5, snapshot_on_error: bool=False): + self.records: List[Dict[str, Any]] = [] + self.max_size = max_size + self._snapshot_on_error = snapshot_on_error + self._save_all = (max_size == -1) + self.checkpoint_dir = None + + def initialize(self, trainer: Trainer): + """Setting up this extention.""" + self.checkpoint_dir = trainer.out / "checkpoints" + + # load existing records + record_path: Path = self.checkpoint_dir / "records.jsonl" + if record_path.exists(): + logger.debug("Loading from an existing checkpoint dir") + self.records = load_records(record_path) + trainer.updater.load(self.records[-1]['path']) + + def on_error(self, trainer, exc, tb): + if self._snapshot_on_error: + self.save_checkpoint_and_update(trainer) + + def __call__(self, trainer: Trainer): + self.save_checkpoint_and_update(trainer) + + def full(self): + """Whether the number of snapshots it keeps track of is greater + than the max_size.""" + return (not self._save_all) and len(self.records) > self.max_size + + @rank_zero_only + def save_checkpoint_and_update(self, trainer: Trainer): + """Saving new snapshot and remove the oldest snapshot if needed.""" + iteration = trainer.updater.state.iteration + epoch = trainer.updater.state.epoch + num = epoch if self.trigger[1] is 'epoch' else iteration + path = self.checkpoint_dir / f"{num}.pdz" + + # add the new one + trainer.updater.save(path) + record = { + "time": str(datetime.now()), + 'path': str(path.resolve()), # use absolute path + 'iteration': iteration, + 'epoch': epoch, + } + self.records.append(record) + + # remove the earist + if self.full(): + eariest_record = self.records[0] + os.remove(eariest_record["path"]) + self.records.pop(0) + + # update the record file + record_path = self.checkpoint_dir / "records.jsonl" + with jsonlines.open(record_path, 'w') as writer: + for record in self.records: + # jsonlines.open may return a Writer or a Reader + writer.write(record) # pylint: disable=no-member \ No newline at end of file diff --git a/deepspeech/training/extensions/visualizer.py b/deepspeech/training/extensions/visualizer.py new file mode 100644 index 00000000..92e07704 --- /dev/null +++ b/deepspeech/training/extensions/visualizer.py @@ -0,0 +1,24 @@ +from deepspeech.training.extensions import extension +from deepspeech.training.updaters.trainer import Trainer + + +class VisualDL(extension.Extension): + """A wrapper of visualdl log writer. It assumes that the metrics to be visualized + are all scalars which are recorded into the `.observation` dictionary of the + trainer object. The dictionary is created for each step, thus the visualdl log + writer uses the iteration from the updater's `iteration` as the global step to + add records. + """ + trigger = (1, 'iteration') + default_name = 'visualdl' + priority = extension.PRIORITY_READER + + def __init__(self, writer): + self.writer = writer + + def __call__(self, trainer: Trainer): + for k, v in trainer.observation.items(): + self.writer.add_scalar(k, v, step=trainer.updater.state.iteration) + + def finalize(self, trainer): + self.writer.close() \ No newline at end of file diff --git a/deepspeech/training/reporter.py b/deepspeech/training/reporter.py new file mode 100644 index 00000000..a5f79fb0 --- /dev/null +++ b/deepspeech/training/reporter.py @@ -0,0 +1,131 @@ +import contextlib +import math +from collections import defaultdict + +OBSERVATIONS = None + + +@contextlib.contextmanager +def scope(observations): + # make `observation` the target to report to. + # it is basically a dictionary that stores temporary observations + global OBSERVATIONS + old = OBSERVATIONS + OBSERVATIONS = observations + + try: + yield + finally: + OBSERVATIONS = old + + +def get_observations(): + global OBSERVATIONS + return OBSERVATIONS + + +def report(name, value): + # a simple function to report named value + # you can use it everywhere, it will get the default target and writ to it + # you can think of it as std.out + observations = get_observations() + if observations is None: + return + else: + observations[name] = value + + +class Summary(): + """Online summarization of a sequence of scalars. + Summary computes the statistics of given scalars online. + """ + + def __init__(self): + self._x = 0.0 + self._x2 = 0.0 + self._n = 0 + + def add(self, value, weight=1): + """Adds a scalar value. + Args: + value: Scalar value to accumulate. It is either a NumPy scalar or + a zero-dimensional array (on CPU or GPU). + weight: An optional weight for the value. It is a NumPy scalar or + a zero-dimensional array (on CPU or GPU). + Default is 1 (integer). + """ + self._x += weight * value + self._x2 += weight * value * value + self._n += weight + + def compute_mean(self): + """Computes the mean.""" + x, n = self._x, self._n + return x / n + + def make_statistics(self): + """Computes and returns the mean and standard deviation values. + Returns: + tuple: Mean and standard deviation values. + """ + x, n = self._x, self._n + mean = x / n + var = self._x2 / n - mean * mean + std = math.sqrt(var) + return mean, std + + +class DictSummary(): + """Online summarization of a sequence of dictionaries. + ``DictSummary`` computes the statistics of a given set of scalars online. + It only computes the statistics for scalar values and variables of scalar + values in the dictionaries. + """ + + def __init__(self): + self._summaries = defaultdict(Summary) + + def add(self, d): + """Adds a dictionary of scalars. + Args: + d (dict): Dictionary of scalars to accumulate. Only elements of + scalars, zero-dimensional arrays, and variables of + zero-dimensional arrays are accumulated. When the value + is a tuple, the second element is interpreted as a weight. + """ + summaries = self._summaries + for k, v in d.items(): + w = 1 + if isinstance(v, tuple): + v = v[0] + w = v[1] + summaries[k].add(v, weight=w) + + def compute_mean(self): + """Creates a dictionary of mean values. + It returns a single dictionary that holds a mean value for each entry + added to the summary. + Returns: + dict: Dictionary of mean values. + """ + return { + name: summary.compute_mean() + for name, summary in self._summaries.items() + } + + def make_statistics(self): + """Creates a dictionary of statistics. + It returns a single dictionary that holds mean and standard deviation + values for every entry added to the summary. For an entry of name + ``'key'``, these values are added to the dictionary by names ``'key'`` + and ``'key.std'``, respectively. + Returns: + dict: Dictionary of statistics of all entries. + """ + stats = {} + for name, summary in self._summaries.items(): + mean, std = summary.make_statistics() + stats[name] = mean + stats[name + '.std'] = std + + return stats \ No newline at end of file diff --git a/deepspeech/training/triggers/__init__.py b/deepspeech/training/triggers/__init__.py new file mode 100644 index 00000000..9da7e615 --- /dev/null +++ b/deepspeech/training/triggers/__init__.py @@ -0,0 +1,13 @@ +from .interval_trigger import IntervalTrigger + +def never_fail_trigger(trainer): + return False + +def get_trigger(trigger): + if trigger is None: + return never_fail_trigger + if callable(trigger): + return trigger + else: + trigger = IntervalTrigger(*trigger) + return trigger \ No newline at end of file diff --git a/deepspeech/training/triggers/interval_trigger.py b/deepspeech/training/triggers/interval_trigger.py new file mode 100644 index 00000000..ef80379c --- /dev/null +++ b/deepspeech/training/triggers/interval_trigger.py @@ -0,0 +1,24 @@ + +class IntervalTrigger(): + """A Predicate to do something every N cycle.""" + + def __init__(self, period: int, unit: str): + if unit not in ("iteration", "epoch"): + raise ValueError("unit should be 'iteration' or 'epoch'") + if period <= 0: + raise ValueError("period should be a positive integer.") + self.period = period + self.unit = unit + self.last_index = None + + def __call__(self, trainer): + if self.last_index is None: + last_index = getattr(trainer.updater.state, self.unit) + self.last_index = last_index + + last_index = self.last_index + index = getattr(trainer.updater.state, self.unit) + fire = index // self.period != last_index // self.period + + self.last_index = index + return fire \ No newline at end of file diff --git a/deepspeech/training/triggers/limit_trigger.py b/deepspeech/training/triggers/limit_trigger.py new file mode 100644 index 00000000..ce13f940 --- /dev/null +++ b/deepspeech/training/triggers/limit_trigger.py @@ -0,0 +1,17 @@ + +class LimitTrigger(): + """A Predicate to decide whether to stop.""" + + def __init__(self, limit: int, unit: str): + if unit not in ("iteration", "epoch"): + raise ValueError("unit should be 'iteration' or 'epoch'") + if limit <= 0: + raise ValueError("limit should be a positive integer.") + self.limit = limit + self.unit = unit + + def __call__(self, trainer): + state = trainer.updater.state + index = getattr(state, self.unit) + fire = index >= self.limit + return fire \ No newline at end of file diff --git a/deepspeech/training/triggers/time_trigger.py b/deepspeech/training/triggers/time_trigger.py new file mode 100644 index 00000000..6232a12d --- /dev/null +++ b/deepspeech/training/triggers/time_trigger.py @@ -0,0 +1,17 @@ +class TimeTrigger(): + """Trigger based on a fixed time interval. + This trigger accepts iterations with a given interval time. + Args: + period (float): Interval time. It is given in seconds. + """ + + def __init__(self, period): + self._period = period + self._next_time = self._period + + def __call__(self, trainer): + if self._next_time < trainer.elapsed_time: + self._next_time += self._period + return True + else: + return False \ No newline at end of file diff --git a/deepspeech/training/updaters/__init__.py b/deepspeech/training/updaters/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/deepspeech/training/updaters/standard_updater.py b/deepspeech/training/updaters/standard_updater.py new file mode 100644 index 00000000..062029ff --- /dev/null +++ b/deepspeech/training/updaters/standard_updater.py @@ -0,0 +1,179 @@ +from typing import Dict +from typing import Optional + +from paddle import Tensor +from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler +from paddle.nn import Layer +from paddle.optimizer import Optimizer +from timer import timer + +from deepspeech.training.reporter import report +from deepspeech.training.updaters.updater import UpdaterBase +from deepspeech.training.updaters.updater import UpdaterState + +from deepspeech.utils.log import Log + +__all__ = ["StandardUpdater"] + +logger = Log(__name__).getlog() + +class StandardUpdater(UpdaterBase): + """An example of over-simplification. Things may not be that simple, but + you can subclass it to fit your need. + """ + + def __init__(self, + model: Layer, + optimizer: Optimizer, + dataloader: DataLoader, + init_state: Optional[UpdaterState]=None): + # it is designed to hold multiple models + models = {"main": model} + self.models: Dict[str, Layer] = models + self.model = model + + # it is designed to hold multiple optimizers + optimizers = {"main": optimizer} + self.optimizer = optimizer + self.optimizers: Dict[str, Optimizer] = optimizers + + # dataloaders + self.dataloader = dataloader + + # init state + if init_state is None: + self.state = UpdaterState() + else: + self.state = init_state + + self.train_iterator = iter(dataloader) + + def update(self): + # We increase the iteration index after updating and before extension. + # Here are the reasons. + + # 0. Snapshotting(as well as other extensions, like visualizer) is + # executed after a step of updating; + # 1. We decide to increase the iteration index after updating and + # before any all extension is executed. + # 3. We do not increase the iteration after extension because we + # prefer a consistent resume behavior, when load from a + # `snapshot_iter_100.pdz` then the next step to train is `101`, + # naturally. But if iteration is increased increased after + # extension(including snapshot), then, a `snapshot_iter_99` is + # loaded. You would need a extra increasing of the iteration idex + # before training to avoid another iteration `99`, which has been + # done before snapshotting. + # 4. Thus iteration index represrnts "currently how mant epochs has + # been done." + # NOTE: use report to capture the correctly value. If you want to + # report the learning rate used for a step, you must report it before + # the learning rate scheduler's step() has been called. In paddle's + # convention, we do not use an extension to change the learning rate. + # so if you want to report it, do it in the updater. + + # Then here comes the next question. When is the proper time to + # increase the epoch index? Since all extensions are executed after + # updating, it is the time that after updating is the proper time to + # increase epoch index. + # 1. If we increase the epoch index before updating, then an extension + # based ot epoch would miss the correct timing. It could only be + # triggerd after an extra updating. + # 2. Theoretically, when an epoch is done, the epoch index should be + # increased. So it would be increase after updating. + # 3. Thus, eppoch index represents "currently how many epochs has been + # done." So it starts from 0. + + # switch to training mode + for model in self.models.values(): + model.train() + + # training for a step is implemented here + batch = self.read_batch() + self.update_core(batch) + + self.state.iteration += 1 + if self.updates_per_epoch is not None: + if self.state.iteration % self.updates_per_epoch == 0: + self.state.epoch += 1 + + def update_core(self, batch): + """A simple case for a training step. Basic assumptions are: + Single model; + Single optimizer; + A batch from the dataloader is just the input of the model; + The model return a single loss, or a dict containing serval losses. + Parameters updates at every batch, no gradient accumulation. + """ + loss = self.model(*batch) + + if isinstance(loss, Tensor): + loss_dict = {"main": loss} + else: + # Dict[str, Tensor] + loss_dict = loss + if "main" not in loss_dict: + main_loss = 0 + for loss_item in loss.values(): + main_loss += loss_item + loss_dict["main"] = main_loss + + for name, loss_item in loss_dict.items(): + report(name, float(loss_item)) + + self.optimizer.clear_gradient() + loss_dict["main"].backward() + self.optimizer.update() + + @property + def updates_per_epoch(self): + """Number of updater per epoch, determined by the length of the + dataloader.""" + length_of_dataloader = None + try: + length_of_dataloader = len(self.dataloader) + except TypeError: + logger.debug("This dataloader has no __len__.") + finally: + return length_of_dataloader + + def new_epoch(self): + """Start a new epoch.""" + # NOTE: all batch sampler for distributed training should + # subclass DistributedBatchSampler and implement `set_epoch` method + if hasattr(self.dataloader, "batch_sampler") + batch_sampler = self.dataloader.batch_sampler + if isinstance(batch_sampler, DistributedBatchSampler): + batch_sampler.set_epoch(self.state.epoch) + self.train_iterator = iter(self.dataloader) + + def read_batch(self): + """Read a batch from the data loader, auto renew when data is exhausted.""" + with timer() as t: + try: + batch = next(self.train_iterator) + except StopIteration: + self.new_epoch() + batch = next(self.train_iterator) + logger.debug( + f"Read a batch takes {t.elapse}s.") # replace it with logger + return batch + + def state_dict(self): + """State dict of a Updater, model, optimizer and updater state are included.""" + state_dict = super().state_dict() + for name, model in self.models.items(): + state_dict[f"{name}_params"] = model.state_dict() + for name, optim in self.optimizers.items(): + state_dict[f"{name}_optimizer"] = optim.state_dict() + return state_dict + + def set_state_dict(self, state_dict): + """Set state dict for a Updater. Parameters of models, states for + optimizers and UpdaterState are restored.""" + for name, model in self.models.items(): + model.set_state_dict(state_dict[f"{name}_params"]) + for name, optim in self.optimizers.items(): + optim.set_state_dict(state_dict[f"{name}_optimizer"]) + super().set_state_dict(state_dict) \ No newline at end of file diff --git a/deepspeech/training/updaters/trainer.py b/deepspeech/training/updaters/trainer.py new file mode 100644 index 00000000..c7562ff0 --- /dev/null +++ b/deepspeech/training/updaters/trainer.py @@ -0,0 +1,171 @@ +import sys +import traceback +from collections import OrderedDict +from pathlib import Path +from typing import Callable +from typing import List +from typing import Union + +import six +import tqdm + +from deepspeech.training.extensions.extension import Extension +from deepspeech.training.extensions.extension import PRIORITY_READER +from deepspeech.training.reporter import scope +from deepspeech.training.triggers import get_trigger +from deepspeech.training.triggers.limit_trigger import LimitTrigger +from deepspeech.training.updaters.updater import UpdaterBase + + +class _ExtensionEntry(): + def __init__(self, extension, trigger, priority): + self.extension = extension + self.trigger = trigger + self.priority = priority + + +class Trainer(): + def __init__(self, + updater: UpdaterBase, + stop_trigger: Callable=None, + out: Union[str, Path]='result', + extensions: List[Extension]=None): + self.updater = updater + self.extensions = OrderedDict() + self.stop_trigger = LimitTrigger(*stop_trigger) + self.out = Path(out) + self.observation = None + + self._done = False + if extensions: + for ext in extensions: + self.extend(ext) + + @property + def is_before_training(self): + return self.updater.state.iteration == 0 + + def extend(self, extension, name=None, trigger=None, priority=None): + # get name for the extension + # argument \ + # -> extention's name \ + # -> default_name (class name, when it is an object) \ + # -> function name when it is a function \ + # -> error + + if name is None: + name = getattr(extension, 'name', None) + if name is None: + name = getattr(extension, 'default_name', None) + if name is None: + name = getattr(extension, '__name__', None) + if name is None: + raise ValueError("Name is not given for the extension.") + if name == 'training': + raise ValueError("training is a reserved name.") + + if trigger is None: + trigger = getattr(extension, 'trigger', (1, 'iteration')) + trigger = get_trigger(trigger) + + if priority is None: + priority = getattr(extension, 'priority', PRIORITY_READER) + + # add suffix to avoid nameing conflict + ordinal = 0 + modified_name = name + while modified_name in self.extensions: + ordinal += 1 + modified_name = f"{name}_{ordinal}" + extension.name = modified_name + + self.extensions[modified_name] = _ExtensionEntry(extension, trigger, + priority) + + def get_extension(self, name): + """get extension by name.""" + extensions = self.extensions + if name in extensions: + return extensions[name].extension + else: + raise ValueError(f'extension {name} not found') + + def run(self): + if self._done: + raise RuntimeError("Training is already done!.") + + self.out.mkdir(parents=True, exist_ok=True) + + # sort extensions by priorities once + extension_order = sorted( + self.extensions.keys(), + key=lambda name: self.extensions[name].priority, + reverse=True) + extensions = [(name, self.extensions[name]) for name in extension_order] + + # initializing all extensions + for name, entry in extensions: + if hasattr(entry.extension, "initialize"): + entry.extension.initialize(self) + + update = self.updater.update # training step + stop_trigger = self.stop_trigger + + # display only one progress bar + max_iteration = None + if isinstance(stop_trigger, LimitTrigger): + if stop_trigger.unit == 'epoch': + max_epoch = self.stop_trigger.limit + updates_per_epoch = getattr(self.updater, "updates_per_epoch", + None) + max_iteration = max_epoch * updates_per_epoch if updates_per_epoch else None + else: + max_iteration = self.stop_trigger.limit + + p = tqdm.tqdm(initial=self.updater.state.iteration, total=max_iteration) + + try: + while not stop_trigger(self): + self.observation = {} + # set observation as the report target + # you can use report freely in Updater.update() + + # updating parameters and state + with scope(self.observation): + update() + p.update() + + # execute extension when necessary + for name, entry in extensions: + if entry.trigger(self): + entry.extension(self) + + # print("###", self.observation) + except Exception as e: + f = sys.stderr + f.write(f"Exception in main training loop: {e}\n") + f.write("Traceback (most recent call last):\n") + traceback.print_tb(sys.exc_info()[2]) + f.write( + "Trainer extensions will try to handle the extension. Then all extensions will finalize." + ) + + # capture the exception in the mian training loop + exc_info = sys.exc_info() + + # try to handle it + for name, entry in extensions: + if hasattr(entry.extension, "on_error"): + try: + entry.extension.on_error(self, e, sys.exc_info()[2]) + except Exception as ee: + f.write(f"Exception in error handler: {ee}\n") + f.write('Traceback (most recent call last):\n') + traceback.print_tb(sys.exc_info()[2]) + + # raise exception in main training loop + six.reraise(*exc_info) + finally: + for name, entry in extensions: + if hasattr(entry.extension, "finalize"): + entry.extension.finalize(self) \ No newline at end of file diff --git a/deepspeech/training/updaters/updater.py b/deepspeech/training/updaters/updater.py new file mode 100644 index 00000000..548042d6 --- /dev/null +++ b/deepspeech/training/updaters/updater.py @@ -0,0 +1,82 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +import paddle + +from deepspeech.utils.log import Log + +__all__ = ["UpdaterBase", "UpdaterState"] + +logger = Log(__name__).getlog() + + +@dataclass +class UpdaterState: + iteration: int = 0 + epoch: int = 0 + + +class UpdaterBase(): + """An updater is the abstraction of how a model is trained given the + dataloader and the optimizer. + The `update_core` method is a step in the training loop with only necessary + operations (get a batch, forward and backward, update the parameters). + Other stuffs are made extensions. Visualization, saving, loading and + periodical validation and evaluation are not considered here. + But even in such simplist case, things are not that simple. There is an + attempt to standardize this process and requires only the model and + dataset and do all the stuffs automatically. But this may hurt flexibility. + If we assume a batch yield from the dataloader is just the input to the + model, we will find that some model requires more arguments, or just some + keyword arguments. But this prevents us from over-simplifying it. + From another perspective, the batch may includes not just the input, but + also the target. But the model's forward method may just need the input. + We can pass a dict or a super-long tuple to the model and let it pick what + it really needs. But this is an abuse of lazy interface. + After all, we care about how a model is trained. But just how the model is + used for inference. We want to control how a model is trained. We just + don't want to be messed up with other auxiliary code. + So the best practice is to define a model and define a updater for it. + """ + + def __init__(self, init_state=None): + if init_state is None: + self.state = UpdaterState() + else: + self.state = init_state + + def update(self, batch): + raise NotImplementedError( + "Implement your own `update` method for training a step.") + + def state_dict(self): + state_dict = { + "epoch": self.state.epoch, + "iteration": self.state.iteration, + } + return state_dict + + def set_state_dict(self, state_dict): + self.state.epoch = state_dict["epoch"] + self.state.iteration = state_dict["iteration"] + + def save(self, path): + logger.debug(f"Saving to {path}.") + archive = self.state_dict() + paddle.save(archive, str(path)) + + def load(self, path): + logger.debug(f"Loading from {path}.") + archive = paddle.load(str(path)) + self.set_state_dict(archive) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 08f2f258..1ed5525e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,3 +15,4 @@ tensorboardX textgrid typeguard yacs +jsonlines \ No newline at end of file