chaner style updater

pull/787/head
Hui Zhang 3 years ago
parent 787fa9d91f
commit cfdca210ff

@ -0,0 +1,28 @@
from typing import Callable
from .extension import Extension
def make_extension(trigger: Callable=None,
default_name: str=None,
priority: int=None,
finalizer: Callable=None,
initializer: Callable=None,
on_error: Callable=None):
"""Make an Extension-like object by injecting required attributes to it.
"""
if trigger is None:
trigger = Extension.trigger
if priority is None:
priority = Extension.priority
def decorator(ext):
ext.trigger = trigger
ext.default_name = default_name or ext.__name__
ext.priority = priority
ext.finalize = finalizer
ext.on_error = on_error
ext.initialize = initializer
return ext
return decorator

@ -0,0 +1,58 @@
from typing import Dict
import paddle
from paddle.io import DataLoader
from paddle.nn import Layer
import extension
from ..reporter import DictSummary
from ..reporter import report
from ..reporter import scope
class StandardEvaluator(extension.Extension):
trigger = (1, 'epoch')
default_name = 'validation'
priority = extension.PRIORITY_WRITER
name = None
def __init__(self, model: Layer, dataloader: DataLoader):
# it is designed to hold multiple models
models = {"main": model}
self.models: Dict[str, Layer] = models
self.model = model
# dataloaders
self.dataloader = dataloader
def evaluate_core(self, batch):
# compute
self.model(batch) # you may report here
def evaluate(self):
# switch to eval mode
for model in self.models.values():
model.eval()
# to average evaluation metrics
summary = DictSummary()
for batch in self.dataloader:
observation = {}
with scope(observation):
# main evaluation computation here.
with paddle.no_grad():
self.evaluate_core(batch)
summary.add(observation)
summary = summary.compute_mean()
return summary
def __call__(self, trainer=None):
# evaluate and report the averaged metric to current observation
# if it is used to extend a trainer, the metrics is reported to
# to observation of the trainer
# or otherwise, you can use your own observation
summary = self.evaluate()
for k, v in summary.items():
report(k, v)

@ -0,0 +1,41 @@
from typing import Callable
PRIORITY_WRITER = 300
PRIORITY_EDITOR = 200
PRIORITY_READER = 100
class Extension():
"""Extension to customize the behavior of Trainer."""
trigger = (1, 'iteration')
priority = PRIORITY_READER
name = None
@property
def default_name(self):
"""Default name of the extension, class name by default."""
return type(self).__name__
def __call__(self, trainer):
"""Main action of the extention. After each update, it is executed
when the trigger fires."""
raise NotImplementedError(
'Extension implementation must override __call__.')
def initialize(self, trainer):
"""Action that is executed once to get the corect trainer state.
It is called before training normally, but if the trainer restores
states with an Snapshot extension, this method should also be called.
"""
pass
def on_error(self, trainer, exc, tb):
"""Handles the error raised during training before finalization.
"""
pass
def finalize(self, trainer):
"""Action that is executed when training is done.
For example, visualizers would need to be closed.
"""
pass

