You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/audio/audiotools/ml/basemodel.py

273 lines
8.9 KiB

# MIT License, Copyright (c) 2023-Present, Descript.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/ml/layers/base.py)
import inspect
import shutil
import tempfile
import typing
from pathlib import Path
import paddle
from paddle import nn
class BaseModel(nn.Layer):
"""This is a class that adds useful save/load functionality to a
``paddle.nn.Layer`` object. ``BaseModel`` objects can be saved
as ``package`` easily, making them super easy to port between
machines without requiring a ton of dependencies. Files can also be
saved as just weights, in the standard way.
>>> class Model(ml.BaseModel):
>>> def __init__(self, arg1: float = 1.0):
>>> super().__init__()
>>> self.arg1 = arg1
>>> self.linear = nn.Linear(1, 1)
>>>
>>> def forward(self, x):
>>> return self.linear(x)
>>>
>>> model1 = Model()
>>>
>>> with tempfile.NamedTemporaryFile(suffix=".pth") as f:
>>> model1.save(
>>> f.name,
>>> )
>>> model2 = Model.load(f.name)
>>> out2 = seed_and_run(model2, x)
>>> assert paddle.allclose(out1, out2)
>>>
>>> model1.save(f.name, package=True)
>>> model2 = Model.load(f.name)
>>> model2.save(f.name, package=False)
>>> model3 = Model.load(f.name)
>>> out3 = seed_and_run(model3, x)
>>>
>>> with tempfile.TemporaryDirectory() as d:
>>> model1.save_to_folder(d, {"data": 1.0})
>>> Model.load_from_folder(d)
"""
def save(
self,
path: str,
metadata: dict=None,
package: bool=False,
intern: list=[],
extern: list=[],
mock: list=[], ):
"""Saves the model, either as a package, or just as
weights, alongside some specified metadata.
Parameters
----------
path : str
Path to save model to.
metadata : dict, optional
Any metadata to save alongside the model,
by default None
package : bool, optional
Whether to use ``package`` to save the model in
a format that is portable, by default True
intern : list, optional
List of additional libraries that are internal
to the model, used with package, by default []
extern : list, optional
List of additional libraries that are external to
the model, used with package, by default []
mock : list, optional
List of libraries to mock, used with package,
by default []
Returns
-------
str
Path to saved model.
"""
sig = inspect.signature(self.__class__)
args = {}
for key, val in sig.parameters.items():
arg_val = val.default
if arg_val is not inspect.Parameter.empty:
args[key] = arg_val
# Look up attibutes in self, and if any of them are in args,
# overwrite them in args.
for attribute in dir(self):
if attribute in args:
args[attribute] = getattr(self, attribute)
metadata = {} if metadata is None else metadata
metadata["kwargs"] = args
if not hasattr(self, "metadata"):
self.metadata = {}
self.metadata.update(metadata)
if not package:
state_dict = {"state_dict": self.state_dict(), "metadata": metadata}
paddle.save(state_dict, str(path))
else:
self._save_package(path, intern=intern, extern=extern, mock=mock)
return path
@property
def device(self):
"""Gets the device the model is on by looking at the device of
the first parameter. May not be valid if model is split across
multiple devices.
"""
return list(self.parameters())[0].place
@classmethod
def load(
cls,
location: str,
*args,
package_name: str=None,
strict: bool=False,
**kwargs, ):
"""Load model from a path. Tries first to load as a package, and if
that fails, tries to load as weights. The arguments to the class are
specified inside the model weights file.
Parameters
----------
location : str
Path to file.
package_name : str, optional
Name of package, by default ``cls.__name__``.
strict : bool, optional
Ignore unmatched keys, by default False
kwargs : dict
Additional keyword arguments to the model instantiation, if
not loading from package.
Returns
-------
BaseModel
A model that inherits from BaseModel.
"""
try:
model = cls._load_package(location, package_name=package_name)
except:
model_dict = paddle.load(location)
metadata = model_dict["metadata"]
metadata["kwargs"].update(kwargs)
sig = inspect.signature(cls)
class_keys = list(sig.parameters.keys())
for k in list(metadata["kwargs"].keys()):
if k not in class_keys:
metadata["kwargs"].pop(k)
model = cls(*args, **metadata["kwargs"])
model.set_state_dict(model_dict["state_dict"])
model.metadata = metadata
return model
def _save_package(self, path, intern=[], extern=[], mock=[], **kwargs):
raise NotImplementedError("Currently Paddle does not support packaging")
@classmethod
def _load_package(cls, path, package_name=None):
raise NotImplementedError("Currently Paddle does not support packaging")
def save_to_folder(
self,
folder: typing.Union[str, Path],
extra_data: dict=None,
package: bool=False, ):
"""Dumps a model into a folder, as both a package
and as weights, as well as anything specified in
``extra_data``. ``extra_data`` is a dictionary of other
pickleable files, with the keys being the paths
to save them in. The model is saved under a subfolder
specified by the name of the class (e.g. ``folder/generator/[package, weights].pth``
if the model name was ``Generator``).
>>> with tempfile.TemporaryDirectory() as d:
>>> extra_data = {
>>> "optimizer.pth": optimizer.state_dict()
>>> }
>>> model.save_to_folder(d, extra_data)
>>> Model.load_from_folder(d)
Parameters
----------
folder : typing.Union[str, Path]
_description_
extra_data : dict, optional
_description_, by default None
Returns
-------
str
Path to folder
"""
extra_data = {} if extra_data is None else extra_data
model_name = type(self).__name__.lower()
target_base = Path(f"{folder}/{model_name}/")
target_base.mkdir(exist_ok=True, parents=True)
if package:
package_path = target_base / f"package.pth"
self.save(package_path)
weights_path = target_base / f"weights.pth"
self.save(weights_path, package=False)
for path, obj in extra_data.items():
paddle.save(obj, str(target_base / path))
return target_base
@classmethod
def load_from_folder(
cls,
folder: typing.Union[str, Path],
package: bool=False,
strict: bool=False,
**kwargs, ):
"""Loads the model from a folder generated by
:py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`.
Like that function, this one looks for a subfolder that has
the name of the class (e.g. ``folder/generator/[package, weights].pth`` if the
model name was ``Generator``).
Parameters
----------
folder : typing.Union[str, Path]
_description_
package : bool, optional
Whether to use ``package`` to load the model,
loading the model from ``package.pth``.
strict : bool, optional
Ignore unmatched keys, by default False
Returns
-------
tuple
tuple of model and extra data as saved by
:py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`.
"""
folder = Path(folder) / cls.__name__.lower()
model_pth = "package.pth" if package else "weights.pth"
model_pth = folder / model_pth
model = cls.load(str(model_pth))
extra_data = {}
excluded = ["package.pth", "weights.pth"]
files = [
x for x in folder.glob("*")
if x.is_file() and x.name not in excluded
]
for f in files:
extra_data[f.name] = paddle.load(str(f), **kwargs)
return model, extra_data