From afa6f12ba14d6f7abddbc6faaa93dbfcc9581033 Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Wed, 26 Feb 2025 11:16:14 +0800 Subject: [PATCH] paddlespeech/audiotools/ml/basemodel.py (#3994) --- paddlespeech/audiotools/ml/basemodel.py | 37 ++++++++++--------------- 1 file changed, 14 insertions(+), 23 deletions(-) diff --git a/paddlespeech/audiotools/ml/basemodel.py b/paddlespeech/audiotools/ml/basemodel.py index 97c31ff7a..2d5683266 100644 --- a/paddlespeech/audiotools/ml/basemodel.py +++ b/paddlespeech/audiotools/ml/basemodel.py @@ -110,7 +110,8 @@ class BaseModel(nn.Layer): state_dict = {"state_dict": self.state_dict(), "metadata": metadata} paddle.save(state_dict, str(path)) else: - self._save_package(path, intern=intern, extern=extern, mock=mock) + raise NotImplementedError( + "Currently Paddle does not support packaging") return path @@ -151,31 +152,21 @@ class BaseModel(nn.Layer): BaseModel A model that inherits from BaseModel. """ - try: - model = cls._load_package(location, package_name=package_name) - except: - model_dict = paddle.load(location) - metadata = model_dict["metadata"] - metadata["kwargs"].update(kwargs) - - sig = inspect.signature(cls) - class_keys = list(sig.parameters.keys()) - for k in list(metadata["kwargs"].keys()): - if k not in class_keys: - metadata["kwargs"].pop(k) - - model = cls(*args, **metadata["kwargs"]) - model.set_state_dict(model_dict["state_dict"]) - model.metadata = metadata + model_dict = paddle.load(location) + metadata = model_dict["metadata"] + metadata["kwargs"].update(kwargs) - return model + sig = inspect.signature(cls) + class_keys = list(sig.parameters.keys()) + for k in list(metadata["kwargs"].keys()): + if k not in class_keys: + metadata["kwargs"].pop(k) - def _save_package(self, path, intern=[], extern=[], mock=[], **kwargs): - raise NotImplementedError("Currently Paddle does not support packaging") + model = cls(*args, **metadata["kwargs"]) + model.set_state_dict(model_dict["state_dict"]) + model.metadata = metadata - @classmethod - def _load_package(cls, path, package_name=None): - raise NotImplementedError("Currently Paddle does not support packaging") + return model def save_to_folder( self,