@ -0,0 +1,102 @@
import os
from datetime import datetime
from pathlib import Path
from typing import Any
from typing import Dict
from typing import List
import jsonlines
from deepspeech.training.updaters.trainer import Trainer
from deepspeech.training.extensions import extension
from deepspeech.utils.mp_tools import rank_zero_only
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
def load_records(records_fp):
"""Load record files (json lines.)"""
with jsonlines.open(records_fp, 'r') as reader:
records = list(reader)
return records
class Snapshot(extension.Extension):
"""An extension to make snapshot of the updater object inside
the trainer. It is done by calling the updater's `save` method.
An Updater save its state_dict by default, which contains the
updater state, (i.e. epoch and iteration) and all the model
parameters and optimizer states. If the updater inside the trainer
subclasses StandardUpdater, everything is good to go.
Parameters
----------
checkpoint_dir : Union[str, Path]
The directory to save checkpoints into.
"""
trigger = (1, 'epoch')
priority = -100
default_name = "snapshot"
def __init__(self, max_size: int=5, snapshot_on_error: bool=False):
self.records: List[Dict[str, Any]] = []
self.max_size = max_size
self._snapshot_on_error = snapshot_on_error
self._save_all = (max_size == -1)
self.checkpoint_dir = None
def initialize(self, trainer: Trainer):
"""Setting up this extention."""
self.checkpoint_dir = trainer.out / "checkpoints"
# load existing records
record_path: Path = self.checkpoint_dir / "records.jsonl"
if record_path.exists():
logger.debug("Loading from an existing checkpoint dir")
self.records = load_records(record_path)
trainer.updater.load(self.records[-1]['path'])
def on_error(self, trainer, exc, tb):
if self._snapshot_on_error:
self.save_checkpoint_and_update(trainer)
def __call__(self, trainer: Trainer):
self.save_checkpoint_and_update(trainer)
def full(self):
"""Whether the number of snapshots it keeps track of is greater
than the max_size."""
return (not self._save_all) and len(self.records) > self.max_size
@rank_zero_only
def save_checkpoint_and_update(self, trainer: Trainer):
"""Saving new snapshot and remove the oldest snapshot if needed."""
iteration = trainer.updater.state.iteration
epoch = trainer.updater.state.epoch
num = epoch if self.trigger[1] is 'epoch' else iteration
path = self.checkpoint_dir / f"{num}.pdz"
# add the new one
trainer.updater.save(path)
record = {
"time": str(datetime.now()),
'path': str(path.resolve()), # use absolute path
'iteration': iteration,
'epoch': epoch,
}
self.records.append(record)
# remove the earist
if self.full():
eariest_record = self.records[0]
os.remove(eariest_record["path"])
self.records.pop(0)
# update the record file
record_path = self.checkpoint_dir / "records.jsonl"
with jsonlines.open(record_path, 'w') as writer:
for record in self.records:
# jsonlines.open may return a Writer or a Reader
writer.write(record) # pylint: disable=no-member

@ -0,0 +1,24 @@
from deepspeech.training.extensions import extension
from deepspeech.training.updaters.trainer import Trainer
class VisualDL(extension.Extension):
"""A wrapper of visualdl log writer. It assumes that the metrics to be visualized
are all scalars which are recorded into the `.observation` dictionary of the
trainer object. The dictionary is created for each step, thus the visualdl log
writer uses the iteration from the updater's `iteration` as the global step to
add records.
"""
trigger = (1, 'iteration')
default_name = 'visualdl'
priority = extension.PRIORITY_READER
def __init__(self, writer):
self.writer = writer
def __call__(self, trainer: Trainer):
for k, v in trainer.observation.items():
self.writer.add_scalar(k, v, step=trainer.updater.state.iteration)
def finalize(self, trainer):
self.writer.close()

