|
|
|
@ -110,7 +110,8 @@ class BaseModel(nn.Layer):
|
|
|
|
|
state_dict = {"state_dict": self.state_dict(), "metadata": metadata}
|
|
|
|
|
paddle.save(state_dict, str(path))
|
|
|
|
|
else:
|
|
|
|
|
self._save_package(path, intern=intern, extern=extern, mock=mock)
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"Currently Paddle does not support packaging")
|
|
|
|
|
|
|
|
|
|
return path
|
|
|
|
|
|
|
|
|
@ -151,31 +152,21 @@ class BaseModel(nn.Layer):
|
|
|
|
|
BaseModel
|
|
|
|
|
A model that inherits from BaseModel.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
model = cls._load_package(location, package_name=package_name)
|
|
|
|
|
except:
|
|
|
|
|
model_dict = paddle.load(location)
|
|
|
|
|
metadata = model_dict["metadata"]
|
|
|
|
|
metadata["kwargs"].update(kwargs)
|
|
|
|
|
|
|
|
|
|
sig = inspect.signature(cls)
|
|
|
|
|
class_keys = list(sig.parameters.keys())
|
|
|
|
|
for k in list(metadata["kwargs"].keys()):
|
|
|
|
|
if k not in class_keys:
|
|
|
|
|
metadata["kwargs"].pop(k)
|
|
|
|
|
|
|
|
|
|
model = cls(*args, **metadata["kwargs"])
|
|
|
|
|
model.set_state_dict(model_dict["state_dict"])
|
|
|
|
|
model.metadata = metadata
|
|
|
|
|
model_dict = paddle.load(location)
|
|
|
|
|
metadata = model_dict["metadata"]
|
|
|
|
|
metadata["kwargs"].update(kwargs)
|
|
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
sig = inspect.signature(cls)
|
|
|
|
|
class_keys = list(sig.parameters.keys())
|
|
|
|
|
for k in list(metadata["kwargs"].keys()):
|
|
|
|
|
if k not in class_keys:
|
|
|
|
|
metadata["kwargs"].pop(k)
|
|
|
|
|
|
|
|
|
|
def _save_package(self, path, intern=[], extern=[], mock=[], **kwargs):
|
|
|
|
|
raise NotImplementedError("Currently Paddle does not support packaging")
|
|
|
|
|
model = cls(*args, **metadata["kwargs"])
|
|
|
|
|
model.set_state_dict(model_dict["state_dict"])
|
|
|
|
|
model.metadata = metadata
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _load_package(cls, path, package_name=None):
|
|
|
|
|
raise NotImplementedError("Currently Paddle does not support packaging")
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
def save_to_folder(
|
|
|
|
|
self,
|
|
|
|
|