paddlespeech/audiotools/ml/basemodel.py (#3994)

pull/3995/head
zxcd 6 months ago committed by GitHub
parent 793a89d53c
commit afa6f12ba1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -110,7 +110,8 @@ class BaseModel(nn.Layer):
state_dict = {"state_dict": self.state_dict(), "metadata": metadata} state_dict = {"state_dict": self.state_dict(), "metadata": metadata}
paddle.save(state_dict, str(path)) paddle.save(state_dict, str(path))
else: else:
self._save_package(path, intern=intern, extern=extern, mock=mock) raise NotImplementedError(
"Currently Paddle does not support packaging")
return path return path
@ -151,31 +152,21 @@ class BaseModel(nn.Layer):
BaseModel BaseModel
A model that inherits from BaseModel. A model that inherits from BaseModel.
""" """
try: model_dict = paddle.load(location)
model = cls._load_package(location, package_name=package_name) metadata = model_dict["metadata"]
except: metadata["kwargs"].update(kwargs)
model_dict = paddle.load(location)
metadata = model_dict["metadata"]
metadata["kwargs"].update(kwargs)
sig = inspect.signature(cls)
class_keys = list(sig.parameters.keys())
for k in list(metadata["kwargs"].keys()):
if k not in class_keys:
metadata["kwargs"].pop(k)
model = cls(*args, **metadata["kwargs"])
model.set_state_dict(model_dict["state_dict"])
model.metadata = metadata
return model sig = inspect.signature(cls)
class_keys = list(sig.parameters.keys())
for k in list(metadata["kwargs"].keys()):
if k not in class_keys:
metadata["kwargs"].pop(k)
def _save_package(self, path, intern=[], extern=[], mock=[], **kwargs): model = cls(*args, **metadata["kwargs"])
raise NotImplementedError("Currently Paddle does not support packaging") model.set_state_dict(model_dict["state_dict"])
model.metadata = metadata
@classmethod return model
def _load_package(cls, path, package_name=None):
raise NotImplementedError("Currently Paddle does not support packaging")
def save_to_folder( def save_to_folder(
self, self,

Loading…
Cancel
Save