@ -0,0 +1,131 @@
import contextlib
import math
from collections import defaultdict
OBSERVATIONS = None
@contextlib.contextmanager
def scope(observations):
# make `observation` the target to report to.
# it is basically a dictionary that stores temporary observations
global OBSERVATIONS
old = OBSERVATIONS
OBSERVATIONS = observations
try:
yield
finally:
OBSERVATIONS = old
def get_observations():
global OBSERVATIONS
return OBSERVATIONS
def report(name, value):
# a simple function to report named value
# you can use it everywhere, it will get the default target and writ to it
# you can think of it as std.out
observations = get_observations()
if observations is None:
return
else:
observations[name] = value
class Summary():
"""Online summarization of a sequence of scalars.
Summary computes the statistics of given scalars online.
"""
def __init__(self):
self._x = 0.0
self._x2 = 0.0
self._n = 0
def add(self, value, weight=1):
"""Adds a scalar value.
Args:
value: Scalar value to accumulate. It is either a NumPy scalar or
a zero-dimensional array (on CPU or GPU).
weight: An optional weight for the value. It is a NumPy scalar or
a zero-dimensional array (on CPU or GPU).
Default is 1 (integer).
"""
self._x += weight * value
self._x2 += weight * value * value
self._n += weight
def compute_mean(self):
"""Computes the mean."""
x, n = self._x, self._n
return x / n
def make_statistics(self):
"""Computes and returns the mean and standard deviation values.
Returns:
tuple: Mean and standard deviation values.
"""
x, n = self._x, self._n
mean = x / n
var = self._x2 / n - mean * mean
std = math.sqrt(var)
return mean, std
class DictSummary():
"""Online summarization of a sequence of dictionaries.
``DictSummary`` computes the statistics of a given set of scalars online.
It only computes the statistics for scalar values and variables of scalar
values in the dictionaries.
"""
def __init__(self):
self._summaries = defaultdict(Summary)
def add(self, d):
"""Adds a dictionary of scalars.
Args:
d (dict): Dictionary of scalars to accumulate. Only elements of
scalars, zero-dimensional arrays, and variables of
zero-dimensional arrays are accumulated. When the value
is a tuple, the second element is interpreted as a weight.
"""
summaries = self._summaries
for k, v in d.items():
w = 1
if isinstance(v, tuple):
v = v[0]
w = v[1]
summaries[k].add(v, weight=w)
def compute_mean(self):
"""Creates a dictionary of mean values.
It returns a single dictionary that holds a mean value for each entry
added to the summary.
Returns:
dict: Dictionary of mean values.
"""
return {
name: summary.compute_mean()
for name, summary in self._summaries.items()
}
def make_statistics(self):
"""Creates a dictionary of statistics.
It returns a single dictionary that holds mean and standard deviation
values for every entry added to the summary. For an entry of name
``'key'``, these values are added to the dictionary by names ``'key'``
and ``'key.std'``, respectively.
Returns:
dict: Dictionary of statistics of all entries.
"""
stats = {}
for name, summary in self._summaries.items():
mean, std = summary.make_statistics()
stats[name] = mean
stats[name + '.std'] = std
return stats

@ -0,0 +1,13 @@
from .interval_trigger import IntervalTrigger
def never_fail_trigger(trainer):
return False
def get_trigger(trigger):
if trigger is None:
return never_fail_trigger
if callable(trigger):
return trigger
else:
trigger = IntervalTrigger(*trigger)
return trigger

@ -0,0 +1,24 @@
class IntervalTrigger():
"""A Predicate to do something every N cycle."""
def __init__(self, period: int, unit: str):
if unit not in ("iteration", "epoch"):
raise ValueError("unit should be 'iteration' or 'epoch'")
if period <= 0:
raise ValueError("period should be a positive integer.")
self.period = period
self.unit = unit
self.last_index = None
def __call__(self, trainer):
if self.last_index is None:
last_index = getattr(trainer.updater.state, self.unit)
self.last_index = last_index
last_index = self.last_index
index = getattr(trainer.updater.state, self.unit)
fire = index // self.period != last_index // self.period
self.last_index = index
return fire

@ -0,0 +1,17 @@
class LimitTrigger():
"""A Predicate to decide whether to stop."""
def __init__(self, limit: int, unit: str):
if unit not in ("iteration", "epoch"):
raise ValueError("unit should be 'iteration' or 'epoch'")
if limit <= 0:
raise ValueError("limit should be a positive integer.")
self.limit = limit
self.unit = unit
def __call__(self, trainer):
state = trainer.updater.state
index = getattr(state, self.unit)
fire = index >= self.limit
return fire

@ -0,0 +1,17 @@
class TimeTrigger():
"""Trigger based on a fixed time interval.
This trigger accepts iterations with a given interval time.
Args:
period (float): Interval time. It is given in seconds.
"""
def __init__(self, period):
self._period = period
self._next_time = self._period
def __call__(self, trainer):
if self._next_time < trainer.elapsed_time:
self._next_time += self._period
return True
else:
return False

