diff --git a/audio/audiotools/basemodel.py b/audio/audiotools/basemodel.py new file mode 100644 index 000000000..2b9f916f9 --- /dev/null +++ b/audio/audiotools/basemodel.py @@ -0,0 +1,275 @@ +import inspect +import shutil +import tempfile +import typing +from pathlib import Path + +import paddle +from paddle import nn + + +class BaseModel(nn.Layer): + """This is a class that adds useful save/load functionality to a + ``paddle.nn.Layer`` object. ``BaseModel`` objects can be saved + as ``package`` easily, making them super easy to port between + machines without requiring a ton of dependencies. Files can also be + saved as just weights, in the standard way. + + >>> class Model(ml.BaseModel): + >>> def __init__(self, arg1: float = 1.0): + >>> super().__init__() + >>> self.arg1 = arg1 + >>> self.linear = nn.Linear(1, 1) + >>> + >>> def forward(self, x): + >>> return self.linear(x) + >>> + >>> model1 = Model() + >>> + >>> with tempfile.NamedTemporaryFile(suffix=".pth") as f: + >>> model1.save( + >>> f.name, + >>> ) + >>> model2 = Model.load(f.name) + >>> out2 = seed_and_run(model2, x) + >>> assert paddle.allclose(out1, out2) + >>> + >>> model1.save(f.name, package=True) + >>> model2 = Model.load(f.name) + >>> model2.save(f.name, package=False) + >>> model3 = Model.load(f.name) + >>> out3 = seed_and_run(model3, x) + >>> + >>> with tempfile.TemporaryDirectory() as d: + >>> model1.save_to_folder(d, {"data": 1.0}) + >>> Model.load_from_folder(d) + + """ + + INTERN = [] + + def save( + self, + path: str, + metadata: dict = None, + package: bool = False, + intern: list = [], + extern: list = [], + mock: list = [], + ): + """Saves the model, either as a package, or just as + weights, alongside some specified metadata. + + Parameters + ---------- + path : str + Path to save model to. + metadata : dict, optional + Any metadata to save alongside the model, + by default None + package : bool, optional + Whether to use ``package`` to save the model in + a format that is portable, by default True + intern : list, optional + List of additional libraries that are internal + to the model, used with package, by default [] + extern : list, optional + List of additional libraries that are external to + the model, used with package, by default [] + mock : list, optional + List of libraries to mock, used with package, + by default [] + + Returns + ------- + str + Path to saved model. + """ + sig = inspect.signature(self.__class__) + args = {} + + for key, val in sig.parameters.items(): + arg_val = val.default + if arg_val is not inspect.Parameter.empty: + args[key] = arg_val + + # Look up attibutes in self, and if any of them are in args, + # overwrite them in args. + for attribute in dir(self): + if attribute in args: + args[attribute] = getattr(self, attribute) + + metadata = {} if metadata is None else metadata + metadata["kwargs"] = args + if not hasattr(self, "metadata"): + self.metadata = {} + self.metadata.update(metadata) + + if not package: + state_dict = {"state_dict": self.state_dict(), "metadata": metadata} + paddle.save(state_dict, path) + else: + self._save_package(path, intern=intern, extern=extern, mock=mock) + + return path + + @property + def device(self): + """Gets the device the model is on by looking at the device of + the first parameter. May not be valid if model is split across + multiple devices. + """ + return list(self.parameters())[0].device + + @classmethod + def load( + cls, + location: str, + *args, + package_name: str = None, + strict: bool = False, + **kwargs, + ): + """Load model from a path. Tries first to load as a package, and if + that fails, tries to load as weights. The arguments to the class are + specified inside the model weights file. + + Parameters + ---------- + location : str + Path to file. + package_name : str, optional + Name of package, by default ``cls.__name__``. + strict : bool, optional + Ignore unmatched keys, by default False + kwargs : dict + Additional keyword arguments to the model instantiation, if + not loading from package. + + Returns + ------- + BaseModel + A model that inherits from BaseModel. + """ + try: + model = cls._load_package(location, package_name=package_name) + except: + model_dict = paddle.load(location, "cpu") + metadata = model_dict["metadata"] + metadata["kwargs"].update(kwargs) + + sig = inspect.signature(cls) + class_keys = list(sig.parameters.keys()) + for k in list(metadata["kwargs"].keys()): + if k not in class_keys: + metadata["kwargs"].pop(k) + + model = cls(*args, **metadata["kwargs"]) + model.load_state_dict(model_dict["state_dict"], strict=strict) + model.metadata = metadata + + return model + + def _save_package(self, path, intern=[], extern=[], mock=[], **kwargs): + raise NotImplementedError("Currently Paddle does not support packaging") + + @classmethod + def _load_package(cls, path, package_name=None): + raise NotImplementedError("Currently Paddle does not support packaging") + + def save_to_folder( + self, + folder: typing.Union[str, Path], + extra_data: dict = None, + package: bool = False, + ): + """Dumps a model into a folder, as both a package + and as weights, as well as anything specified in + ``extra_data``. ``extra_data`` is a dictionary of other + pickleable files, with the keys being the paths + to save them in. The model is saved under a subfolder + specified by the name of the class (e.g. ``folder/generator/[package, weights].pth`` + if the model name was ``Generator``). + + >>> with tempfile.TemporaryDirectory() as d: + >>> extra_data = { + >>> "optimizer.pth": optimizer.state_dict() + >>> } + >>> model.save_to_folder(d, extra_data) + >>> Model.load_from_folder(d) + + Parameters + ---------- + folder : typing.Union[str, Path] + _description_ + extra_data : dict, optional + _description_, by default None + + Returns + ------- + str + Path to folder + """ + extra_data = {} if extra_data is None else extra_data + model_name = type(self).__name__.lower() + target_base = Path(f"{folder}/{model_name}/") + target_base.mkdir(exist_ok=True, parents=True) + + if package: + package_path = target_base / f"package.pth" + self.save(package_path) + + weights_path = target_base / f"weights.pth" + self.save(weights_path, package=False) + + for path, obj in extra_data.items(): + paddle.save(obj, target_base / path) + + return target_base + + @classmethod + def load_from_folder( + cls, + folder: typing.Union[str, Path], + package: bool = False, + strict: bool = False, + **kwargs, + ): + """Loads the model from a folder generated by + :py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`. + Like that function, this one looks for a subfolder that has + the name of the class (e.g. ``folder/generator/[package, weights].pth`` if the + model name was ``Generator``). + + Parameters + ---------- + folder : typing.Union[str, Path] + _description_ + package : bool, optional + Whether to use ``package`` to load the model, + loading the model from ``package.pth``. + strict : bool, optional + Ignore unmatched keys, by default False + + Returns + ------- + tuple + tuple of model and extra data as saved by + :py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`. + """ + folder = Path(folder) / cls.__name__.lower() + model_pth = "package.pth" if package else "weights.pth" + model_pth = folder / model_pth + + model = cls.load(model_pth, strict=strict) + extra_data = {} + excluded = ["package.pth", "weights.pth"] + files = [ + x + for x in folder.glob("*") + if x.is_file() and x.name not in excluded + ] + for f in files: + extra_data[f.name] = paddle.load(f, **kwargs) + + return model, extra_data diff --git a/audio/audiotools/decorators.py b/audio/audiotools/decorators.py new file mode 100644 index 000000000..27982758a --- /dev/null +++ b/audio/audiotools/decorators.py @@ -0,0 +1,447 @@ +import math +import os +import time +from collections import defaultdict +from functools import wraps + +import paddle +import paddle.distributed as dist +from rich import box +from rich.console import Console +from rich.console import Group +from rich.live import Live +from rich.markdown import Markdown +from rich.padding import Padding +from rich.panel import Panel +from rich.progress import BarColumn +from rich.progress import Progress +from rich.progress import SpinnerColumn +from rich.progress import TimeElapsedColumn +from rich.progress import TimeRemainingColumn +from rich.rule import Rule +from rich.table import Table +from visualdl import LogWriter + + +# This is here so that the history can be pickled. +def default_list(): + return [] + + +class Mean: + """✅Keeps track of the running mean, along with the latest + value. + """ + + def __init__(self): + self.reset() + + def __call__(self): + mean = self.total / max(self.count, 1) + return mean + + def reset(self): + self.count = 0 + self.total = 0 + + def update(self, val): + if math.isfinite(val): + self.count += 1 + self.total += val + + +def when(condition): + """✅Runs a function only when the condition is met. The condition is + a function that is run. + + Parameters + ---------- + condition : Callable + Function to run to check whether or not to run the decorated + function. + + Example + ------- + Checkpoint only runs every 100 iterations, and only if the + local rank is 0. + + >>> i = 0 + >>> rank = 0 + >>> + >>> @when(lambda: i % 100 == 0 and rank == 0) + >>> def checkpoint(): + >>> print("Saving to /runs/exp1") + >>> + >>> for i in range(1000): + >>> checkpoint() + + """ + + def decorator(fn): + @wraps(fn) + def decorated(*args, **kwargs): + if condition(): + return fn(*args, **kwargs) + + return decorated + + return decorator + + +def timer(prefix: str = "time"): + """✅Adds execution time to the output dictionary of the decorated + function. The function decorated by this must output a dictionary. + The key added will follow the form "[prefix]/[name_of_function]" + + Parameters + ---------- + prefix : str, optional + The key added will follow the form "[prefix]/[name_of_function]", + by default "time". + """ + + def decorator(fn): + @wraps(fn) + def decorated(*args, **kwargs): + s = time.perf_counter() + output = fn(*args, **kwargs) + assert isinstance(output, dict) + e = time.perf_counter() + output[f"{prefix}/{fn.__name__}"] = e - s + return output + + return decorated + + return decorator + + +class Tracker: + """✅ + A tracker class that helps to monitor the progress of training and logging the metrics. + + Attributes + ---------- + metrics : dict + A dictionary containing the metrics for each label. + history : dict + A dictionary containing the history of metrics for each label. + writer : LogWriter + A LogWriter object for logging the metrics. + rank : int + The rank of the current process. + step : int + The current step of the training. + tasks : dict + A dictionary containing the progress bars and tables for each label. + pbar : Progress + A progress bar object for displaying the progress. + consoles : list + A list of console objects for logging. + live : Live + A Live object for updating the display live. + + Methods + ------- + print(msg: str) + Prints the given message to all consoles. + update(label: str, fn_name: str) + Updates the progress bar and table for the given label. + done(label: str, title: str) + Resets the progress bar and table for the given label and prints the final result. + track(label: str, length: int, completed: int = 0, op: dist.ReduceOp = dist.ReduceOp.AVG, ddp_active: bool = "LOCAL_RANK" in os.environ) + A decorator for tracking the progress and metrics of a function. + log(label: str, value_type: str = "value", history: bool = True) + A decorator for logging the metrics of a function. + is_best(label: str, key: str) -> bool + Checks if the latest value of the given key in the label is the best so far. + state_dict() -> dict + Returns a dictionary containing the state of the tracker. + load_state_dict(state_dict: dict) -> Tracker + Loads the state of the tracker from the given state dictionary. + """ + + def __init__( + self, + writer: LogWriter = None, + log_file: str = None, + rank: int = 0, + console_width: int = 100, + step: int = 0, + ): + """ + Initializes the Tracker object. + + Parameters + ---------- + writer : LogWriter, optional + A LogWriter object for logging the metrics, by default None. + log_file : str, optional + The path to the log file, by default None. + rank : int, optional + The rank of the current process, by default 0. + console_width : int, optional + The width of the console, by default 100. + step : int, optional + The current step of the training, by default 0. + """ + self.metrics = {} + self.history = {} + self.writer = writer + self.rank = rank + self.step = step + + # Create progress bars etc. + self.tasks = {} + self.pbar = Progress( + SpinnerColumn(), + "[progress.description]{task.description}", + "{task.completed}/{task.total}", + BarColumn(), + TimeElapsedColumn(), + "/", + TimeRemainingColumn(), + ) + self.consoles = [Console(width=console_width)] + self.live = Live(console=self.consoles[0], refresh_per_second=10) + if log_file is not None: + self.consoles.append( + Console(width=console_width, file=open(log_file, "a")) + ) + + def print(self, msg): + """ + Prints the given message to all consoles. + + Parameters + ---------- + msg : str + The message to be printed. + """ + if self.rank == 0: + for c in self.consoles: + c.log(msg) + + def update(self, label, fn_name): + """ + Updates the progress bar and table for the given label. + + Parameters + ---------- + label : str + The label of the progress bar and table to be updated. + fn_name : str + The name of the function associated with the label. + """ + if self.rank == 0: + self.pbar.advance(self.tasks[label]["pbar"]) + + # Create table + table = Table(title=label, expand=True, box=box.MINIMAL) + table.add_column("key", style="cyan") + table.add_column("value", style="bright_blue") + table.add_column("mean", style="bright_green") + + keys = self.metrics[label]["value"].keys() + for k in keys: + value = self.metrics[label]["value"][k] + mean = self.metrics[label]["mean"][k]() + table.add_row(k, f"{value:10.6f}", f"{mean:10.6f}") + + self.tasks[label]["table"] = table + tables = [t["table"] for t in self.tasks.values()] + group = Group(*tables, self.pbar) + self.live.update( + Group( + Padding("", (0, 0)), + Rule(f"[italic]{fn_name}()", style="white"), + Padding("", (0, 0)), + Panel.fit( + group, + padding=(0, 5), + title="[b]Progress", + border_style="blue", + ), + ) + ) + + def done(self, label: str, title: str): + """ + Resets the progress bar and table for the given label and prints the final result. + + Parameters + ---------- + label : str + The label of the progress bar and table to be reset. + title : str + The title to be displayed when printing the final result. + """ + for label in self.metrics: + for v in self.metrics[label]["mean"].values(): + v.reset() + + if self.rank == 0: + self.pbar.reset(self.tasks[label]["pbar"]) + tables = [t["table"] for t in self.tasks.values()] + group = Group(Markdown(f"# {title}"), *tables, self.pbar) + self.print(group) + + def track( + self, + label: str, + length: int, + completed: int = 0, + op: dist.ReduceOp = dist.ReduceOp.AVG, + ddp_active: bool = "LOCAL_RANK" in os.environ, + ): + """ + A decorator for tracking the progress and metrics of a function. + + Parameters + ---------- + label : str + The label to be associated with the progress and metrics. + length : int + The total number of iterations to be completed. + completed : int, optional + The number of iterations already completed, by default 0. + op : dist.ReduceOp, optional + The reduce operation to be used, by default dist.ReduceOp.AVG. + ddp_active : bool, optional + Whether the DistributedDataParallel is active, by default "LOCAL_RANK" in os.environ. + """ + self.tasks[label] = { + "pbar": self.pbar.add_task( + f"[white]Iteration ({label})", total=length, completed=completed + ), + "table": Table(), + } + self.metrics[label] = { + "value": defaultdict(), + "mean": defaultdict(lambda: Mean()), + } + + def decorator(fn): + @wraps(fn) + def decorated(*args, **kwargs): + output = fn(*args, **kwargs) + if not isinstance(output, dict): + self.update(label, fn.__name__) + return output + # Collect across all DDP processes + scalar_keys = [] + for k, v in output.items(): + if isinstance(v, (int, float)): + v = paddle.to_tensor([v]) + if not paddle.is_tensor(v): + continue + if ddp_active and v.is_cuda: # pragma: no cover + dist.all_reduce(v, op=op) + output[k] = v.detach() + if paddle.numel(v) == 1: + scalar_keys.append(k) + output[k] = v.item() + + # Save the outputs to tracker + for k, v in output.items(): + if k not in scalar_keys: + continue + self.metrics[label]["value"][k] = v + # Update the running mean + self.metrics[label]["mean"][k].update(v) + + self.update(label, fn.__name__) + return output + + return decorated + + return decorator + + def log(self, label: str, value_type: str = "value", history: bool = True): + """ + A decorator for logging the metrics of a function. + + Parameters + ---------- + label : str + The label to be associated with the logging. + value_type : str, optional + The type of value to be logged, by default "value". + history : bool, optional + Whether to save the history of the metrics, by default True. + """ + assert value_type in ["mean", "value"] + if history: + if label not in self.history: + self.history[label] = defaultdict(default_list) + + def decorator(fn): + @wraps(fn) + def decorated(*args, **kwargs): + output = fn(*args, **kwargs) + if self.rank == 0: + nonlocal value_type, label + metrics = self.metrics[label][value_type] + for k, v in metrics.items(): + v = v() if isinstance(v, Mean) else v + if self.writer is not None: + self.writer.add_scalar( + tag=f"{k}/{label}", value=v, step=self.step + ) + if label in self.history: + self.history[label][k].append(v) + + if label in self.history: + self.history[label]["step"].append(self.step) + + return output + + return decorated + + return decorator + + def is_best(self, label, key): + """ + Checks if the latest value of the given key in the label is the best so far. + + Parameters + ---------- + label : str + The label of the metrics to be checked. + key : str + The key of the metric to be checked. + + Returns + ------- + bool + True if the latest value is the best so far, otherwise False. + """ + return self.history[label][key][-1] == min(self.history[label][key]) + + def state_dict(self): + """ + Returns a dictionary containing the state of the tracker. + + Returns + ------- + dict + A dictionary containing the history and step of the tracker. + """ + return {"history": self.history, "step": self.step} + + def load_state_dict(self, state_dict): + """ + Loads the state of the tracker from the given state dictionary. + + Parameters + ---------- + state_dict : dict + A dictionary containing the history and step of the tracker. + + Returns + ------- + Tracker + The tracker object with the loaded state. + """ + self.history = state_dict["history"] + self.step = state_dict["step"] + return self