parent
1726e2fdfc
commit
b9c7835eb9
@ -0,0 +1,275 @@
|
||||
import inspect
|
||||
import shutil
|
||||
import tempfile
|
||||
import typing
|
||||
from pathlib import Path
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
|
||||
class BaseModel(nn.Layer):
|
||||
"""This is a class that adds useful save/load functionality to a
|
||||
``paddle.nn.Layer`` object. ``BaseModel`` objects can be saved
|
||||
as ``package`` easily, making them super easy to port between
|
||||
machines without requiring a ton of dependencies. Files can also be
|
||||
saved as just weights, in the standard way.
|
||||
|
||||
>>> class Model(ml.BaseModel):
|
||||
>>> def __init__(self, arg1: float = 1.0):
|
||||
>>> super().__init__()
|
||||
>>> self.arg1 = arg1
|
||||
>>> self.linear = nn.Linear(1, 1)
|
||||
>>>
|
||||
>>> def forward(self, x):
|
||||
>>> return self.linear(x)
|
||||
>>>
|
||||
>>> model1 = Model()
|
||||
>>>
|
||||
>>> with tempfile.NamedTemporaryFile(suffix=".pth") as f:
|
||||
>>> model1.save(
|
||||
>>> f.name,
|
||||
>>> )
|
||||
>>> model2 = Model.load(f.name)
|
||||
>>> out2 = seed_and_run(model2, x)
|
||||
>>> assert paddle.allclose(out1, out2)
|
||||
>>>
|
||||
>>> model1.save(f.name, package=True)
|
||||
>>> model2 = Model.load(f.name)
|
||||
>>> model2.save(f.name, package=False)
|
||||
>>> model3 = Model.load(f.name)
|
||||
>>> out3 = seed_and_run(model3, x)
|
||||
>>>
|
||||
>>> with tempfile.TemporaryDirectory() as d:
|
||||
>>> model1.save_to_folder(d, {"data": 1.0})
|
||||
>>> Model.load_from_folder(d)
|
||||
|
||||
"""
|
||||
|
||||
INTERN = []
|
||||
|
||||
def save(
|
||||
self,
|
||||
path: str,
|
||||
metadata: dict = None,
|
||||
package: bool = False,
|
||||
intern: list = [],
|
||||
extern: list = [],
|
||||
mock: list = [],
|
||||
):
|
||||
"""Saves the model, either as a package, or just as
|
||||
weights, alongside some specified metadata.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str
|
||||
Path to save model to.
|
||||
metadata : dict, optional
|
||||
Any metadata to save alongside the model,
|
||||
by default None
|
||||
package : bool, optional
|
||||
Whether to use ``package`` to save the model in
|
||||
a format that is portable, by default True
|
||||
intern : list, optional
|
||||
List of additional libraries that are internal
|
||||
to the model, used with package, by default []
|
||||
extern : list, optional
|
||||
List of additional libraries that are external to
|
||||
the model, used with package, by default []
|
||||
mock : list, optional
|
||||
List of libraries to mock, used with package,
|
||||
by default []
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Path to saved model.
|
||||
"""
|
||||
sig = inspect.signature(self.__class__)
|
||||
args = {}
|
||||
|
||||
for key, val in sig.parameters.items():
|
||||
arg_val = val.default
|
||||
if arg_val is not inspect.Parameter.empty:
|
||||
args[key] = arg_val
|
||||
|
||||
# Look up attibutes in self, and if any of them are in args,
|
||||
# overwrite them in args.
|
||||
for attribute in dir(self):
|
||||
if attribute in args:
|
||||
args[attribute] = getattr(self, attribute)
|
||||
|
||||
metadata = {} if metadata is None else metadata
|
||||
metadata["kwargs"] = args
|
||||
if not hasattr(self, "metadata"):
|
||||
self.metadata = {}
|
||||
self.metadata.update(metadata)
|
||||
|
||||
if not package:
|
||||
state_dict = {"state_dict": self.state_dict(), "metadata": metadata}
|
||||
paddle.save(state_dict, path)
|
||||
else:
|
||||
self._save_package(path, intern=intern, extern=extern, mock=mock)
|
||||
|
||||
return path
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
"""Gets the device the model is on by looking at the device of
|
||||
the first parameter. May not be valid if model is split across
|
||||
multiple devices.
|
||||
"""
|
||||
return list(self.parameters())[0].device
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
location: str,
|
||||
*args,
|
||||
package_name: str = None,
|
||||
strict: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Load model from a path. Tries first to load as a package, and if
|
||||
that fails, tries to load as weights. The arguments to the class are
|
||||
specified inside the model weights file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
location : str
|
||||
Path to file.
|
||||
package_name : str, optional
|
||||
Name of package, by default ``cls.__name__``.
|
||||
strict : bool, optional
|
||||
Ignore unmatched keys, by default False
|
||||
kwargs : dict
|
||||
Additional keyword arguments to the model instantiation, if
|
||||
not loading from package.
|
||||
|
||||
Returns
|
||||
-------
|
||||
BaseModel
|
||||
A model that inherits from BaseModel.
|
||||
"""
|
||||
try:
|
||||
model = cls._load_package(location, package_name=package_name)
|
||||
except:
|
||||
model_dict = paddle.load(location, "cpu")
|
||||
metadata = model_dict["metadata"]
|
||||
metadata["kwargs"].update(kwargs)
|
||||
|
||||
sig = inspect.signature(cls)
|
||||
class_keys = list(sig.parameters.keys())
|
||||
for k in list(metadata["kwargs"].keys()):
|
||||
if k not in class_keys:
|
||||
metadata["kwargs"].pop(k)
|
||||
|
||||
model = cls(*args, **metadata["kwargs"])
|
||||
model.load_state_dict(model_dict["state_dict"], strict=strict)
|
||||
model.metadata = metadata
|
||||
|
||||
return model
|
||||
|
||||
def _save_package(self, path, intern=[], extern=[], mock=[], **kwargs):
|
||||
raise NotImplementedError("Currently Paddle does not support packaging")
|
||||
|
||||
@classmethod
|
||||
def _load_package(cls, path, package_name=None):
|
||||
raise NotImplementedError("Currently Paddle does not support packaging")
|
||||
|
||||
def save_to_folder(
|
||||
self,
|
||||
folder: typing.Union[str, Path],
|
||||
extra_data: dict = None,
|
||||
package: bool = False,
|
||||
):
|
||||
"""Dumps a model into a folder, as both a package
|
||||
and as weights, as well as anything specified in
|
||||
``extra_data``. ``extra_data`` is a dictionary of other
|
||||
pickleable files, with the keys being the paths
|
||||
to save them in. The model is saved under a subfolder
|
||||
specified by the name of the class (e.g. ``folder/generator/[package, weights].pth``
|
||||
if the model name was ``Generator``).
|
||||
|
||||
>>> with tempfile.TemporaryDirectory() as d:
|
||||
>>> extra_data = {
|
||||
>>> "optimizer.pth": optimizer.state_dict()
|
||||
>>> }
|
||||
>>> model.save_to_folder(d, extra_data)
|
||||
>>> Model.load_from_folder(d)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
folder : typing.Union[str, Path]
|
||||
_description_
|
||||
extra_data : dict, optional
|
||||
_description_, by default None
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Path to folder
|
||||
"""
|
||||
extra_data = {} if extra_data is None else extra_data
|
||||
model_name = type(self).__name__.lower()
|
||||
target_base = Path(f"{folder}/{model_name}/")
|
||||
target_base.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
if package:
|
||||
package_path = target_base / f"package.pth"
|
||||
self.save(package_path)
|
||||
|
||||
weights_path = target_base / f"weights.pth"
|
||||
self.save(weights_path, package=False)
|
||||
|
||||
for path, obj in extra_data.items():
|
||||
paddle.save(obj, target_base / path)
|
||||
|
||||
return target_base
|
||||
|
||||
@classmethod
|
||||
def load_from_folder(
|
||||
cls,
|
||||
folder: typing.Union[str, Path],
|
||||
package: bool = False,
|
||||
strict: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Loads the model from a folder generated by
|
||||
:py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`.
|
||||
Like that function, this one looks for a subfolder that has
|
||||
the name of the class (e.g. ``folder/generator/[package, weights].pth`` if the
|
||||
model name was ``Generator``).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
folder : typing.Union[str, Path]
|
||||
_description_
|
||||
package : bool, optional
|
||||
Whether to use ``package`` to load the model,
|
||||
loading the model from ``package.pth``.
|
||||
strict : bool, optional
|
||||
Ignore unmatched keys, by default False
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple
|
||||
tuple of model and extra data as saved by
|
||||
:py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`.
|
||||
"""
|
||||
folder = Path(folder) / cls.__name__.lower()
|
||||
model_pth = "package.pth" if package else "weights.pth"
|
||||
model_pth = folder / model_pth
|
||||
|
||||
model = cls.load(model_pth, strict=strict)
|
||||
extra_data = {}
|
||||
excluded = ["package.pth", "weights.pth"]
|
||||
files = [
|
||||
x
|
||||
for x in folder.glob("*")
|
||||
if x.is_file() and x.name not in excluded
|
||||
]
|
||||
for f in files:
|
||||
extra_data[f.name] = paddle.load(f, **kwargs)
|
||||
|
||||
return model, extra_data
|
@ -0,0 +1,447 @@
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from functools import wraps
|
||||
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
from rich import box
|
||||
from rich.console import Console
|
||||
from rich.console import Group
|
||||
from rich.live import Live
|
||||
from rich.markdown import Markdown
|
||||
from rich.padding import Padding
|
||||
from rich.panel import Panel
|
||||
from rich.progress import BarColumn
|
||||
from rich.progress import Progress
|
||||
from rich.progress import SpinnerColumn
|
||||
from rich.progress import TimeElapsedColumn
|
||||
from rich.progress import TimeRemainingColumn
|
||||
from rich.rule import Rule
|
||||
from rich.table import Table
|
||||
from visualdl import LogWriter
|
||||
|
||||
|
||||
# This is here so that the history can be pickled.
|
||||
def default_list():
|
||||
return []
|
||||
|
||||
|
||||
class Mean:
|
||||
"""✅Keeps track of the running mean, along with the latest
|
||||
value.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def __call__(self):
|
||||
mean = self.total / max(self.count, 1)
|
||||
return mean
|
||||
|
||||
def reset(self):
|
||||
self.count = 0
|
||||
self.total = 0
|
||||
|
||||
def update(self, val):
|
||||
if math.isfinite(val):
|
||||
self.count += 1
|
||||
self.total += val
|
||||
|
||||
|
||||
def when(condition):
|
||||
"""✅Runs a function only when the condition is met. The condition is
|
||||
a function that is run.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
condition : Callable
|
||||
Function to run to check whether or not to run the decorated
|
||||
function.
|
||||
|
||||
Example
|
||||
-------
|
||||
Checkpoint only runs every 100 iterations, and only if the
|
||||
local rank is 0.
|
||||
|
||||
>>> i = 0
|
||||
>>> rank = 0
|
||||
>>>
|
||||
>>> @when(lambda: i % 100 == 0 and rank == 0)
|
||||
>>> def checkpoint():
|
||||
>>> print("Saving to /runs/exp1")
|
||||
>>>
|
||||
>>> for i in range(1000):
|
||||
>>> checkpoint()
|
||||
|
||||
"""
|
||||
|
||||
def decorator(fn):
|
||||
@wraps(fn)
|
||||
def decorated(*args, **kwargs):
|
||||
if condition():
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def timer(prefix: str = "time"):
|
||||
"""✅Adds execution time to the output dictionary of the decorated
|
||||
function. The function decorated by this must output a dictionary.
|
||||
The key added will follow the form "[prefix]/[name_of_function]"
|
||||
|
||||
Parameters
|
||||
----------
|
||||
prefix : str, optional
|
||||
The key added will follow the form "[prefix]/[name_of_function]",
|
||||
by default "time".
|
||||
"""
|
||||
|
||||
def decorator(fn):
|
||||
@wraps(fn)
|
||||
def decorated(*args, **kwargs):
|
||||
s = time.perf_counter()
|
||||
output = fn(*args, **kwargs)
|
||||
assert isinstance(output, dict)
|
||||
e = time.perf_counter()
|
||||
output[f"{prefix}/{fn.__name__}"] = e - s
|
||||
return output
|
||||
|
||||
return decorated
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class Tracker:
|
||||
"""✅
|
||||
A tracker class that helps to monitor the progress of training and logging the metrics.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
metrics : dict
|
||||
A dictionary containing the metrics for each label.
|
||||
history : dict
|
||||
A dictionary containing the history of metrics for each label.
|
||||
writer : LogWriter
|
||||
A LogWriter object for logging the metrics.
|
||||
rank : int
|
||||
The rank of the current process.
|
||||
step : int
|
||||
The current step of the training.
|
||||
tasks : dict
|
||||
A dictionary containing the progress bars and tables for each label.
|
||||
pbar : Progress
|
||||
A progress bar object for displaying the progress.
|
||||
consoles : list
|
||||
A list of console objects for logging.
|
||||
live : Live
|
||||
A Live object for updating the display live.
|
||||
|
||||
Methods
|
||||
-------
|
||||
print(msg: str)
|
||||
Prints the given message to all consoles.
|
||||
update(label: str, fn_name: str)
|
||||
Updates the progress bar and table for the given label.
|
||||
done(label: str, title: str)
|
||||
Resets the progress bar and table for the given label and prints the final result.
|
||||
track(label: str, length: int, completed: int = 0, op: dist.ReduceOp = dist.ReduceOp.AVG, ddp_active: bool = "LOCAL_RANK" in os.environ)
|
||||
A decorator for tracking the progress and metrics of a function.
|
||||
log(label: str, value_type: str = "value", history: bool = True)
|
||||
A decorator for logging the metrics of a function.
|
||||
is_best(label: str, key: str) -> bool
|
||||
Checks if the latest value of the given key in the label is the best so far.
|
||||
state_dict() -> dict
|
||||
Returns a dictionary containing the state of the tracker.
|
||||
load_state_dict(state_dict: dict) -> Tracker
|
||||
Loads the state of the tracker from the given state dictionary.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
writer: LogWriter = None,
|
||||
log_file: str = None,
|
||||
rank: int = 0,
|
||||
console_width: int = 100,
|
||||
step: int = 0,
|
||||
):
|
||||
"""
|
||||
Initializes the Tracker object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
writer : LogWriter, optional
|
||||
A LogWriter object for logging the metrics, by default None.
|
||||
log_file : str, optional
|
||||
The path to the log file, by default None.
|
||||
rank : int, optional
|
||||
The rank of the current process, by default 0.
|
||||
console_width : int, optional
|
||||
The width of the console, by default 100.
|
||||
step : int, optional
|
||||
The current step of the training, by default 0.
|
||||
"""
|
||||
self.metrics = {}
|
||||
self.history = {}
|
||||
self.writer = writer
|
||||
self.rank = rank
|
||||
self.step = step
|
||||
|
||||
# Create progress bars etc.
|
||||
self.tasks = {}
|
||||
self.pbar = Progress(
|
||||
SpinnerColumn(),
|
||||
"[progress.description]{task.description}",
|
||||
"{task.completed}/{task.total}",
|
||||
BarColumn(),
|
||||
TimeElapsedColumn(),
|
||||
"/",
|
||||
TimeRemainingColumn(),
|
||||
)
|
||||
self.consoles = [Console(width=console_width)]
|
||||
self.live = Live(console=self.consoles[0], refresh_per_second=10)
|
||||
if log_file is not None:
|
||||
self.consoles.append(
|
||||
Console(width=console_width, file=open(log_file, "a"))
|
||||
)
|
||||
|
||||
def print(self, msg):
|
||||
"""
|
||||
Prints the given message to all consoles.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
msg : str
|
||||
The message to be printed.
|
||||
"""
|
||||
if self.rank == 0:
|
||||
for c in self.consoles:
|
||||
c.log(msg)
|
||||
|
||||
def update(self, label, fn_name):
|
||||
"""
|
||||
Updates the progress bar and table for the given label.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
label : str
|
||||
The label of the progress bar and table to be updated.
|
||||
fn_name : str
|
||||
The name of the function associated with the label.
|
||||
"""
|
||||
if self.rank == 0:
|
||||
self.pbar.advance(self.tasks[label]["pbar"])
|
||||
|
||||
# Create table
|
||||
table = Table(title=label, expand=True, box=box.MINIMAL)
|
||||
table.add_column("key", style="cyan")
|
||||
table.add_column("value", style="bright_blue")
|
||||
table.add_column("mean", style="bright_green")
|
||||
|
||||
keys = self.metrics[label]["value"].keys()
|
||||
for k in keys:
|
||||
value = self.metrics[label]["value"][k]
|
||||
mean = self.metrics[label]["mean"][k]()
|
||||
table.add_row(k, f"{value:10.6f}", f"{mean:10.6f}")
|
||||
|
||||
self.tasks[label]["table"] = table
|
||||
tables = [t["table"] for t in self.tasks.values()]
|
||||
group = Group(*tables, self.pbar)
|
||||
self.live.update(
|
||||
Group(
|
||||
Padding("", (0, 0)),
|
||||
Rule(f"[italic]{fn_name}()", style="white"),
|
||||
Padding("", (0, 0)),
|
||||
Panel.fit(
|
||||
group,
|
||||
padding=(0, 5),
|
||||
title="[b]Progress",
|
||||
border_style="blue",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def done(self, label: str, title: str):
|
||||
"""
|
||||
Resets the progress bar and table for the given label and prints the final result.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
label : str
|
||||
The label of the progress bar and table to be reset.
|
||||
title : str
|
||||
The title to be displayed when printing the final result.
|
||||
"""
|
||||
for label in self.metrics:
|
||||
for v in self.metrics[label]["mean"].values():
|
||||
v.reset()
|
||||
|
||||
if self.rank == 0:
|
||||
self.pbar.reset(self.tasks[label]["pbar"])
|
||||
tables = [t["table"] for t in self.tasks.values()]
|
||||
group = Group(Markdown(f"# {title}"), *tables, self.pbar)
|
||||
self.print(group)
|
||||
|
||||
def track(
|
||||
self,
|
||||
label: str,
|
||||
length: int,
|
||||
completed: int = 0,
|
||||
op: dist.ReduceOp = dist.ReduceOp.AVG,
|
||||
ddp_active: bool = "LOCAL_RANK" in os.environ,
|
||||
):
|
||||
"""
|
||||
A decorator for tracking the progress and metrics of a function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
label : str
|
||||
The label to be associated with the progress and metrics.
|
||||
length : int
|
||||
The total number of iterations to be completed.
|
||||
completed : int, optional
|
||||
The number of iterations already completed, by default 0.
|
||||
op : dist.ReduceOp, optional
|
||||
The reduce operation to be used, by default dist.ReduceOp.AVG.
|
||||
ddp_active : bool, optional
|
||||
Whether the DistributedDataParallel is active, by default "LOCAL_RANK" in os.environ.
|
||||
"""
|
||||
self.tasks[label] = {
|
||||
"pbar": self.pbar.add_task(
|
||||
f"[white]Iteration ({label})", total=length, completed=completed
|
||||
),
|
||||
"table": Table(),
|
||||
}
|
||||
self.metrics[label] = {
|
||||
"value": defaultdict(),
|
||||
"mean": defaultdict(lambda: Mean()),
|
||||
}
|
||||
|
||||
def decorator(fn):
|
||||
@wraps(fn)
|
||||
def decorated(*args, **kwargs):
|
||||
output = fn(*args, **kwargs)
|
||||
if not isinstance(output, dict):
|
||||
self.update(label, fn.__name__)
|
||||
return output
|
||||
# Collect across all DDP processes
|
||||
scalar_keys = []
|
||||
for k, v in output.items():
|
||||
if isinstance(v, (int, float)):
|
||||
v = paddle.to_tensor([v])
|
||||
if not paddle.is_tensor(v):
|
||||
continue
|
||||
if ddp_active and v.is_cuda: # pragma: no cover
|
||||
dist.all_reduce(v, op=op)
|
||||
output[k] = v.detach()
|
||||
if paddle.numel(v) == 1:
|
||||
scalar_keys.append(k)
|
||||
output[k] = v.item()
|
||||
|
||||
# Save the outputs to tracker
|
||||
for k, v in output.items():
|
||||
if k not in scalar_keys:
|
||||
continue
|
||||
self.metrics[label]["value"][k] = v
|
||||
# Update the running mean
|
||||
self.metrics[label]["mean"][k].update(v)
|
||||
|
||||
self.update(label, fn.__name__)
|
||||
return output
|
||||
|
||||
return decorated
|
||||
|
||||
return decorator
|
||||
|
||||
def log(self, label: str, value_type: str = "value", history: bool = True):
|
||||
"""
|
||||
A decorator for logging the metrics of a function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
label : str
|
||||
The label to be associated with the logging.
|
||||
value_type : str, optional
|
||||
The type of value to be logged, by default "value".
|
||||
history : bool, optional
|
||||
Whether to save the history of the metrics, by default True.
|
||||
"""
|
||||
assert value_type in ["mean", "value"]
|
||||
if history:
|
||||
if label not in self.history:
|
||||
self.history[label] = defaultdict(default_list)
|
||||
|
||||
def decorator(fn):
|
||||
@wraps(fn)
|
||||
def decorated(*args, **kwargs):
|
||||
output = fn(*args, **kwargs)
|
||||
if self.rank == 0:
|
||||
nonlocal value_type, label
|
||||
metrics = self.metrics[label][value_type]
|
||||
for k, v in metrics.items():
|
||||
v = v() if isinstance(v, Mean) else v
|
||||
if self.writer is not None:
|
||||
self.writer.add_scalar(
|
||||
tag=f"{k}/{label}", value=v, step=self.step
|
||||
)
|
||||
if label in self.history:
|
||||
self.history[label][k].append(v)
|
||||
|
||||
if label in self.history:
|
||||
self.history[label]["step"].append(self.step)
|
||||
|
||||
return output
|
||||
|
||||
return decorated
|
||||
|
||||
return decorator
|
||||
|
||||
def is_best(self, label, key):
|
||||
"""
|
||||
Checks if the latest value of the given key in the label is the best so far.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
label : str
|
||||
The label of the metrics to be checked.
|
||||
key : str
|
||||
The key of the metric to be checked.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the latest value is the best so far, otherwise False.
|
||||
"""
|
||||
return self.history[label][key][-1] == min(self.history[label][key])
|
||||
|
||||
def state_dict(self):
|
||||
"""
|
||||
Returns a dictionary containing the state of the tracker.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A dictionary containing the history and step of the tracker.
|
||||
"""
|
||||
return {"history": self.history, "step": self.step}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""
|
||||
Loads the state of the tracker from the given state dictionary.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state_dict : dict
|
||||
A dictionary containing the history and step of the tracker.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tracker
|
||||
The tracker object with the loaded state.
|
||||
"""
|
||||
self.history = state_dict["history"]
|
||||
self.step = state_dict["step"]
|
||||
return self
|
Loading…
Reference in new issue