@ -0,0 +1,179 @@
from typing import Dict
from typing import Optional
from paddle import Tensor
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from paddle.nn import Layer
from paddle.optimizer import Optimizer
from timer import timer
from deepspeech.training.reporter import report
from deepspeech.training.updaters.updater import UpdaterBase
from deepspeech.training.updaters.updater import UpdaterState
from deepspeech.utils.log import Log
__all__ = ["StandardUpdater"]
logger = Log(__name__).getlog()
class StandardUpdater(UpdaterBase):
"""An example of over-simplification. Things may not be that simple, but
you can subclass it to fit your need.
"""
def __init__(self,
model: Layer,
optimizer: Optimizer,
dataloader: DataLoader,
init_state: Optional[UpdaterState]=None):
# it is designed to hold multiple models
models = {"main": model}
self.models: Dict[str, Layer] = models
self.model = model
# it is designed to hold multiple optimizers
optimizers = {"main": optimizer}
self.optimizer = optimizer
self.optimizers: Dict[str, Optimizer] = optimizers
# dataloaders
self.dataloader = dataloader
# init state
if init_state is None:
self.state = UpdaterState()
else:
self.state = init_state
self.train_iterator = iter(dataloader)
def update(self):
# We increase the iteration index after updating and before extension.
# Here are the reasons.
# 0. Snapshotting(as well as other extensions, like visualizer) is
# executed after a step of updating;
# 1. We decide to increase the iteration index after updating and
# before any all extension is executed.
# 3. We do not increase the iteration after extension because we
# prefer a consistent resume behavior, when load from a
# `snapshot_iter_100.pdz` then the next step to train is `101`,
# naturally. But if iteration is increased increased after
# extension(including snapshot), then, a `snapshot_iter_99` is
# loaded. You would need a extra increasing of the iteration idex
# before training to avoid another iteration `99`, which has been
# done before snapshotting.
# 4. Thus iteration index represrnts "currently how mant epochs has
# been done."
# NOTE: use report to capture the correctly value. If you want to
# report the learning rate used for a step, you must report it before
# the learning rate scheduler's step() has been called. In paddle's
# convention, we do not use an extension to change the learning rate.
# so if you want to report it, do it in the updater.
# Then here comes the next question. When is the proper time to
# increase the epoch index? Since all extensions are executed after
# updating, it is the time that after updating is the proper time to
# increase epoch index.
# 1. If we increase the epoch index before updating, then an extension
# based ot epoch would miss the correct timing. It could only be
# triggerd after an extra updating.
# 2. Theoretically, when an epoch is done, the epoch index should be
# increased. So it would be increase after updating.
# 3. Thus, eppoch index represents "currently how many epochs has been
# done." So it starts from 0.
# switch to training mode
for model in self.models.values():
model.train()
# training for a step is implemented here
batch = self.read_batch()
self.update_core(batch)
self.state.iteration += 1
if self.updates_per_epoch is not None:
if self.state.iteration % self.updates_per_epoch == 0:
self.state.epoch += 1
def update_core(self, batch):
"""A simple case for a training step. Basic assumptions are:
Single model;
Single optimizer;
A batch from the dataloader is just the input of the model;
The model return a single loss, or a dict containing serval losses.
Parameters updates at every batch, no gradient accumulation.
"""
loss = self.model(*batch)
if isinstance(loss, Tensor):
loss_dict = {"main": loss}
else:
# Dict[str, Tensor]
loss_dict = loss
if "main" not in loss_dict:
main_loss = 0
for loss_item in loss.values():
main_loss += loss_item
loss_dict["main"] = main_loss
for name, loss_item in loss_dict.items():
report(name, float(loss_item))
self.optimizer.clear_gradient()
loss_dict["main"].backward()
self.optimizer.update()
@property
def updates_per_epoch(self):
"""Number of updater per epoch, determined by the length of the
dataloader."""
length_of_dataloader = None
try:
length_of_dataloader = len(self.dataloader)
except TypeError:
logger.debug("This dataloader has no __len__.")
finally:
return length_of_dataloader
def new_epoch(self):
"""Start a new epoch."""
# NOTE: all batch sampler for distributed training should
# subclass DistributedBatchSampler and implement `set_epoch` method
if hasattr(self.dataloader, "batch_sampler")
batch_sampler = self.dataloader.batch_sampler
if isinstance(batch_sampler, DistributedBatchSampler):
batch_sampler.set_epoch(self.state.epoch)
self.train_iterator = iter(self.dataloader)
def read_batch(self):
"""Read a batch from the data loader, auto renew when data is exhausted."""
with timer() as t:
try:
batch = next(self.train_iterator)
except StopIteration:
self.new_epoch()
batch = next(self.train_iterator)
logger.debug(
f"Read a batch takes {t.elapse}s.") # replace it with logger
return batch
def state_dict(self):
"""State dict of a Updater, model, optimizer and updater state are included."""
state_dict = super().state_dict()
for name, model in self.models.items():
state_dict[f"{name}_params"] = model.state_dict()
for name, optim in self.optimizers.items():
state_dict[f"{name}_optimizer"] = optim.state_dict()
return state_dict
def set_state_dict(self, state_dict):
"""Set state dict for a Updater. Parameters of models, states for
optimizers and UpdaterState are restored."""
for name, model in self.models.items():
model.set_state_dict(state_dict[f"{name}_params"])
for name, optim in self.optimizers.items():
optim.set_state_dict(state_dict[f"{name}_optimizer"])
super().set_state_dict(state_dict)

