parent
1726e2fdfc
commit
b9c7835eb9
@ -0,0 +1,275 @@
|
|||||||
|
import inspect
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
import typing
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModel(nn.Layer):
|
||||||
|
"""This is a class that adds useful save/load functionality to a
|
||||||
|
``paddle.nn.Layer`` object. ``BaseModel`` objects can be saved
|
||||||
|
as ``package`` easily, making them super easy to port between
|
||||||
|
machines without requiring a ton of dependencies. Files can also be
|
||||||
|
saved as just weights, in the standard way.
|
||||||
|
|
||||||
|
>>> class Model(ml.BaseModel):
|
||||||
|
>>> def __init__(self, arg1: float = 1.0):
|
||||||
|
>>> super().__init__()
|
||||||
|
>>> self.arg1 = arg1
|
||||||
|
>>> self.linear = nn.Linear(1, 1)
|
||||||
|
>>>
|
||||||
|
>>> def forward(self, x):
|
||||||
|
>>> return self.linear(x)
|
||||||
|
>>>
|
||||||
|
>>> model1 = Model()
|
||||||
|
>>>
|
||||||
|
>>> with tempfile.NamedTemporaryFile(suffix=".pth") as f:
|
||||||
|
>>> model1.save(
|
||||||
|
>>> f.name,
|
||||||
|
>>> )
|
||||||
|
>>> model2 = Model.load(f.name)
|
||||||
|
>>> out2 = seed_and_run(model2, x)
|
||||||
|
>>> assert paddle.allclose(out1, out2)
|
||||||
|
>>>
|
||||||
|
>>> model1.save(f.name, package=True)
|
||||||
|
>>> model2 = Model.load(f.name)
|
||||||
|
>>> model2.save(f.name, package=False)
|
||||||
|
>>> model3 = Model.load(f.name)
|
||||||
|
>>> out3 = seed_and_run(model3, x)
|
||||||
|
>>>
|
||||||
|
>>> with tempfile.TemporaryDirectory() as d:
|
||||||
|
>>> model1.save_to_folder(d, {"data": 1.0})
|
||||||
|
>>> Model.load_from_folder(d)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
INTERN = []
|
||||||
|
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
metadata: dict = None,
|
||||||
|
package: bool = False,
|
||||||
|
intern: list = [],
|
||||||
|
extern: list = [],
|
||||||
|
mock: list = [],
|
||||||
|
):
|
||||||
|
"""Saves the model, either as a package, or just as
|
||||||
|
weights, alongside some specified metadata.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : str
|
||||||
|
Path to save model to.
|
||||||
|
metadata : dict, optional
|
||||||
|
Any metadata to save alongside the model,
|
||||||
|
by default None
|
||||||
|
package : bool, optional
|
||||||
|
Whether to use ``package`` to save the model in
|
||||||
|
a format that is portable, by default True
|
||||||
|
intern : list, optional
|
||||||
|
List of additional libraries that are internal
|
||||||
|
to the model, used with package, by default []
|
||||||
|
extern : list, optional
|
||||||
|
List of additional libraries that are external to
|
||||||
|
the model, used with package, by default []
|
||||||
|
mock : list, optional
|
||||||
|
List of libraries to mock, used with package,
|
||||||
|
by default []
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str
|
||||||
|
Path to saved model.
|
||||||
|
"""
|
||||||
|
sig = inspect.signature(self.__class__)
|
||||||
|
args = {}
|
||||||
|
|
||||||
|
for key, val in sig.parameters.items():
|
||||||
|
arg_val = val.default
|
||||||
|
if arg_val is not inspect.Parameter.empty:
|
||||||
|
args[key] = arg_val
|
||||||
|
|
||||||
|
# Look up attibutes in self, and if any of them are in args,
|
||||||
|
# overwrite them in args.
|
||||||
|
for attribute in dir(self):
|
||||||
|
if attribute in args:
|
||||||
|
args[attribute] = getattr(self, attribute)
|
||||||
|
|
||||||
|
metadata = {} if metadata is None else metadata
|
||||||
|
metadata["kwargs"] = args
|
||||||
|
if not hasattr(self, "metadata"):
|
||||||
|
self.metadata = {}
|
||||||
|
self.metadata.update(metadata)
|
||||||
|
|
||||||
|
if not package:
|
||||||
|
state_dict = {"state_dict": self.state_dict(), "metadata": metadata}
|
||||||
|
paddle.save(state_dict, path)
|
||||||
|
else:
|
||||||
|
self._save_package(path, intern=intern, extern=extern, mock=mock)
|
||||||
|
|
||||||
|
return path
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
"""Gets the device the model is on by looking at the device of
|
||||||
|
the first parameter. May not be valid if model is split across
|
||||||
|
multiple devices.
|
||||||
|
"""
|
||||||
|
return list(self.parameters())[0].device
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
location: str,
|
||||||
|
*args,
|
||||||
|
package_name: str = None,
|
||||||
|
strict: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Load model from a path. Tries first to load as a package, and if
|
||||||
|
that fails, tries to load as weights. The arguments to the class are
|
||||||
|
specified inside the model weights file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
location : str
|
||||||
|
Path to file.
|
||||||
|
package_name : str, optional
|
||||||
|
Name of package, by default ``cls.__name__``.
|
||||||
|
strict : bool, optional
|
||||||
|
Ignore unmatched keys, by default False
|
||||||
|
kwargs : dict
|
||||||
|
Additional keyword arguments to the model instantiation, if
|
||||||
|
not loading from package.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
BaseModel
|
||||||
|
A model that inherits from BaseModel.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model = cls._load_package(location, package_name=package_name)
|
||||||
|
except:
|
||||||
|
model_dict = paddle.load(location, "cpu")
|
||||||
|
metadata = model_dict["metadata"]
|
||||||
|
metadata["kwargs"].update(kwargs)
|
||||||
|
|
||||||
|
sig = inspect.signature(cls)
|
||||||
|
class_keys = list(sig.parameters.keys())
|
||||||
|
for k in list(metadata["kwargs"].keys()):
|
||||||
|
if k not in class_keys:
|
||||||
|
metadata["kwargs"].pop(k)
|
||||||
|
|
||||||
|
model = cls(*args, **metadata["kwargs"])
|
||||||
|
model.load_state_dict(model_dict["state_dict"], strict=strict)
|
||||||
|
model.metadata = metadata
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _save_package(self, path, intern=[], extern=[], mock=[], **kwargs):
|
||||||
|
raise NotImplementedError("Currently Paddle does not support packaging")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _load_package(cls, path, package_name=None):
|
||||||
|
raise NotImplementedError("Currently Paddle does not support packaging")
|
||||||
|
|
||||||
|
def save_to_folder(
|
||||||
|
self,
|
||||||
|
folder: typing.Union[str, Path],
|
||||||
|
extra_data: dict = None,
|
||||||
|
package: bool = False,
|
||||||
|
):
|
||||||
|
"""Dumps a model into a folder, as both a package
|
||||||
|
and as weights, as well as anything specified in
|
||||||
|
``extra_data``. ``extra_data`` is a dictionary of other
|
||||||
|
pickleable files, with the keys being the paths
|
||||||
|
to save them in. The model is saved under a subfolder
|
||||||
|
specified by the name of the class (e.g. ``folder/generator/[package, weights].pth``
|
||||||
|
if the model name was ``Generator``).
|
||||||
|
|
||||||
|
>>> with tempfile.TemporaryDirectory() as d:
|
||||||
|
>>> extra_data = {
|
||||||
|
>>> "optimizer.pth": optimizer.state_dict()
|
||||||
|
>>> }
|
||||||
|
>>> model.save_to_folder(d, extra_data)
|
||||||
|
>>> Model.load_from_folder(d)
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
folder : typing.Union[str, Path]
|
||||||
|
_description_
|
||||||
|
extra_data : dict, optional
|
||||||
|
_description_, by default None
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str
|
||||||
|
Path to folder
|
||||||
|
"""
|
||||||
|
extra_data = {} if extra_data is None else extra_data
|
||||||
|
model_name = type(self).__name__.lower()
|
||||||
|
target_base = Path(f"{folder}/{model_name}/")
|
||||||
|
target_base.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
if package:
|
||||||
|
package_path = target_base / f"package.pth"
|
||||||
|
self.save(package_path)
|
||||||
|
|
||||||
|
weights_path = target_base / f"weights.pth"
|
||||||
|
self.save(weights_path, package=False)
|
||||||
|
|
||||||
|
for path, obj in extra_data.items():
|
||||||
|
paddle.save(obj, target_base / path)
|
||||||
|
|
||||||
|
return target_base
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_from_folder(
|
||||||
|
cls,
|
||||||
|
folder: typing.Union[str, Path],
|
||||||
|
package: bool = False,
|
||||||
|
strict: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Loads the model from a folder generated by
|
||||||
|
:py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`.
|
||||||
|
Like that function, this one looks for a subfolder that has
|
||||||
|
the name of the class (e.g. ``folder/generator/[package, weights].pth`` if the
|
||||||
|
model name was ``Generator``).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
folder : typing.Union[str, Path]
|
||||||
|
_description_
|
||||||
|
package : bool, optional
|
||||||
|
Whether to use ``package`` to load the model,
|
||||||
|
loading the model from ``package.pth``.
|
||||||
|
strict : bool, optional
|
||||||
|
Ignore unmatched keys, by default False
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
tuple
|
||||||
|
tuple of model and extra data as saved by
|
||||||
|
:py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`.
|
||||||
|
"""
|
||||||
|
folder = Path(folder) / cls.__name__.lower()
|
||||||
|
model_pth = "package.pth" if package else "weights.pth"
|
||||||
|
model_pth = folder / model_pth
|
||||||
|
|
||||||
|
model = cls.load(model_pth, strict=strict)
|
||||||
|
extra_data = {}
|
||||||
|
excluded = ["package.pth", "weights.pth"]
|
||||||
|
files = [
|
||||||
|
x
|
||||||
|
for x in folder.glob("*")
|
||||||
|
if x.is_file() and x.name not in excluded
|
||||||
|
]
|
||||||
|
for f in files:
|
||||||
|
extra_data[f.name] = paddle.load(f, **kwargs)
|
||||||
|
|
||||||
|
return model, extra_data
|
@ -0,0 +1,447 @@
|
|||||||
|
import math
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.distributed as dist
|
||||||
|
from rich import box
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.console import Group
|
||||||
|
from rich.live import Live
|
||||||
|
from rich.markdown import Markdown
|
||||||
|
from rich.padding import Padding
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.progress import BarColumn
|
||||||
|
from rich.progress import Progress
|
||||||
|
from rich.progress import SpinnerColumn
|
||||||
|
from rich.progress import TimeElapsedColumn
|
||||||
|
from rich.progress import TimeRemainingColumn
|
||||||
|
from rich.rule import Rule
|
||||||
|
from rich.table import Table
|
||||||
|
from visualdl import LogWriter
|
||||||
|
|
||||||
|
|
||||||
|
# This is here so that the history can be pickled.
|
||||||
|
def default_list():
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class Mean:
|
||||||
|
"""✅Keeps track of the running mean, along with the latest
|
||||||
|
value.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
mean = self.total / max(self.count, 1)
|
||||||
|
return mean
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.count = 0
|
||||||
|
self.total = 0
|
||||||
|
|
||||||
|
def update(self, val):
|
||||||
|
if math.isfinite(val):
|
||||||
|
self.count += 1
|
||||||
|
self.total += val
|
||||||
|
|
||||||
|
|
||||||
|
def when(condition):
|
||||||
|
"""✅Runs a function only when the condition is met. The condition is
|
||||||
|
a function that is run.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
condition : Callable
|
||||||
|
Function to run to check whether or not to run the decorated
|
||||||
|
function.
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
Checkpoint only runs every 100 iterations, and only if the
|
||||||
|
local rank is 0.
|
||||||
|
|
||||||
|
>>> i = 0
|
||||||
|
>>> rank = 0
|
||||||
|
>>>
|
||||||
|
>>> @when(lambda: i % 100 == 0 and rank == 0)
|
||||||
|
>>> def checkpoint():
|
||||||
|
>>> print("Saving to /runs/exp1")
|
||||||
|
>>>
|
||||||
|
>>> for i in range(1000):
|
||||||
|
>>> checkpoint()
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(fn):
|
||||||
|
@wraps(fn)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
if condition():
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorated
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def timer(prefix: str = "time"):
|
||||||
|
"""✅Adds execution time to the output dictionary of the decorated
|
||||||
|
function. The function decorated by this must output a dictionary.
|
||||||
|
The key added will follow the form "[prefix]/[name_of_function]"
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
prefix : str, optional
|
||||||
|
The key added will follow the form "[prefix]/[name_of_function]",
|
||||||
|
by default "time".
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(fn):
|
||||||
|
@wraps(fn)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
s = time.perf_counter()
|
||||||
|
output = fn(*args, **kwargs)
|
||||||
|
assert isinstance(output, dict)
|
||||||
|
e = time.perf_counter()
|
||||||
|
output[f"{prefix}/{fn.__name__}"] = e - s
|
||||||
|
return output
|
||||||
|
|
||||||
|
return decorated
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class Tracker:
|
||||||
|
"""✅
|
||||||
|
A tracker class that helps to monitor the progress of training and logging the metrics.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
metrics : dict
|
||||||
|
A dictionary containing the metrics for each label.
|
||||||
|
history : dict
|
||||||
|
A dictionary containing the history of metrics for each label.
|
||||||
|
writer : LogWriter
|
||||||
|
A LogWriter object for logging the metrics.
|
||||||
|
rank : int
|
||||||
|
The rank of the current process.
|
||||||
|
step : int
|
||||||
|
The current step of the training.
|
||||||
|
tasks : dict
|
||||||
|
A dictionary containing the progress bars and tables for each label.
|
||||||
|
pbar : Progress
|
||||||
|
A progress bar object for displaying the progress.
|
||||||
|
consoles : list
|
||||||
|
A list of console objects for logging.
|
||||||
|
live : Live
|
||||||
|
A Live object for updating the display live.
|
||||||
|
|
||||||
|
Methods
|
||||||
|
-------
|
||||||
|
print(msg: str)
|
||||||
|
Prints the given message to all consoles.
|
||||||
|
update(label: str, fn_name: str)
|
||||||
|
Updates the progress bar and table for the given label.
|
||||||
|
done(label: str, title: str)
|
||||||
|
Resets the progress bar and table for the given label and prints the final result.
|
||||||
|
track(label: str, length: int, completed: int = 0, op: dist.ReduceOp = dist.ReduceOp.AVG, ddp_active: bool = "LOCAL_RANK" in os.environ)
|
||||||
|
A decorator for tracking the progress and metrics of a function.
|
||||||
|
log(label: str, value_type: str = "value", history: bool = True)
|
||||||
|
A decorator for logging the metrics of a function.
|
||||||
|
is_best(label: str, key: str) -> bool
|
||||||
|
Checks if the latest value of the given key in the label is the best so far.
|
||||||
|
state_dict() -> dict
|
||||||
|
Returns a dictionary containing the state of the tracker.
|
||||||
|
load_state_dict(state_dict: dict) -> Tracker
|
||||||
|
Loads the state of the tracker from the given state dictionary.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
writer: LogWriter = None,
|
||||||
|
log_file: str = None,
|
||||||
|
rank: int = 0,
|
||||||
|
console_width: int = 100,
|
||||||
|
step: int = 0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initializes the Tracker object.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
writer : LogWriter, optional
|
||||||
|
A LogWriter object for logging the metrics, by default None.
|
||||||
|
log_file : str, optional
|
||||||
|
The path to the log file, by default None.
|
||||||
|
rank : int, optional
|
||||||
|
The rank of the current process, by default 0.
|
||||||
|
console_width : int, optional
|
||||||
|
The width of the console, by default 100.
|
||||||
|
step : int, optional
|
||||||
|
The current step of the training, by default 0.
|
||||||
|
"""
|
||||||
|
self.metrics = {}
|
||||||
|
self.history = {}
|
||||||
|
self.writer = writer
|
||||||
|
self.rank = rank
|
||||||
|
self.step = step
|
||||||
|
|
||||||
|
# Create progress bars etc.
|
||||||
|
self.tasks = {}
|
||||||
|
self.pbar = Progress(
|
||||||
|
SpinnerColumn(),
|
||||||
|
"[progress.description]{task.description}",
|
||||||
|
"{task.completed}/{task.total}",
|
||||||
|
BarColumn(),
|
||||||
|
TimeElapsedColumn(),
|
||||||
|
"/",
|
||||||
|
TimeRemainingColumn(),
|
||||||
|
)
|
||||||
|
self.consoles = [Console(width=console_width)]
|
||||||
|
self.live = Live(console=self.consoles[0], refresh_per_second=10)
|
||||||
|
if log_file is not None:
|
||||||
|
self.consoles.append(
|
||||||
|
Console(width=console_width, file=open(log_file, "a"))
|
||||||
|
)
|
||||||
|
|
||||||
|
def print(self, msg):
|
||||||
|
"""
|
||||||
|
Prints the given message to all consoles.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
msg : str
|
||||||
|
The message to be printed.
|
||||||
|
"""
|
||||||
|
if self.rank == 0:
|
||||||
|
for c in self.consoles:
|
||||||
|
c.log(msg)
|
||||||
|
|
||||||
|
def update(self, label, fn_name):
|
||||||
|
"""
|
||||||
|
Updates the progress bar and table for the given label.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
label : str
|
||||||
|
The label of the progress bar and table to be updated.
|
||||||
|
fn_name : str
|
||||||
|
The name of the function associated with the label.
|
||||||
|
"""
|
||||||
|
if self.rank == 0:
|
||||||
|
self.pbar.advance(self.tasks[label]["pbar"])
|
||||||
|
|
||||||
|
# Create table
|
||||||
|
table = Table(title=label, expand=True, box=box.MINIMAL)
|
||||||
|
table.add_column("key", style="cyan")
|
||||||
|
table.add_column("value", style="bright_blue")
|
||||||
|
table.add_column("mean", style="bright_green")
|
||||||
|
|
||||||
|
keys = self.metrics[label]["value"].keys()
|
||||||
|
for k in keys:
|
||||||
|
value = self.metrics[label]["value"][k]
|
||||||
|
mean = self.metrics[label]["mean"][k]()
|
||||||
|
table.add_row(k, f"{value:10.6f}", f"{mean:10.6f}")
|
||||||
|
|
||||||
|
self.tasks[label]["table"] = table
|
||||||
|
tables = [t["table"] for t in self.tasks.values()]
|
||||||
|
group = Group(*tables, self.pbar)
|
||||||
|
self.live.update(
|
||||||
|
Group(
|
||||||
|
Padding("", (0, 0)),
|
||||||
|
Rule(f"[italic]{fn_name}()", style="white"),
|
||||||
|
Padding("", (0, 0)),
|
||||||
|
Panel.fit(
|
||||||
|
group,
|
||||||
|
padding=(0, 5),
|
||||||
|
title="[b]Progress",
|
||||||
|
border_style="blue",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def done(self, label: str, title: str):
|
||||||
|
"""
|
||||||
|
Resets the progress bar and table for the given label and prints the final result.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
label : str
|
||||||
|
The label of the progress bar and table to be reset.
|
||||||
|
title : str
|
||||||
|
The title to be displayed when printing the final result.
|
||||||
|
"""
|
||||||
|
for label in self.metrics:
|
||||||
|
for v in self.metrics[label]["mean"].values():
|
||||||
|
v.reset()
|
||||||
|
|
||||||
|
if self.rank == 0:
|
||||||
|
self.pbar.reset(self.tasks[label]["pbar"])
|
||||||
|
tables = [t["table"] for t in self.tasks.values()]
|
||||||
|
group = Group(Markdown(f"# {title}"), *tables, self.pbar)
|
||||||
|
self.print(group)
|
||||||
|
|
||||||
|
def track(
|
||||||
|
self,
|
||||||
|
label: str,
|
||||||
|
length: int,
|
||||||
|
completed: int = 0,
|
||||||
|
op: dist.ReduceOp = dist.ReduceOp.AVG,
|
||||||
|
ddp_active: bool = "LOCAL_RANK" in os.environ,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
A decorator for tracking the progress and metrics of a function.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
label : str
|
||||||
|
The label to be associated with the progress and metrics.
|
||||||
|
length : int
|
||||||
|
The total number of iterations to be completed.
|
||||||
|
completed : int, optional
|
||||||
|
The number of iterations already completed, by default 0.
|
||||||
|
op : dist.ReduceOp, optional
|
||||||
|
The reduce operation to be used, by default dist.ReduceOp.AVG.
|
||||||
|
ddp_active : bool, optional
|
||||||
|
Whether the DistributedDataParallel is active, by default "LOCAL_RANK" in os.environ.
|
||||||
|
"""
|
||||||
|
self.tasks[label] = {
|
||||||
|
"pbar": self.pbar.add_task(
|
||||||
|
f"[white]Iteration ({label})", total=length, completed=completed
|
||||||
|
),
|
||||||
|
"table": Table(),
|
||||||
|
}
|
||||||
|
self.metrics[label] = {
|
||||||
|
"value": defaultdict(),
|
||||||
|
"mean": defaultdict(lambda: Mean()),
|
||||||
|
}
|
||||||
|
|
||||||
|
def decorator(fn):
|
||||||
|
@wraps(fn)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
output = fn(*args, **kwargs)
|
||||||
|
if not isinstance(output, dict):
|
||||||
|
self.update(label, fn.__name__)
|
||||||
|
return output
|
||||||
|
# Collect across all DDP processes
|
||||||
|
scalar_keys = []
|
||||||
|
for k, v in output.items():
|
||||||
|
if isinstance(v, (int, float)):
|
||||||
|
v = paddle.to_tensor([v])
|
||||||
|
if not paddle.is_tensor(v):
|
||||||
|
continue
|
||||||
|
if ddp_active and v.is_cuda: # pragma: no cover
|
||||||
|
dist.all_reduce(v, op=op)
|
||||||
|
output[k] = v.detach()
|
||||||
|
if paddle.numel(v) == 1:
|
||||||
|
scalar_keys.append(k)
|
||||||
|
output[k] = v.item()
|
||||||
|
|
||||||
|
# Save the outputs to tracker
|
||||||
|
for k, v in output.items():
|
||||||
|
if k not in scalar_keys:
|
||||||
|
continue
|
||||||
|
self.metrics[label]["value"][k] = v
|
||||||
|
# Update the running mean
|
||||||
|
self.metrics[label]["mean"][k].update(v)
|
||||||
|
|
||||||
|
self.update(label, fn.__name__)
|
||||||
|
return output
|
||||||
|
|
||||||
|
return decorated
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def log(self, label: str, value_type: str = "value", history: bool = True):
|
||||||
|
"""
|
||||||
|
A decorator for logging the metrics of a function.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
label : str
|
||||||
|
The label to be associated with the logging.
|
||||||
|
value_type : str, optional
|
||||||
|
The type of value to be logged, by default "value".
|
||||||
|
history : bool, optional
|
||||||
|
Whether to save the history of the metrics, by default True.
|
||||||
|
"""
|
||||||
|
assert value_type in ["mean", "value"]
|
||||||
|
if history:
|
||||||
|
if label not in self.history:
|
||||||
|
self.history[label] = defaultdict(default_list)
|
||||||
|
|
||||||
|
def decorator(fn):
|
||||||
|
@wraps(fn)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
output = fn(*args, **kwargs)
|
||||||
|
if self.rank == 0:
|
||||||
|
nonlocal value_type, label
|
||||||
|
metrics = self.metrics[label][value_type]
|
||||||
|
for k, v in metrics.items():
|
||||||
|
v = v() if isinstance(v, Mean) else v
|
||||||
|
if self.writer is not None:
|
||||||
|
self.writer.add_scalar(
|
||||||
|
tag=f"{k}/{label}", value=v, step=self.step
|
||||||
|
)
|
||||||
|
if label in self.history:
|
||||||
|
self.history[label][k].append(v)
|
||||||
|
|
||||||
|
if label in self.history:
|
||||||
|
self.history[label]["step"].append(self.step)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
return decorated
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def is_best(self, label, key):
|
||||||
|
"""
|
||||||
|
Checks if the latest value of the given key in the label is the best so far.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
label : str
|
||||||
|
The label of the metrics to be checked.
|
||||||
|
key : str
|
||||||
|
The key of the metric to be checked.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if the latest value is the best so far, otherwise False.
|
||||||
|
"""
|
||||||
|
return self.history[label][key][-1] == min(self.history[label][key])
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
"""
|
||||||
|
Returns a dictionary containing the state of the tracker.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict
|
||||||
|
A dictionary containing the history and step of the tracker.
|
||||||
|
"""
|
||||||
|
return {"history": self.history, "step": self.step}
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
"""
|
||||||
|
Loads the state of the tracker from the given state dictionary.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
state_dict : dict
|
||||||
|
A dictionary containing the history and step of the tracker.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tracker
|
||||||
|
The tracker object with the loaded state.
|
||||||
|
"""
|
||||||
|
self.history = state_dict["history"]
|
||||||
|
self.step = state_dict["step"]
|
||||||
|
return self
|
Loading…
Reference in new issue