parent
787fa9d91f
commit
cfdca210ff
@ -0,0 +1,28 @@
|
|||||||
|
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from .extension import Extension
|
||||||
|
|
||||||
|
def make_extension(trigger: Callable=None,
|
||||||
|
default_name: str=None,
|
||||||
|
priority: int=None,
|
||||||
|
finalizer: Callable=None,
|
||||||
|
initializer: Callable=None,
|
||||||
|
on_error: Callable=None):
|
||||||
|
"""Make an Extension-like object by injecting required attributes to it.
|
||||||
|
"""
|
||||||
|
if trigger is None:
|
||||||
|
trigger = Extension.trigger
|
||||||
|
if priority is None:
|
||||||
|
priority = Extension.priority
|
||||||
|
|
||||||
|
def decorator(ext):
|
||||||
|
ext.trigger = trigger
|
||||||
|
ext.default_name = default_name or ext.__name__
|
||||||
|
ext.priority = priority
|
||||||
|
ext.finalize = finalizer
|
||||||
|
ext.on_error = on_error
|
||||||
|
ext.initialize = initializer
|
||||||
|
return ext
|
||||||
|
|
||||||
|
return decorator
|
@ -0,0 +1,58 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle.io import DataLoader
|
||||||
|
from paddle.nn import Layer
|
||||||
|
|
||||||
|
import extension
|
||||||
|
from ..reporter import DictSummary
|
||||||
|
from ..reporter import report
|
||||||
|
from ..reporter import scope
|
||||||
|
|
||||||
|
|
||||||
|
class StandardEvaluator(extension.Extension):
|
||||||
|
|
||||||
|
trigger = (1, 'epoch')
|
||||||
|
default_name = 'validation'
|
||||||
|
priority = extension.PRIORITY_WRITER
|
||||||
|
|
||||||
|
name = None
|
||||||
|
|
||||||
|
def __init__(self, model: Layer, dataloader: DataLoader):
|
||||||
|
# it is designed to hold multiple models
|
||||||
|
models = {"main": model}
|
||||||
|
self.models: Dict[str, Layer] = models
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
# dataloaders
|
||||||
|
self.dataloader = dataloader
|
||||||
|
|
||||||
|
def evaluate_core(self, batch):
|
||||||
|
# compute
|
||||||
|
self.model(batch) # you may report here
|
||||||
|
|
||||||
|
def evaluate(self):
|
||||||
|
# switch to eval mode
|
||||||
|
for model in self.models.values():
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# to average evaluation metrics
|
||||||
|
summary = DictSummary()
|
||||||
|
for batch in self.dataloader:
|
||||||
|
observation = {}
|
||||||
|
with scope(observation):
|
||||||
|
# main evaluation computation here.
|
||||||
|
with paddle.no_grad():
|
||||||
|
self.evaluate_core(batch)
|
||||||
|
summary.add(observation)
|
||||||
|
summary = summary.compute_mean()
|
||||||
|
return summary
|
||||||
|
|
||||||
|
def __call__(self, trainer=None):
|
||||||
|
# evaluate and report the averaged metric to current observation
|
||||||
|
# if it is used to extend a trainer, the metrics is reported to
|
||||||
|
# to observation of the trainer
|
||||||
|
# or otherwise, you can use your own observation
|
||||||
|
summary = self.evaluate()
|
||||||
|
for k, v in summary.items():
|
||||||
|
report(k, v)
|
@ -0,0 +1,41 @@
|
|||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
PRIORITY_WRITER = 300
|
||||||
|
PRIORITY_EDITOR = 200
|
||||||
|
PRIORITY_READER = 100
|
||||||
|
|
||||||
|
|
||||||
|
class Extension():
|
||||||
|
"""Extension to customize the behavior of Trainer."""
|
||||||
|
trigger = (1, 'iteration')
|
||||||
|
priority = PRIORITY_READER
|
||||||
|
name = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_name(self):
|
||||||
|
"""Default name of the extension, class name by default."""
|
||||||
|
return type(self).__name__
|
||||||
|
|
||||||
|
def __call__(self, trainer):
|
||||||
|
"""Main action of the extention. After each update, it is executed
|
||||||
|
when the trigger fires."""
|
||||||
|
raise NotImplementedError(
|
||||||
|
'Extension implementation must override __call__.')
|
||||||
|
|
||||||
|
def initialize(self, trainer):
|
||||||
|
"""Action that is executed once to get the corect trainer state.
|
||||||
|
It is called before training normally, but if the trainer restores
|
||||||
|
states with an Snapshot extension, this method should also be called.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_error(self, trainer, exc, tb):
|
||||||
|
"""Handles the error raised during training before finalization.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def finalize(self, trainer):
|
||||||
|
"""Action that is executed when training is done.
|
||||||
|
For example, visualizers would need to be closed.
|
||||||
|
"""
|
||||||
|
pass
|
@ -0,0 +1,102 @@
|
|||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from typing import Dict
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import jsonlines
|
||||||
|
|
||||||
|
from deepspeech.training.updaters.trainer import Trainer
|
||||||
|
from deepspeech.training.extensions import extension
|
||||||
|
from deepspeech.utils.mp_tools import rank_zero_only
|
||||||
|
|
||||||
|
from deepspeech.utils.log import Log
|
||||||
|
|
||||||
|
logger = Log(__name__).getlog()
|
||||||
|
|
||||||
|
|
||||||
|
def load_records(records_fp):
|
||||||
|
"""Load record files (json lines.)"""
|
||||||
|
with jsonlines.open(records_fp, 'r') as reader:
|
||||||
|
records = list(reader)
|
||||||
|
return records
|
||||||
|
|
||||||
|
|
||||||
|
class Snapshot(extension.Extension):
|
||||||
|
"""An extension to make snapshot of the updater object inside
|
||||||
|
the trainer. It is done by calling the updater's `save` method.
|
||||||
|
An Updater save its state_dict by default, which contains the
|
||||||
|
updater state, (i.e. epoch and iteration) and all the model
|
||||||
|
parameters and optimizer states. If the updater inside the trainer
|
||||||
|
subclasses StandardUpdater, everything is good to go.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
checkpoint_dir : Union[str, Path]
|
||||||
|
The directory to save checkpoints into.
|
||||||
|
"""
|
||||||
|
|
||||||
|
trigger = (1, 'epoch')
|
||||||
|
priority = -100
|
||||||
|
default_name = "snapshot"
|
||||||
|
|
||||||
|
def __init__(self, max_size: int=5, snapshot_on_error: bool=False):
|
||||||
|
self.records: List[Dict[str, Any]] = []
|
||||||
|
self.max_size = max_size
|
||||||
|
self._snapshot_on_error = snapshot_on_error
|
||||||
|
self._save_all = (max_size == -1)
|
||||||
|
self.checkpoint_dir = None
|
||||||
|
|
||||||
|
def initialize(self, trainer: Trainer):
|
||||||
|
"""Setting up this extention."""
|
||||||
|
self.checkpoint_dir = trainer.out / "checkpoints"
|
||||||
|
|
||||||
|
# load existing records
|
||||||
|
record_path: Path = self.checkpoint_dir / "records.jsonl"
|
||||||
|
if record_path.exists():
|
||||||
|
logger.debug("Loading from an existing checkpoint dir")
|
||||||
|
self.records = load_records(record_path)
|
||||||
|
trainer.updater.load(self.records[-1]['path'])
|
||||||
|
|
||||||
|
def on_error(self, trainer, exc, tb):
|
||||||
|
if self._snapshot_on_error:
|
||||||
|
self.save_checkpoint_and_update(trainer)
|
||||||
|
|
||||||
|
def __call__(self, trainer: Trainer):
|
||||||
|
self.save_checkpoint_and_update(trainer)
|
||||||
|
|
||||||
|
def full(self):
|
||||||
|
"""Whether the number of snapshots it keeps track of is greater
|
||||||
|
than the max_size."""
|
||||||
|
return (not self._save_all) and len(self.records) > self.max_size
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
def save_checkpoint_and_update(self, trainer: Trainer):
|
||||||
|
"""Saving new snapshot and remove the oldest snapshot if needed."""
|
||||||
|
iteration = trainer.updater.state.iteration
|
||||||
|
epoch = trainer.updater.state.epoch
|
||||||
|
num = epoch if self.trigger[1] is 'epoch' else iteration
|
||||||
|
path = self.checkpoint_dir / f"{num}.pdz"
|
||||||
|
|
||||||
|
# add the new one
|
||||||
|
trainer.updater.save(path)
|
||||||
|
record = {
|
||||||
|
"time": str(datetime.now()),
|
||||||
|
'path': str(path.resolve()), # use absolute path
|
||||||
|
'iteration': iteration,
|
||||||
|
'epoch': epoch,
|
||||||
|
}
|
||||||
|
self.records.append(record)
|
||||||
|
|
||||||
|
# remove the earist
|
||||||
|
if self.full():
|
||||||
|
eariest_record = self.records[0]
|
||||||
|
os.remove(eariest_record["path"])
|
||||||
|
self.records.pop(0)
|
||||||
|
|
||||||
|
# update the record file
|
||||||
|
record_path = self.checkpoint_dir / "records.jsonl"
|
||||||
|
with jsonlines.open(record_path, 'w') as writer:
|
||||||
|
for record in self.records:
|
||||||
|
# jsonlines.open may return a Writer or a Reader
|
||||||
|
writer.write(record) # pylint: disable=no-member
|
@ -0,0 +1,24 @@
|
|||||||
|
from deepspeech.training.extensions import extension
|
||||||
|
from deepspeech.training.updaters.trainer import Trainer
|
||||||
|
|
||||||
|
|
||||||
|
class VisualDL(extension.Extension):
|
||||||
|
"""A wrapper of visualdl log writer. It assumes that the metrics to be visualized
|
||||||
|
are all scalars which are recorded into the `.observation` dictionary of the
|
||||||
|
trainer object. The dictionary is created for each step, thus the visualdl log
|
||||||
|
writer uses the iteration from the updater's `iteration` as the global step to
|
||||||
|
add records.
|
||||||
|
"""
|
||||||
|
trigger = (1, 'iteration')
|
||||||
|
default_name = 'visualdl'
|
||||||
|
priority = extension.PRIORITY_READER
|
||||||
|
|
||||||
|
def __init__(self, writer):
|
||||||
|
self.writer = writer
|
||||||
|
|
||||||
|
def __call__(self, trainer: Trainer):
|
||||||
|
for k, v in trainer.observation.items():
|
||||||
|
self.writer.add_scalar(k, v, step=trainer.updater.state.iteration)
|
||||||
|
|
||||||
|
def finalize(self, trainer):
|
||||||
|
self.writer.close()
|
@ -0,0 +1,131 @@
|
|||||||
|
import contextlib
|
||||||
|
import math
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
OBSERVATIONS = None
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def scope(observations):
|
||||||
|
# make `observation` the target to report to.
|
||||||
|
# it is basically a dictionary that stores temporary observations
|
||||||
|
global OBSERVATIONS
|
||||||
|
old = OBSERVATIONS
|
||||||
|
OBSERVATIONS = observations
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
OBSERVATIONS = old
|
||||||
|
|
||||||
|
|
||||||
|
def get_observations():
|
||||||
|
global OBSERVATIONS
|
||||||
|
return OBSERVATIONS
|
||||||
|
|
||||||
|
|
||||||
|
def report(name, value):
|
||||||
|
# a simple function to report named value
|
||||||
|
# you can use it everywhere, it will get the default target and writ to it
|
||||||
|
# you can think of it as std.out
|
||||||
|
observations = get_observations()
|
||||||
|
if observations is None:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
observations[name] = value
|
||||||
|
|
||||||
|
|
||||||
|
class Summary():
|
||||||
|
"""Online summarization of a sequence of scalars.
|
||||||
|
Summary computes the statistics of given scalars online.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._x = 0.0
|
||||||
|
self._x2 = 0.0
|
||||||
|
self._n = 0
|
||||||
|
|
||||||
|
def add(self, value, weight=1):
|
||||||
|
"""Adds a scalar value.
|
||||||
|
Args:
|
||||||
|
value: Scalar value to accumulate. It is either a NumPy scalar or
|
||||||
|
a zero-dimensional array (on CPU or GPU).
|
||||||
|
weight: An optional weight for the value. It is a NumPy scalar or
|
||||||
|
a zero-dimensional array (on CPU or GPU).
|
||||||
|
Default is 1 (integer).
|
||||||
|
"""
|
||||||
|
self._x += weight * value
|
||||||
|
self._x2 += weight * value * value
|
||||||
|
self._n += weight
|
||||||
|
|
||||||
|
def compute_mean(self):
|
||||||
|
"""Computes the mean."""
|
||||||
|
x, n = self._x, self._n
|
||||||
|
return x / n
|
||||||
|
|
||||||
|
def make_statistics(self):
|
||||||
|
"""Computes and returns the mean and standard deviation values.
|
||||||
|
Returns:
|
||||||
|
tuple: Mean and standard deviation values.
|
||||||
|
"""
|
||||||
|
x, n = self._x, self._n
|
||||||
|
mean = x / n
|
||||||
|
var = self._x2 / n - mean * mean
|
||||||
|
std = math.sqrt(var)
|
||||||
|
return mean, std
|
||||||
|
|
||||||
|
|
||||||
|
class DictSummary():
|
||||||
|
"""Online summarization of a sequence of dictionaries.
|
||||||
|
``DictSummary`` computes the statistics of a given set of scalars online.
|
||||||
|
It only computes the statistics for scalar values and variables of scalar
|
||||||
|
values in the dictionaries.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._summaries = defaultdict(Summary)
|
||||||
|
|
||||||
|
def add(self, d):
|
||||||
|
"""Adds a dictionary of scalars.
|
||||||
|
Args:
|
||||||
|
d (dict): Dictionary of scalars to accumulate. Only elements of
|
||||||
|
scalars, zero-dimensional arrays, and variables of
|
||||||
|
zero-dimensional arrays are accumulated. When the value
|
||||||
|
is a tuple, the second element is interpreted as a weight.
|
||||||
|
"""
|
||||||
|
summaries = self._summaries
|
||||||
|
for k, v in d.items():
|
||||||
|
w = 1
|
||||||
|
if isinstance(v, tuple):
|
||||||
|
v = v[0]
|
||||||
|
w = v[1]
|
||||||
|
summaries[k].add(v, weight=w)
|
||||||
|
|
||||||
|
def compute_mean(self):
|
||||||
|
"""Creates a dictionary of mean values.
|
||||||
|
It returns a single dictionary that holds a mean value for each entry
|
||||||
|
added to the summary.
|
||||||
|
Returns:
|
||||||
|
dict: Dictionary of mean values.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
name: summary.compute_mean()
|
||||||
|
for name, summary in self._summaries.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def make_statistics(self):
|
||||||
|
"""Creates a dictionary of statistics.
|
||||||
|
It returns a single dictionary that holds mean and standard deviation
|
||||||
|
values for every entry added to the summary. For an entry of name
|
||||||
|
``'key'``, these values are added to the dictionary by names ``'key'``
|
||||||
|
and ``'key.std'``, respectively.
|
||||||
|
Returns:
|
||||||
|
dict: Dictionary of statistics of all entries.
|
||||||
|
"""
|
||||||
|
stats = {}
|
||||||
|
for name, summary in self._summaries.items():
|
||||||
|
mean, std = summary.make_statistics()
|
||||||
|
stats[name] = mean
|
||||||
|
stats[name + '.std'] = std
|
||||||
|
|
||||||
|
return stats
|
@ -0,0 +1,13 @@
|
|||||||
|
from .interval_trigger import IntervalTrigger
|
||||||
|
|
||||||
|
def never_fail_trigger(trainer):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_trigger(trigger):
|
||||||
|
if trigger is None:
|
||||||
|
return never_fail_trigger
|
||||||
|
if callable(trigger):
|
||||||
|
return trigger
|
||||||
|
else:
|
||||||
|
trigger = IntervalTrigger(*trigger)
|
||||||
|
return trigger
|
@ -0,0 +1,24 @@
|
|||||||
|
|
||||||
|
class IntervalTrigger():
|
||||||
|
"""A Predicate to do something every N cycle."""
|
||||||
|
|
||||||
|
def __init__(self, period: int, unit: str):
|
||||||
|
if unit not in ("iteration", "epoch"):
|
||||||
|
raise ValueError("unit should be 'iteration' or 'epoch'")
|
||||||
|
if period <= 0:
|
||||||
|
raise ValueError("period should be a positive integer.")
|
||||||
|
self.period = period
|
||||||
|
self.unit = unit
|
||||||
|
self.last_index = None
|
||||||
|
|
||||||
|
def __call__(self, trainer):
|
||||||
|
if self.last_index is None:
|
||||||
|
last_index = getattr(trainer.updater.state, self.unit)
|
||||||
|
self.last_index = last_index
|
||||||
|
|
||||||
|
last_index = self.last_index
|
||||||
|
index = getattr(trainer.updater.state, self.unit)
|
||||||
|
fire = index // self.period != last_index // self.period
|
||||||
|
|
||||||
|
self.last_index = index
|
||||||
|
return fire
|
@ -0,0 +1,17 @@
|
|||||||
|
|
||||||
|
class LimitTrigger():
|
||||||
|
"""A Predicate to decide whether to stop."""
|
||||||
|
|
||||||
|
def __init__(self, limit: int, unit: str):
|
||||||
|
if unit not in ("iteration", "epoch"):
|
||||||
|
raise ValueError("unit should be 'iteration' or 'epoch'")
|
||||||
|
if limit <= 0:
|
||||||
|
raise ValueError("limit should be a positive integer.")
|
||||||
|
self.limit = limit
|
||||||
|
self.unit = unit
|
||||||
|
|
||||||
|
def __call__(self, trainer):
|
||||||
|
state = trainer.updater.state
|
||||||
|
index = getattr(state, self.unit)
|
||||||
|
fire = index >= self.limit
|
||||||
|
return fire
|
@ -0,0 +1,17 @@
|
|||||||
|
class TimeTrigger():
|
||||||
|
"""Trigger based on a fixed time interval.
|
||||||
|
This trigger accepts iterations with a given interval time.
|
||||||
|
Args:
|
||||||
|
period (float): Interval time. It is given in seconds.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, period):
|
||||||
|
self._period = period
|
||||||
|
self._next_time = self._period
|
||||||
|
|
||||||
|
def __call__(self, trainer):
|
||||||
|
if self._next_time < trainer.elapsed_time:
|
||||||
|
self._next_time += self._period
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
@ -0,0 +1,179 @@
|
|||||||
|
from typing import Dict
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from paddle import Tensor
|
||||||
|
from paddle.io import DataLoader
|
||||||
|
from paddle.io import DistributedBatchSampler
|
||||||
|
from paddle.nn import Layer
|
||||||
|
from paddle.optimizer import Optimizer
|
||||||
|
from timer import timer
|
||||||
|
|
||||||
|
from deepspeech.training.reporter import report
|
||||||
|
from deepspeech.training.updaters.updater import UpdaterBase
|
||||||
|
from deepspeech.training.updaters.updater import UpdaterState
|
||||||
|
|
||||||
|
from deepspeech.utils.log import Log
|
||||||
|
|
||||||
|
__all__ = ["StandardUpdater"]
|
||||||
|
|
||||||
|
logger = Log(__name__).getlog()
|
||||||
|
|
||||||
|
class StandardUpdater(UpdaterBase):
|
||||||
|
"""An example of over-simplification. Things may not be that simple, but
|
||||||
|
you can subclass it to fit your need.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model: Layer,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
dataloader: DataLoader,
|
||||||
|
init_state: Optional[UpdaterState]=None):
|
||||||
|
# it is designed to hold multiple models
|
||||||
|
models = {"main": model}
|
||||||
|
self.models: Dict[str, Layer] = models
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
# it is designed to hold multiple optimizers
|
||||||
|
optimizers = {"main": optimizer}
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.optimizers: Dict[str, Optimizer] = optimizers
|
||||||
|
|
||||||
|
# dataloaders
|
||||||
|
self.dataloader = dataloader
|
||||||
|
|
||||||
|
# init state
|
||||||
|
if init_state is None:
|
||||||
|
self.state = UpdaterState()
|
||||||
|
else:
|
||||||
|
self.state = init_state
|
||||||
|
|
||||||
|
self.train_iterator = iter(dataloader)
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
# We increase the iteration index after updating and before extension.
|
||||||
|
# Here are the reasons.
|
||||||
|
|
||||||
|
# 0. Snapshotting(as well as other extensions, like visualizer) is
|
||||||
|
# executed after a step of updating;
|
||||||
|
# 1. We decide to increase the iteration index after updating and
|
||||||
|
# before any all extension is executed.
|
||||||
|
# 3. We do not increase the iteration after extension because we
|
||||||
|
# prefer a consistent resume behavior, when load from a
|
||||||
|
# `snapshot_iter_100.pdz` then the next step to train is `101`,
|
||||||
|
# naturally. But if iteration is increased increased after
|
||||||
|
# extension(including snapshot), then, a `snapshot_iter_99` is
|
||||||
|
# loaded. You would need a extra increasing of the iteration idex
|
||||||
|
# before training to avoid another iteration `99`, which has been
|
||||||
|
# done before snapshotting.
|
||||||
|
# 4. Thus iteration index represrnts "currently how mant epochs has
|
||||||
|
# been done."
|
||||||
|
# NOTE: use report to capture the correctly value. If you want to
|
||||||
|
# report the learning rate used for a step, you must report it before
|
||||||
|
# the learning rate scheduler's step() has been called. In paddle's
|
||||||
|
# convention, we do not use an extension to change the learning rate.
|
||||||
|
# so if you want to report it, do it in the updater.
|
||||||
|
|
||||||
|
# Then here comes the next question. When is the proper time to
|
||||||
|
# increase the epoch index? Since all extensions are executed after
|
||||||
|
# updating, it is the time that after updating is the proper time to
|
||||||
|
# increase epoch index.
|
||||||
|
# 1. If we increase the epoch index before updating, then an extension
|
||||||
|
# based ot epoch would miss the correct timing. It could only be
|
||||||
|
# triggerd after an extra updating.
|
||||||
|
# 2. Theoretically, when an epoch is done, the epoch index should be
|
||||||
|
# increased. So it would be increase after updating.
|
||||||
|
# 3. Thus, eppoch index represents "currently how many epochs has been
|
||||||
|
# done." So it starts from 0.
|
||||||
|
|
||||||
|
# switch to training mode
|
||||||
|
for model in self.models.values():
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
# training for a step is implemented here
|
||||||
|
batch = self.read_batch()
|
||||||
|
self.update_core(batch)
|
||||||
|
|
||||||
|
self.state.iteration += 1
|
||||||
|
if self.updates_per_epoch is not None:
|
||||||
|
if self.state.iteration % self.updates_per_epoch == 0:
|
||||||
|
self.state.epoch += 1
|
||||||
|
|
||||||
|
def update_core(self, batch):
|
||||||
|
"""A simple case for a training step. Basic assumptions are:
|
||||||
|
Single model;
|
||||||
|
Single optimizer;
|
||||||
|
A batch from the dataloader is just the input of the model;
|
||||||
|
The model return a single loss, or a dict containing serval losses.
|
||||||
|
Parameters updates at every batch, no gradient accumulation.
|
||||||
|
"""
|
||||||
|
loss = self.model(*batch)
|
||||||
|
|
||||||
|
if isinstance(loss, Tensor):
|
||||||
|
loss_dict = {"main": loss}
|
||||||
|
else:
|
||||||
|
# Dict[str, Tensor]
|
||||||
|
loss_dict = loss
|
||||||
|
if "main" not in loss_dict:
|
||||||
|
main_loss = 0
|
||||||
|
for loss_item in loss.values():
|
||||||
|
main_loss += loss_item
|
||||||
|
loss_dict["main"] = main_loss
|
||||||
|
|
||||||
|
for name, loss_item in loss_dict.items():
|
||||||
|
report(name, float(loss_item))
|
||||||
|
|
||||||
|
self.optimizer.clear_gradient()
|
||||||
|
loss_dict["main"].backward()
|
||||||
|
self.optimizer.update()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def updates_per_epoch(self):
|
||||||
|
"""Number of updater per epoch, determined by the length of the
|
||||||
|
dataloader."""
|
||||||
|
length_of_dataloader = None
|
||||||
|
try:
|
||||||
|
length_of_dataloader = len(self.dataloader)
|
||||||
|
except TypeError:
|
||||||
|
logger.debug("This dataloader has no __len__.")
|
||||||
|
finally:
|
||||||
|
return length_of_dataloader
|
||||||
|
|
||||||
|
def new_epoch(self):
|
||||||
|
"""Start a new epoch."""
|
||||||
|
# NOTE: all batch sampler for distributed training should
|
||||||
|
# subclass DistributedBatchSampler and implement `set_epoch` method
|
||||||
|
if hasattr(self.dataloader, "batch_sampler")
|
||||||
|
batch_sampler = self.dataloader.batch_sampler
|
||||||
|
if isinstance(batch_sampler, DistributedBatchSampler):
|
||||||
|
batch_sampler.set_epoch(self.state.epoch)
|
||||||
|
self.train_iterator = iter(self.dataloader)
|
||||||
|
|
||||||
|
def read_batch(self):
|
||||||
|
"""Read a batch from the data loader, auto renew when data is exhausted."""
|
||||||
|
with timer() as t:
|
||||||
|
try:
|
||||||
|
batch = next(self.train_iterator)
|
||||||
|
except StopIteration:
|
||||||
|
self.new_epoch()
|
||||||
|
batch = next(self.train_iterator)
|
||||||
|
logger.debug(
|
||||||
|
f"Read a batch takes {t.elapse}s.") # replace it with logger
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
"""State dict of a Updater, model, optimizer and updater state are included."""
|
||||||
|
state_dict = super().state_dict()
|
||||||
|
for name, model in self.models.items():
|
||||||
|
state_dict[f"{name}_params"] = model.state_dict()
|
||||||
|
for name, optim in self.optimizers.items():
|
||||||
|
state_dict[f"{name}_optimizer"] = optim.state_dict()
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def set_state_dict(self, state_dict):
|
||||||
|
"""Set state dict for a Updater. Parameters of models, states for
|
||||||
|
optimizers and UpdaterState are restored."""
|
||||||
|
for name, model in self.models.items():
|
||||||
|
model.set_state_dict(state_dict[f"{name}_params"])
|
||||||
|
for name, optim in self.optimizers.items():
|
||||||
|
optim.set_state_dict(state_dict[f"{name}_optimizer"])
|
||||||
|
super().set_state_dict(state_dict)
|
@ -0,0 +1,171 @@
|
|||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
from collections import OrderedDict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable
|
||||||
|
from typing import List
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import six
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
from deepspeech.training.extensions.extension import Extension
|
||||||
|
from deepspeech.training.extensions.extension import PRIORITY_READER
|
||||||
|
from deepspeech.training.reporter import scope
|
||||||
|
from deepspeech.training.triggers import get_trigger
|
||||||
|
from deepspeech.training.triggers.limit_trigger import LimitTrigger
|
||||||
|
from deepspeech.training.updaters.updater import UpdaterBase
|
||||||
|
|
||||||
|
|
||||||
|
class _ExtensionEntry():
|
||||||
|
def __init__(self, extension, trigger, priority):
|
||||||
|
self.extension = extension
|
||||||
|
self.trigger = trigger
|
||||||
|
self.priority = priority
|
||||||
|
|
||||||
|
|
||||||
|
class Trainer():
|
||||||
|
def __init__(self,
|
||||||
|
updater: UpdaterBase,
|
||||||
|
stop_trigger: Callable=None,
|
||||||
|
out: Union[str, Path]='result',
|
||||||
|
extensions: List[Extension]=None):
|
||||||
|
self.updater = updater
|
||||||
|
self.extensions = OrderedDict()
|
||||||
|
self.stop_trigger = LimitTrigger(*stop_trigger)
|
||||||
|
self.out = Path(out)
|
||||||
|
self.observation = None
|
||||||
|
|
||||||
|
self._done = False
|
||||||
|
if extensions:
|
||||||
|
for ext in extensions:
|
||||||
|
self.extend(ext)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_before_training(self):
|
||||||
|
return self.updater.state.iteration == 0
|
||||||
|
|
||||||
|
def extend(self, extension, name=None, trigger=None, priority=None):
|
||||||
|
# get name for the extension
|
||||||
|
# argument \
|
||||||
|
# -> extention's name \
|
||||||
|
# -> default_name (class name, when it is an object) \
|
||||||
|
# -> function name when it is a function \
|
||||||
|
# -> error
|
||||||
|
|
||||||
|
if name is None:
|
||||||
|
name = getattr(extension, 'name', None)
|
||||||
|
if name is None:
|
||||||
|
name = getattr(extension, 'default_name', None)
|
||||||
|
if name is None:
|
||||||
|
name = getattr(extension, '__name__', None)
|
||||||
|
if name is None:
|
||||||
|
raise ValueError("Name is not given for the extension.")
|
||||||
|
if name == 'training':
|
||||||
|
raise ValueError("training is a reserved name.")
|
||||||
|
|
||||||
|
if trigger is None:
|
||||||
|
trigger = getattr(extension, 'trigger', (1, 'iteration'))
|
||||||
|
trigger = get_trigger(trigger)
|
||||||
|
|
||||||
|
if priority is None:
|
||||||
|
priority = getattr(extension, 'priority', PRIORITY_READER)
|
||||||
|
|
||||||
|
# add suffix to avoid nameing conflict
|
||||||
|
ordinal = 0
|
||||||
|
modified_name = name
|
||||||
|
while modified_name in self.extensions:
|
||||||
|
ordinal += 1
|
||||||
|
modified_name = f"{name}_{ordinal}"
|
||||||
|
extension.name = modified_name
|
||||||
|
|
||||||
|
self.extensions[modified_name] = _ExtensionEntry(extension, trigger,
|
||||||
|
priority)
|
||||||
|
|
||||||
|
def get_extension(self, name):
|
||||||
|
"""get extension by name."""
|
||||||
|
extensions = self.extensions
|
||||||
|
if name in extensions:
|
||||||
|
return extensions[name].extension
|
||||||
|
else:
|
||||||
|
raise ValueError(f'extension {name} not found')
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
if self._done:
|
||||||
|
raise RuntimeError("Training is already done!.")
|
||||||
|
|
||||||
|
self.out.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# sort extensions by priorities once
|
||||||
|
extension_order = sorted(
|
||||||
|
self.extensions.keys(),
|
||||||
|
key=lambda name: self.extensions[name].priority,
|
||||||
|
reverse=True)
|
||||||
|
extensions = [(name, self.extensions[name]) for name in extension_order]
|
||||||
|
|
||||||
|
# initializing all extensions
|
||||||
|
for name, entry in extensions:
|
||||||
|
if hasattr(entry.extension, "initialize"):
|
||||||
|
entry.extension.initialize(self)
|
||||||
|
|
||||||
|
update = self.updater.update # training step
|
||||||
|
stop_trigger = self.stop_trigger
|
||||||
|
|
||||||
|
# display only one progress bar
|
||||||
|
max_iteration = None
|
||||||
|
if isinstance(stop_trigger, LimitTrigger):
|
||||||
|
if stop_trigger.unit == 'epoch':
|
||||||
|
max_epoch = self.stop_trigger.limit
|
||||||
|
updates_per_epoch = getattr(self.updater, "updates_per_epoch",
|
||||||
|
None)
|
||||||
|
max_iteration = max_epoch * updates_per_epoch if updates_per_epoch else None
|
||||||
|
else:
|
||||||
|
max_iteration = self.stop_trigger.limit
|
||||||
|
|
||||||
|
p = tqdm.tqdm(initial=self.updater.state.iteration, total=max_iteration)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while not stop_trigger(self):
|
||||||
|
self.observation = {}
|
||||||
|
# set observation as the report target
|
||||||
|
# you can use report freely in Updater.update()
|
||||||
|
|
||||||
|
# updating parameters and state
|
||||||
|
with scope(self.observation):
|
||||||
|
update()
|
||||||
|
p.update()
|
||||||
|
|
||||||
|
# execute extension when necessary
|
||||||
|
for name, entry in extensions:
|
||||||
|
if entry.trigger(self):
|
||||||
|
entry.extension(self)
|
||||||
|
|
||||||
|
# print("###", self.observation)
|
||||||
|
except Exception as e:
|
||||||
|
f = sys.stderr
|
||||||
|
f.write(f"Exception in main training loop: {e}\n")
|
||||||
|
f.write("Traceback (most recent call last):\n")
|
||||||
|
traceback.print_tb(sys.exc_info()[2])
|
||||||
|
f.write(
|
||||||
|
"Trainer extensions will try to handle the extension. Then all extensions will finalize."
|
||||||
|
)
|
||||||
|
|
||||||
|
# capture the exception in the mian training loop
|
||||||
|
exc_info = sys.exc_info()
|
||||||
|
|
||||||
|
# try to handle it
|
||||||
|
for name, entry in extensions:
|
||||||
|
if hasattr(entry.extension, "on_error"):
|
||||||
|
try:
|
||||||
|
entry.extension.on_error(self, e, sys.exc_info()[2])
|
||||||
|
except Exception as ee:
|
||||||
|
f.write(f"Exception in error handler: {ee}\n")
|
||||||
|
f.write('Traceback (most recent call last):\n')
|
||||||
|
traceback.print_tb(sys.exc_info()[2])
|
||||||
|
|
||||||
|
# raise exception in main training loop
|
||||||
|
six.reraise(*exc_info)
|
||||||
|
finally:
|
||||||
|
for name, entry in extensions:
|
||||||
|
if hasattr(entry.extension, "finalize"):
|
||||||
|
entry.extension.finalize(self)
|
@ -0,0 +1,82 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from deepspeech.utils.log import Log
|
||||||
|
|
||||||
|
__all__ = ["UpdaterBase", "UpdaterState"]
|
||||||
|
|
||||||
|
logger = Log(__name__).getlog()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UpdaterState:
|
||||||
|
iteration: int = 0
|
||||||
|
epoch: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class UpdaterBase():
|
||||||
|
"""An updater is the abstraction of how a model is trained given the
|
||||||
|
dataloader and the optimizer.
|
||||||
|
The `update_core` method is a step in the training loop with only necessary
|
||||||
|
operations (get a batch, forward and backward, update the parameters).
|
||||||
|
Other stuffs are made extensions. Visualization, saving, loading and
|
||||||
|
periodical validation and evaluation are not considered here.
|
||||||
|
But even in such simplist case, things are not that simple. There is an
|
||||||
|
attempt to standardize this process and requires only the model and
|
||||||
|
dataset and do all the stuffs automatically. But this may hurt flexibility.
|
||||||
|
If we assume a batch yield from the dataloader is just the input to the
|
||||||
|
model, we will find that some model requires more arguments, or just some
|
||||||
|
keyword arguments. But this prevents us from over-simplifying it.
|
||||||
|
From another perspective, the batch may includes not just the input, but
|
||||||
|
also the target. But the model's forward method may just need the input.
|
||||||
|
We can pass a dict or a super-long tuple to the model and let it pick what
|
||||||
|
it really needs. But this is an abuse of lazy interface.
|
||||||
|
After all, we care about how a model is trained. But just how the model is
|
||||||
|
used for inference. We want to control how a model is trained. We just
|
||||||
|
don't want to be messed up with other auxiliary code.
|
||||||
|
So the best practice is to define a model and define a updater for it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, init_state=None):
|
||||||
|
if init_state is None:
|
||||||
|
self.state = UpdaterState()
|
||||||
|
else:
|
||||||
|
self.state = init_state
|
||||||
|
|
||||||
|
def update(self, batch):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Implement your own `update` method for training a step.")
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
state_dict = {
|
||||||
|
"epoch": self.state.epoch,
|
||||||
|
"iteration": self.state.iteration,
|
||||||
|
}
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def set_state_dict(self, state_dict):
|
||||||
|
self.state.epoch = state_dict["epoch"]
|
||||||
|
self.state.iteration = state_dict["iteration"]
|
||||||
|
|
||||||
|
def save(self, path):
|
||||||
|
logger.debug(f"Saving to {path}.")
|
||||||
|
archive = self.state_dict()
|
||||||
|
paddle.save(archive, str(path))
|
||||||
|
|
||||||
|
def load(self, path):
|
||||||
|
logger.debug(f"Loading from {path}.")
|
||||||
|
archive = paddle.load(str(path))
|
||||||
|
self.set_state_dict(archive)
|
Loading…
Reference in new issue