@ -0,0 +1,171 @@
import sys
import traceback
from collections import OrderedDict
from pathlib import Path
from typing import Callable
from typing import List
from typing import Union
import six
import tqdm
from deepspeech.training.extensions.extension import Extension
from deepspeech.training.extensions.extension import PRIORITY_READER
from deepspeech.training.reporter import scope
from deepspeech.training.triggers import get_trigger
from deepspeech.training.triggers.limit_trigger import LimitTrigger
from deepspeech.training.updaters.updater import UpdaterBase
class _ExtensionEntry():
def __init__(self, extension, trigger, priority):
self.extension = extension
self.trigger = trigger
self.priority = priority
class Trainer():
def __init__(self,
updater: UpdaterBase,
stop_trigger: Callable=None,
out: Union[str, Path]='result',
extensions: List[Extension]=None):
self.updater = updater
self.extensions = OrderedDict()
self.stop_trigger = LimitTrigger(*stop_trigger)
self.out = Path(out)
self.observation = None
self._done = False
if extensions:
for ext in extensions:
self.extend(ext)
@property
def is_before_training(self):
return self.updater.state.iteration == 0
def extend(self, extension, name=None, trigger=None, priority=None):
# get name for the extension
# argument \
# -> extention's name \
# -> default_name (class name, when it is an object) \
# -> function name when it is a function \
# -> error
if name is None:
name = getattr(extension, 'name', None)
if name is None:
name = getattr(extension, 'default_name', None)
if name is None:
name = getattr(extension, '__name__', None)
if name is None:
raise ValueError("Name is not given for the extension.")
if name == 'training':
raise ValueError("training is a reserved name.")
if trigger is None:
trigger = getattr(extension, 'trigger', (1, 'iteration'))
trigger = get_trigger(trigger)
if priority is None:
priority = getattr(extension, 'priority', PRIORITY_READER)
# add suffix to avoid nameing conflict
ordinal = 0
modified_name = name
while modified_name in self.extensions:
ordinal += 1
modified_name = f"{name}_{ordinal}"
extension.name = modified_name
self.extensions[modified_name] = _ExtensionEntry(extension, trigger,
priority)
def get_extension(self, name):
"""get extension by name."""
extensions = self.extensions
if name in extensions:
return extensions[name].extension
else:
raise ValueError(f'extension {name} not found')
def run(self):
if self._done:
raise RuntimeError("Training is already done!.")
self.out.mkdir(parents=True, exist_ok=True)
# sort extensions by priorities once
extension_order = sorted(
self.extensions.keys(),
key=lambda name: self.extensions[name].priority,
reverse=True)
extensions = [(name, self.extensions[name]) for name in extension_order]
# initializing all extensions
for name, entry in extensions:
if hasattr(entry.extension, "initialize"):
entry.extension.initialize(self)
update = self.updater.update # training step
stop_trigger = self.stop_trigger
# display only one progress bar
max_iteration = None
if isinstance(stop_trigger, LimitTrigger):
if stop_trigger.unit == 'epoch':
max_epoch = self.stop_trigger.limit
updates_per_epoch = getattr(self.updater, "updates_per_epoch",
None)
max_iteration = max_epoch * updates_per_epoch if updates_per_epoch else None
else:
max_iteration = self.stop_trigger.limit
p = tqdm.tqdm(initial=self.updater.state.iteration, total=max_iteration)
try:
while not stop_trigger(self):
self.observation = {}
# set observation as the report target
# you can use report freely in Updater.update()
# updating parameters and state
with scope(self.observation):
update()
p.update()
# execute extension when necessary
for name, entry in extensions:
if entry.trigger(self):
entry.extension(self)
# print("###", self.observation)
except Exception as e:
f = sys.stderr
f.write(f"Exception in main training loop: {e}\n")
f.write("Traceback (most recent call last):\n")
traceback.print_tb(sys.exc_info()[2])
f.write(
"Trainer extensions will try to handle the extension. Then all extensions will finalize."
)
# capture the exception in the mian training loop
exc_info = sys.exc_info()
# try to handle it
for name, entry in extensions:
if hasattr(entry.extension, "on_error"):
try:
entry.extension.on_error(self, e, sys.exc_info()[2])
except Exception as ee:
f.write(f"Exception in error handler: {ee}\n")
f.write('Traceback (most recent call last):\n')
traceback.print_tb(sys.exc_info()[2])
# raise exception in main training loop
six.reraise(*exc_info)
finally:
for name, entry in extensions:
if hasattr(entry.extension, "finalize"):
entry.extension.finalize(self)

