Merge branch 'develop' of https://github.com/PaddlePaddle/DeepSpeech into ds2_online_export
commit
2e77c3c378
@ -0,0 +1,28 @@
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from .extension import Extension
|
||||
|
||||
def make_extension(trigger: Callable=None,
|
||||
default_name: str=None,
|
||||
priority: int=None,
|
||||
finalizer: Callable=None,
|
||||
initializer: Callable=None,
|
||||
on_error: Callable=None):
|
||||
"""Make an Extension-like object by injecting required attributes to it.
|
||||
"""
|
||||
if trigger is None:
|
||||
trigger = Extension.trigger
|
||||
if priority is None:
|
||||
priority = Extension.priority
|
||||
|
||||
def decorator(ext):
|
||||
ext.trigger = trigger
|
||||
ext.default_name = default_name or ext.__name__
|
||||
ext.priority = priority
|
||||
ext.finalize = finalizer
|
||||
ext.on_error = on_error
|
||||
ext.initialize = initializer
|
||||
return ext
|
||||
|
||||
return decorator
|
@ -0,0 +1,58 @@
|
||||
from typing import Dict
|
||||
|
||||
import paddle
|
||||
from paddle.io import DataLoader
|
||||
from paddle.nn import Layer
|
||||
|
||||
import extension
|
||||
from ..reporter import DictSummary
|
||||
from ..reporter import report
|
||||
from ..reporter import scope
|
||||
|
||||
|
||||
class StandardEvaluator(extension.Extension):
|
||||
|
||||
trigger = (1, 'epoch')
|
||||
default_name = 'validation'
|
||||
priority = extension.PRIORITY_WRITER
|
||||
|
||||
name = None
|
||||
|
||||
def __init__(self, model: Layer, dataloader: DataLoader):
|
||||
# it is designed to hold multiple models
|
||||
models = {"main": model}
|
||||
self.models: Dict[str, Layer] = models
|
||||
self.model = model
|
||||
|
||||
# dataloaders
|
||||
self.dataloader = dataloader
|
||||
|
||||
def evaluate_core(self, batch):
|
||||
# compute
|
||||
self.model(batch) # you may report here
|
||||
|
||||
def evaluate(self):
|
||||
# switch to eval mode
|
||||
for model in self.models.values():
|
||||
model.eval()
|
||||
|
||||
# to average evaluation metrics
|
||||
summary = DictSummary()
|
||||
for batch in self.dataloader:
|
||||
observation = {}
|
||||
with scope(observation):
|
||||
# main evaluation computation here.
|
||||
with paddle.no_grad():
|
||||
self.evaluate_core(batch)
|
||||
summary.add(observation)
|
||||
summary = summary.compute_mean()
|
||||
return summary
|
||||
|
||||
def __call__(self, trainer=None):
|
||||
# evaluate and report the averaged metric to current observation
|
||||
# if it is used to extend a trainer, the metrics is reported to
|
||||
# to observation of the trainer
|
||||
# or otherwise, you can use your own observation
|
||||
summary = self.evaluate()
|
||||
for k, v in summary.items():
|
||||
report(k, v)
|
@ -0,0 +1,41 @@
|
||||
from typing import Callable
|
||||
|
||||
PRIORITY_WRITER = 300
|
||||
PRIORITY_EDITOR = 200
|
||||
PRIORITY_READER = 100
|
||||
|
||||
|
||||
class Extension():
|
||||
"""Extension to customize the behavior of Trainer."""
|
||||
trigger = (1, 'iteration')
|
||||
priority = PRIORITY_READER
|
||||
name = None
|
||||
|
||||
@property
|
||||
def default_name(self):
|
||||
"""Default name of the extension, class name by default."""
|
||||
return type(self).__name__
|
||||
|
||||
def __call__(self, trainer):
|
||||
"""Main action of the extention. After each update, it is executed
|
||||
when the trigger fires."""
|
||||
raise NotImplementedError(
|
||||
'Extension implementation must override __call__.')
|
||||
|
||||
def initialize(self, trainer):
|
||||
"""Action that is executed once to get the corect trainer state.
|
||||
It is called before training normally, but if the trainer restores
|
||||
states with an Snapshot extension, this method should also be called.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_error(self, trainer, exc, tb):
|
||||
"""Handles the error raised during training before finalization.
|
||||
"""
|
||||
pass
|
||||
|
||||
def finalize(self, trainer):
|
||||
"""Action that is executed when training is done.
|
||||
For example, visualizers would need to be closed.
|
||||
"""
|
||||
pass
|
@ -0,0 +1,102 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
import jsonlines
|
||||
|
||||
from deepspeech.training.updaters.trainer import Trainer
|
||||
from deepspeech.training.extensions import extension
|
||||
from deepspeech.utils.mp_tools import rank_zero_only
|
||||
|
||||
from deepspeech.utils.log import Log
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
def load_records(records_fp):
|
||||
"""Load record files (json lines.)"""
|
||||
with jsonlines.open(records_fp, 'r') as reader:
|
||||
records = list(reader)
|
||||
return records
|
||||
|
||||
|
||||
class Snapshot(extension.Extension):
|
||||
"""An extension to make snapshot of the updater object inside
|
||||
the trainer. It is done by calling the updater's `save` method.
|
||||
An Updater save its state_dict by default, which contains the
|
||||
updater state, (i.e. epoch and iteration) and all the model
|
||||
parameters and optimizer states. If the updater inside the trainer
|
||||
subclasses StandardUpdater, everything is good to go.
|
||||
Parameters
|
||||
----------
|
||||
checkpoint_dir : Union[str, Path]
|
||||
The directory to save checkpoints into.
|
||||
"""
|
||||
|
||||
trigger = (1, 'epoch')
|
||||
priority = -100
|
||||
default_name = "snapshot"
|
||||
|
||||
def __init__(self, max_size: int=5, snapshot_on_error: bool=False):
|
||||
self.records: List[Dict[str, Any]] = []
|
||||
self.max_size = max_size
|
||||
self._snapshot_on_error = snapshot_on_error
|
||||
self._save_all = (max_size == -1)
|
||||
self.checkpoint_dir = None
|
||||
|
||||
def initialize(self, trainer: Trainer):
|
||||
"""Setting up this extention."""
|
||||
self.checkpoint_dir = trainer.out / "checkpoints"
|
||||
|
||||
# load existing records
|
||||
record_path: Path = self.checkpoint_dir / "records.jsonl"
|
||||
if record_path.exists():
|
||||
logger.debug("Loading from an existing checkpoint dir")
|
||||
self.records = load_records(record_path)
|
||||
trainer.updater.load(self.records[-1]['path'])
|
||||
|
||||
def on_error(self, trainer, exc, tb):
|
||||
if self._snapshot_on_error:
|
||||
self.save_checkpoint_and_update(trainer)
|
||||
|
||||
def __call__(self, trainer: Trainer):
|
||||
self.save_checkpoint_and_update(trainer)
|
||||
|
||||
def full(self):
|
||||
"""Whether the number of snapshots it keeps track of is greater
|
||||
than the max_size."""
|
||||
return (not self._save_all) and len(self.records) > self.max_size
|
||||
|
||||
@rank_zero_only
|
||||
def save_checkpoint_and_update(self, trainer: Trainer):
|
||||
"""Saving new snapshot and remove the oldest snapshot if needed."""
|
||||
iteration = trainer.updater.state.iteration
|
||||
epoch = trainer.updater.state.epoch
|
||||
num = epoch if self.trigger[1] is 'epoch' else iteration
|
||||
path = self.checkpoint_dir / f"{num}.pdz"
|
||||
|
||||
# add the new one
|
||||
trainer.updater.save(path)
|
||||
record = {
|
||||
"time": str(datetime.now()),
|
||||
'path': str(path.resolve()), # use absolute path
|
||||
'iteration': iteration,
|
||||
'epoch': epoch,
|
||||
}
|
||||
self.records.append(record)
|
||||
|
||||
# remove the earist
|
||||
if self.full():
|
||||
eariest_record = self.records[0]
|
||||
os.remove(eariest_record["path"])
|
||||
self.records.pop(0)
|
||||
|
||||
# update the record file
|
||||
record_path = self.checkpoint_dir / "records.jsonl"
|
||||
with jsonlines.open(record_path, 'w') as writer:
|
||||
for record in self.records:
|
||||
# jsonlines.open may return a Writer or a Reader
|
||||
writer.write(record) # pylint: disable=no-member
|
@ -0,0 +1,24 @@
|
||||
from deepspeech.training.extensions import extension
|
||||
from deepspeech.training.updaters.trainer import Trainer
|
||||
|
||||
|
||||
class VisualDL(extension.Extension):
|
||||
"""A wrapper of visualdl log writer. It assumes that the metrics to be visualized
|
||||
are all scalars which are recorded into the `.observation` dictionary of the
|
||||
trainer object. The dictionary is created for each step, thus the visualdl log
|
||||
writer uses the iteration from the updater's `iteration` as the global step to
|
||||
add records.
|
||||
"""
|
||||
trigger = (1, 'iteration')
|
||||
default_name = 'visualdl'
|
||||
priority = extension.PRIORITY_READER
|
||||
|
||||
def __init__(self, writer):
|
||||
self.writer = writer
|
||||
|
||||
def __call__(self, trainer: Trainer):
|
||||
for k, v in trainer.observation.items():
|
||||
self.writer.add_scalar(k, v, step=trainer.updater.state.iteration)
|
||||
|
||||
def finalize(self, trainer):
|
||||
self.writer.close()
|
@ -0,0 +1,131 @@
|
||||
import contextlib
|
||||
import math
|
||||
from collections import defaultdict
|
||||
|
||||
OBSERVATIONS = None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def scope(observations):
|
||||
# make `observation` the target to report to.
|
||||
# it is basically a dictionary that stores temporary observations
|
||||
global OBSERVATIONS
|
||||
old = OBSERVATIONS
|
||||
OBSERVATIONS = observations
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
OBSERVATIONS = old
|
||||
|
||||
|
||||
def get_observations():
|
||||
global OBSERVATIONS
|
||||
return OBSERVATIONS
|
||||
|
||||
|
||||
def report(name, value):
|
||||
# a simple function to report named value
|
||||
# you can use it everywhere, it will get the default target and writ to it
|
||||
# you can think of it as std.out
|
||||
observations = get_observations()
|
||||
if observations is None:
|
||||
return
|
||||
else:
|
||||
observations[name] = value
|
||||
|
||||
|
||||
class Summary():
|
||||
"""Online summarization of a sequence of scalars.
|
||||
Summary computes the statistics of given scalars online.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._x = 0.0
|
||||
self._x2 = 0.0
|
||||
self._n = 0
|
||||
|
||||
def add(self, value, weight=1):
|
||||
"""Adds a scalar value.
|
||||
Args:
|
||||
value: Scalar value to accumulate. It is either a NumPy scalar or
|
||||
a zero-dimensional array (on CPU or GPU).
|
||||
weight: An optional weight for the value. It is a NumPy scalar or
|
||||
a zero-dimensional array (on CPU or GPU).
|
||||
Default is 1 (integer).
|
||||
"""
|
||||
self._x += weight * value
|
||||
self._x2 += weight * value * value
|
||||
self._n += weight
|
||||
|
||||
def compute_mean(self):
|
||||
"""Computes the mean."""
|
||||
x, n = self._x, self._n
|
||||
return x / n
|
||||
|
||||
def make_statistics(self):
|
||||
"""Computes and returns the mean and standard deviation values.
|
||||
Returns:
|
||||
tuple: Mean and standard deviation values.
|
||||
"""
|
||||
x, n = self._x, self._n
|
||||
mean = x / n
|
||||
var = self._x2 / n - mean * mean
|
||||
std = math.sqrt(var)
|
||||
return mean, std
|
||||
|
||||
|
||||
class DictSummary():
|
||||
"""Online summarization of a sequence of dictionaries.
|
||||
``DictSummary`` computes the statistics of a given set of scalars online.
|
||||
It only computes the statistics for scalar values and variables of scalar
|
||||
values in the dictionaries.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._summaries = defaultdict(Summary)
|
||||
|
||||
def add(self, d):
|
||||
"""Adds a dictionary of scalars.
|
||||
Args:
|
||||
d (dict): Dictionary of scalars to accumulate. Only elements of
|
||||
scalars, zero-dimensional arrays, and variables of
|
||||
zero-dimensional arrays are accumulated. When the value
|
||||
is a tuple, the second element is interpreted as a weight.
|
||||
"""
|
||||
summaries = self._summaries
|
||||
for k, v in d.items():
|
||||
w = 1
|
||||
if isinstance(v, tuple):
|
||||
v = v[0]
|
||||
w = v[1]
|
||||
summaries[k].add(v, weight=w)
|
||||
|
||||
def compute_mean(self):
|
||||
"""Creates a dictionary of mean values.
|
||||
It returns a single dictionary that holds a mean value for each entry
|
||||
added to the summary.
|
||||
Returns:
|
||||
dict: Dictionary of mean values.
|
||||
"""
|
||||
return {
|
||||
name: summary.compute_mean()
|
||||
for name, summary in self._summaries.items()
|
||||
}
|
||||
|
||||
def make_statistics(self):
|
||||
"""Creates a dictionary of statistics.
|
||||
It returns a single dictionary that holds mean and standard deviation
|
||||
values for every entry added to the summary. For an entry of name
|
||||
``'key'``, these values are added to the dictionary by names ``'key'``
|
||||
and ``'key.std'``, respectively.
|
||||
Returns:
|
||||
dict: Dictionary of statistics of all entries.
|
||||
"""
|
||||
stats = {}
|
||||
for name, summary in self._summaries.items():
|
||||
mean, std = summary.make_statistics()
|
||||
stats[name] = mean
|
||||
stats[name + '.std'] = std
|
||||
|
||||
return stats
|
@ -0,0 +1,13 @@
|
||||
from .interval_trigger import IntervalTrigger
|
||||
|
||||
def never_fail_trigger(trainer):
|
||||
return False
|
||||
|
||||
def get_trigger(trigger):
|
||||
if trigger is None:
|
||||
return never_fail_trigger
|
||||
if callable(trigger):
|
||||
return trigger
|
||||
else:
|
||||
trigger = IntervalTrigger(*trigger)
|
||||
return trigger
|
@ -0,0 +1,24 @@
|
||||
|
||||
class IntervalTrigger():
|
||||
"""A Predicate to do something every N cycle."""
|
||||
|
||||
def __init__(self, period: int, unit: str):
|
||||
if unit not in ("iteration", "epoch"):
|
||||
raise ValueError("unit should be 'iteration' or 'epoch'")
|
||||
if period <= 0:
|
||||
raise ValueError("period should be a positive integer.")
|
||||
self.period = period
|
||||
self.unit = unit
|
||||
self.last_index = None
|
||||
|
||||
def __call__(self, trainer):
|
||||
if self.last_index is None:
|
||||
last_index = getattr(trainer.updater.state, self.unit)
|
||||
self.last_index = last_index
|
||||
|
||||
last_index = self.last_index
|
||||
index = getattr(trainer.updater.state, self.unit)
|
||||
fire = index // self.period != last_index // self.period
|
||||
|
||||
self.last_index = index
|
||||
return fire
|
@ -0,0 +1,17 @@
|
||||
|
||||
class LimitTrigger():
|
||||
"""A Predicate to decide whether to stop."""
|
||||
|
||||
def __init__(self, limit: int, unit: str):
|
||||
if unit not in ("iteration", "epoch"):
|
||||
raise ValueError("unit should be 'iteration' or 'epoch'")
|
||||
if limit <= 0:
|
||||
raise ValueError("limit should be a positive integer.")
|
||||
self.limit = limit
|
||||
self.unit = unit
|
||||
|
||||
def __call__(self, trainer):
|
||||
state = trainer.updater.state
|
||||
index = getattr(state, self.unit)
|
||||
fire = index >= self.limit
|
||||
return fire
|
@ -0,0 +1,17 @@
|
||||
class TimeTrigger():
|
||||
"""Trigger based on a fixed time interval.
|
||||
This trigger accepts iterations with a given interval time.
|
||||
Args:
|
||||
period (float): Interval time. It is given in seconds.
|
||||
"""
|
||||
|
||||
def __init__(self, period):
|
||||
self._period = period
|
||||
self._next_time = self._period
|
||||
|
||||
def __call__(self, trainer):
|
||||
if self._next_time < trainer.elapsed_time:
|
||||
self._next_time += self._period
|
||||
return True
|
||||
else:
|
||||
return False
|
@ -0,0 +1,179 @@
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
from paddle import Tensor
|
||||
from paddle.io import DataLoader
|
||||
from paddle.io import DistributedBatchSampler
|
||||
from paddle.nn import Layer
|
||||
from paddle.optimizer import Optimizer
|
||||
from timer import timer
|
||||
|
||||
from deepspeech.training.reporter import report
|
||||
from deepspeech.training.updaters.updater import UpdaterBase
|
||||
from deepspeech.training.updaters.updater import UpdaterState
|
||||
|
||||
from deepspeech.utils.log import Log
|
||||
|
||||
__all__ = ["StandardUpdater"]
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
class StandardUpdater(UpdaterBase):
|
||||
"""An example of over-simplification. Things may not be that simple, but
|
||||
you can subclass it to fit your need.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: Layer,
|
||||
optimizer: Optimizer,
|
||||
dataloader: DataLoader,
|
||||
init_state: Optional[UpdaterState]=None):
|
||||
# it is designed to hold multiple models
|
||||
models = {"main": model}
|
||||
self.models: Dict[str, Layer] = models
|
||||
self.model = model
|
||||
|
||||
# it is designed to hold multiple optimizers
|
||||
optimizers = {"main": optimizer}
|
||||
self.optimizer = optimizer
|
||||
self.optimizers: Dict[str, Optimizer] = optimizers
|
||||
|
||||
# dataloaders
|
||||
self.dataloader = dataloader
|
||||
|
||||
# init state
|
||||
if init_state is None:
|
||||
self.state = UpdaterState()
|
||||
else:
|
||||
self.state = init_state
|
||||
|
||||
self.train_iterator = iter(dataloader)
|
||||
|
||||
def update(self):
|
||||
# We increase the iteration index after updating and before extension.
|
||||
# Here are the reasons.
|
||||
|
||||
# 0. Snapshotting(as well as other extensions, like visualizer) is
|
||||
# executed after a step of updating;
|
||||
# 1. We decide to increase the iteration index after updating and
|
||||
# before any all extension is executed.
|
||||
# 3. We do not increase the iteration after extension because we
|
||||
# prefer a consistent resume behavior, when load from a
|
||||
# `snapshot_iter_100.pdz` then the next step to train is `101`,
|
||||
# naturally. But if iteration is increased increased after
|
||||
# extension(including snapshot), then, a `snapshot_iter_99` is
|
||||
# loaded. You would need a extra increasing of the iteration idex
|
||||
# before training to avoid another iteration `99`, which has been
|
||||
# done before snapshotting.
|
||||
# 4. Thus iteration index represrnts "currently how mant epochs has
|
||||
# been done."
|
||||
# NOTE: use report to capture the correctly value. If you want to
|
||||
# report the learning rate used for a step, you must report it before
|
||||
# the learning rate scheduler's step() has been called. In paddle's
|
||||
# convention, we do not use an extension to change the learning rate.
|
||||
# so if you want to report it, do it in the updater.
|
||||
|
||||
# Then here comes the next question. When is the proper time to
|
||||
# increase the epoch index? Since all extensions are executed after
|
||||
# updating, it is the time that after updating is the proper time to
|
||||
# increase epoch index.
|
||||
# 1. If we increase the epoch index before updating, then an extension
|
||||
# based ot epoch would miss the correct timing. It could only be
|
||||
# triggerd after an extra updating.
|
||||
# 2. Theoretically, when an epoch is done, the epoch index should be
|
||||
# increased. So it would be increase after updating.
|
||||
# 3. Thus, eppoch index represents "currently how many epochs has been
|
||||
# done." So it starts from 0.
|
||||
|
||||
# switch to training mode
|
||||
for model in self.models.values():
|
||||
model.train()
|
||||
|
||||
# training for a step is implemented here
|
||||
batch = self.read_batch()
|
||||
self.update_core(batch)
|
||||
|
||||
self.state.iteration += 1
|
||||
if self.updates_per_epoch is not None:
|
||||
if self.state.iteration % self.updates_per_epoch == 0:
|
||||
self.state.epoch += 1
|
||||
|
||||
def update_core(self, batch):
|
||||
"""A simple case for a training step. Basic assumptions are:
|
||||
Single model;
|
||||
Single optimizer;
|
||||
A batch from the dataloader is just the input of the model;
|
||||
The model return a single loss, or a dict containing serval losses.
|
||||
Parameters updates at every batch, no gradient accumulation.
|
||||
"""
|
||||
loss = self.model(*batch)
|
||||
|
||||
if isinstance(loss, Tensor):
|
||||
loss_dict = {"main": loss}
|
||||
else:
|
||||
# Dict[str, Tensor]
|
||||
loss_dict = loss
|
||||
if "main" not in loss_dict:
|
||||
main_loss = 0
|
||||
for loss_item in loss.values():
|
||||
main_loss += loss_item
|
||||
loss_dict["main"] = main_loss
|
||||
|
||||
for name, loss_item in loss_dict.items():
|
||||
report(name, float(loss_item))
|
||||
|
||||
self.optimizer.clear_gradient()
|
||||
loss_dict["main"].backward()
|
||||
self.optimizer.update()
|
||||
|
||||
@property
|
||||
def updates_per_epoch(self):
|
||||
"""Number of updater per epoch, determined by the length of the
|
||||
dataloader."""
|
||||
length_of_dataloader = None
|
||||
try:
|
||||
length_of_dataloader = len(self.dataloader)
|
||||
except TypeError:
|
||||
logger.debug("This dataloader has no __len__.")
|
||||
finally:
|
||||
return length_of_dataloader
|
||||
|
||||
def new_epoch(self):
|
||||
"""Start a new epoch."""
|
||||
# NOTE: all batch sampler for distributed training should
|
||||
# subclass DistributedBatchSampler and implement `set_epoch` method
|
||||
if hasattr(self.dataloader, "batch_sampler")
|
||||
batch_sampler = self.dataloader.batch_sampler
|
||||
if isinstance(batch_sampler, DistributedBatchSampler):
|
||||
batch_sampler.set_epoch(self.state.epoch)
|
||||
self.train_iterator = iter(self.dataloader)
|
||||
|
||||
def read_batch(self):
|
||||
"""Read a batch from the data loader, auto renew when data is exhausted."""
|
||||
with timer() as t:
|
||||
try:
|
||||
batch = next(self.train_iterator)
|
||||
except StopIteration:
|
||||
self.new_epoch()
|
||||
batch = next(self.train_iterator)
|
||||
logger.debug(
|
||||
f"Read a batch takes {t.elapse}s.") # replace it with logger
|
||||
return batch
|
||||
|
||||
def state_dict(self):
|
||||
"""State dict of a Updater, model, optimizer and updater state are included."""
|
||||
state_dict = super().state_dict()
|
||||
for name, model in self.models.items():
|
||||
state_dict[f"{name}_params"] = model.state_dict()
|
||||
for name, optim in self.optimizers.items():
|
||||
state_dict[f"{name}_optimizer"] = optim.state_dict()
|
||||
return state_dict
|
||||
|
||||
def set_state_dict(self, state_dict):
|
||||
"""Set state dict for a Updater. Parameters of models, states for
|
||||
optimizers and UpdaterState are restored."""
|
||||
for name, model in self.models.items():
|
||||
model.set_state_dict(state_dict[f"{name}_params"])
|
||||
for name, optim in self.optimizers.items():
|
||||
optim.set_state_dict(state_dict[f"{name}_optimizer"])
|
||||
super().set_state_dict(state_dict)
|
@ -0,0 +1,171 @@
|
||||
import sys
|
||||
import traceback
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
from typing import List
|
||||
from typing import Union
|
||||
|
||||
import six
|
||||
import tqdm
|
||||
|
||||
from deepspeech.training.extensions.extension import Extension
|
||||
from deepspeech.training.extensions.extension import PRIORITY_READER
|
||||
from deepspeech.training.reporter import scope
|
||||
from deepspeech.training.triggers import get_trigger
|
||||
from deepspeech.training.triggers.limit_trigger import LimitTrigger
|
||||
from deepspeech.training.updaters.updater import UpdaterBase
|
||||
|
||||
|
||||
class _ExtensionEntry():
|
||||
def __init__(self, extension, trigger, priority):
|
||||
self.extension = extension
|
||||
self.trigger = trigger
|
||||
self.priority = priority
|
||||
|
||||
|
||||
class Trainer():
|
||||
def __init__(self,
|
||||
updater: UpdaterBase,
|
||||
stop_trigger: Callable=None,
|
||||
out: Union[str, Path]='result',
|
||||
extensions: List[Extension]=None):
|
||||
self.updater = updater
|
||||
self.extensions = OrderedDict()
|
||||
self.stop_trigger = LimitTrigger(*stop_trigger)
|
||||
self.out = Path(out)
|
||||
self.observation = None
|
||||
|
||||
self._done = False
|
||||
if extensions:
|
||||
for ext in extensions:
|
||||
self.extend(ext)
|
||||
|
||||
@property
|
||||
def is_before_training(self):
|
||||
return self.updater.state.iteration == 0
|
||||
|
||||
def extend(self, extension, name=None, trigger=None, priority=None):
|
||||
# get name for the extension
|
||||
# argument \
|
||||
# -> extention's name \
|
||||
# -> default_name (class name, when it is an object) \
|
||||
# -> function name when it is a function \
|
||||
# -> error
|
||||
|
||||
if name is None:
|
||||
name = getattr(extension, 'name', None)
|
||||
if name is None:
|
||||
name = getattr(extension, 'default_name', None)
|
||||
if name is None:
|
||||
name = getattr(extension, '__name__', None)
|
||||
if name is None:
|
||||
raise ValueError("Name is not given for the extension.")
|
||||
if name == 'training':
|
||||
raise ValueError("training is a reserved name.")
|
||||
|
||||
if trigger is None:
|
||||
trigger = getattr(extension, 'trigger', (1, 'iteration'))
|
||||
trigger = get_trigger(trigger)
|
||||
|
||||
if priority is None:
|
||||
priority = getattr(extension, 'priority', PRIORITY_READER)
|
||||
|
||||
# add suffix to avoid nameing conflict
|
||||
ordinal = 0
|
||||
modified_name = name
|
||||
while modified_name in self.extensions:
|
||||
ordinal += 1
|
||||
modified_name = f"{name}_{ordinal}"
|
||||
extension.name = modified_name
|
||||
|
||||
self.extensions[modified_name] = _ExtensionEntry(extension, trigger,
|
||||
priority)
|
||||
|
||||
def get_extension(self, name):
|
||||
"""get extension by name."""
|
||||
extensions = self.extensions
|
||||
if name in extensions:
|
||||
return extensions[name].extension
|
||||
else:
|
||||
raise ValueError(f'extension {name} not found')
|
||||
|
||||
def run(self):
|
||||
if self._done:
|
||||
raise RuntimeError("Training is already done!.")
|
||||
|
||||
self.out.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# sort extensions by priorities once
|
||||
extension_order = sorted(
|
||||
self.extensions.keys(),
|
||||
key=lambda name: self.extensions[name].priority,
|
||||
reverse=True)
|
||||
extensions = [(name, self.extensions[name]) for name in extension_order]
|
||||
|
||||
# initializing all extensions
|
||||
for name, entry in extensions:
|
||||
if hasattr(entry.extension, "initialize"):
|
||||
entry.extension.initialize(self)
|
||||
|
||||
update = self.updater.update # training step
|
||||
stop_trigger = self.stop_trigger
|
||||
|
||||
# display only one progress bar
|
||||
max_iteration = None
|
||||
if isinstance(stop_trigger, LimitTrigger):
|
||||
if stop_trigger.unit == 'epoch':
|
||||
max_epoch = self.stop_trigger.limit
|
||||
updates_per_epoch = getattr(self.updater, "updates_per_epoch",
|
||||
None)
|
||||
max_iteration = max_epoch * updates_per_epoch if updates_per_epoch else None
|
||||
else:
|
||||
max_iteration = self.stop_trigger.limit
|
||||
|
||||
p = tqdm.tqdm(initial=self.updater.state.iteration, total=max_iteration)
|
||||
|
||||
try:
|
||||
while not stop_trigger(self):
|
||||
self.observation = {}
|
||||
# set observation as the report target
|
||||
# you can use report freely in Updater.update()
|
||||
|
||||
# updating parameters and state
|
||||
with scope(self.observation):
|
||||
update()
|
||||
p.update()
|
||||
|
||||
# execute extension when necessary
|
||||
for name, entry in extensions:
|
||||
if entry.trigger(self):
|
||||
entry.extension(self)
|
||||
|
||||
# print("###", self.observation)
|
||||
except Exception as e:
|
||||
f = sys.stderr
|
||||
f.write(f"Exception in main training loop: {e}\n")
|
||||
f.write("Traceback (most recent call last):\n")
|
||||
traceback.print_tb(sys.exc_info()[2])
|
||||
f.write(
|
||||
"Trainer extensions will try to handle the extension. Then all extensions will finalize."
|
||||
)
|
||||
|
||||
# capture the exception in the mian training loop
|
||||
exc_info = sys.exc_info()
|
||||
|
||||
# try to handle it
|
||||
for name, entry in extensions:
|
||||
if hasattr(entry.extension, "on_error"):
|
||||
try:
|
||||
entry.extension.on_error(self, e, sys.exc_info()[2])
|
||||
except Exception as ee:
|
||||
f.write(f"Exception in error handler: {ee}\n")
|
||||
f.write('Traceback (most recent call last):\n')
|
||||
traceback.print_tb(sys.exc_info()[2])
|
||||
|
||||
# raise exception in main training loop
|
||||
six.reraise(*exc_info)
|
||||
finally:
|
||||
for name, entry in extensions:
|
||||
if hasattr(entry.extension, "finalize"):
|
||||
entry.extension.finalize(self)
|
@ -0,0 +1,82 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
import paddle
|
||||
|
||||
from deepspeech.utils.log import Log
|
||||
|
||||
__all__ = ["UpdaterBase", "UpdaterState"]
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdaterState:
|
||||
iteration: int = 0
|
||||
epoch: int = 0
|
||||
|
||||
|
||||
class UpdaterBase():
|
||||
"""An updater is the abstraction of how a model is trained given the
|
||||
dataloader and the optimizer.
|
||||
The `update_core` method is a step in the training loop with only necessary
|
||||
operations (get a batch, forward and backward, update the parameters).
|
||||
Other stuffs are made extensions. Visualization, saving, loading and
|
||||
periodical validation and evaluation are not considered here.
|
||||
But even in such simplist case, things are not that simple. There is an
|
||||
attempt to standardize this process and requires only the model and
|
||||
dataset and do all the stuffs automatically. But this may hurt flexibility.
|
||||
If we assume a batch yield from the dataloader is just the input to the
|
||||
model, we will find that some model requires more arguments, or just some
|
||||
keyword arguments. But this prevents us from over-simplifying it.
|
||||
From another perspective, the batch may includes not just the input, but
|
||||
also the target. But the model's forward method may just need the input.
|
||||
We can pass a dict or a super-long tuple to the model and let it pick what
|
||||
it really needs. But this is an abuse of lazy interface.
|
||||
After all, we care about how a model is trained. But just how the model is
|
||||
used for inference. We want to control how a model is trained. We just
|
||||
don't want to be messed up with other auxiliary code.
|
||||
So the best practice is to define a model and define a updater for it.
|
||||
"""
|
||||
|
||||
def __init__(self, init_state=None):
|
||||
if init_state is None:
|
||||
self.state = UpdaterState()
|
||||
else:
|
||||
self.state = init_state
|
||||
|
||||
def update(self, batch):
|
||||
raise NotImplementedError(
|
||||
"Implement your own `update` method for training a step.")
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {
|
||||
"epoch": self.state.epoch,
|
||||
"iteration": self.state.iteration,
|
||||
}
|
||||
return state_dict
|
||||
|
||||
def set_state_dict(self, state_dict):
|
||||
self.state.epoch = state_dict["epoch"]
|
||||
self.state.iteration = state_dict["iteration"]
|
||||
|
||||
def save(self, path):
|
||||
logger.debug(f"Saving to {path}.")
|
||||
archive = self.state_dict()
|
||||
paddle.save(archive, str(path))
|
||||
|
||||
def load(self, path):
|
||||
logger.debug(f"Loading from {path}.")
|
||||
archive = paddle.load(str(path))
|
||||
self.set_state_dict(archive)
|
Loading…
Reference in new issue