@ -0,0 +1,82 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
import paddle
from deepspeech.utils.log import Log
__all__ = ["UpdaterBase", "UpdaterState"]
logger = Log(__name__).getlog()
@dataclass
class UpdaterState:
iteration: int = 0
epoch: int = 0
class UpdaterBase():
"""An updater is the abstraction of how a model is trained given the
dataloader and the optimizer.
The `update_core` method is a step in the training loop with only necessary
operations (get a batch, forward and backward, update the parameters).
Other stuffs are made extensions. Visualization, saving, loading and
periodical validation and evaluation are not considered here.
But even in such simplist case, things are not that simple. There is an
attempt to standardize this process and requires only the model and
dataset and do all the stuffs automatically. But this may hurt flexibility.
If we assume a batch yield from the dataloader is just the input to the
model, we will find that some model requires more arguments, or just some
keyword arguments. But this prevents us from over-simplifying it.
From another perspective, the batch may includes not just the input, but
also the target. But the model's forward method may just need the input.
We can pass a dict or a super-long tuple to the model and let it pick what
it really needs. But this is an abuse of lazy interface.
After all, we care about how a model is trained. But just how the model is
used for inference. We want to control how a model is trained. We just
don't want to be messed up with other auxiliary code.
So the best practice is to define a model and define a updater for it.
"""
def __init__(self, init_state=None):
if init_state is None:
self.state = UpdaterState()
else:
self.state = init_state
def update(self, batch):
raise NotImplementedError(
"Implement your own `update` method for training a step.")
def state_dict(self):
state_dict = {
"epoch": self.state.epoch,
"iteration": self.state.iteration,
}
return state_dict
def set_state_dict(self, state_dict):
self.state.epoch = state_dict["epoch"]
self.state.iteration = state_dict["iteration"]
def save(self, path):
logger.debug(f"Saving to {path}.")
archive = self.state_dict()
paddle.save(archive, str(path))
def load(self, path):
logger.debug(f"Loading from {path}.")
archive = paddle.load(str(path))
self.set_state_dict(archive)

@ -15,3 +15,4 @@ tensorboardX
textgrid
typeguard
yacs
jsonlines
Loading…
Cancel
Save