You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
90 lines
2.4 KiB
90 lines
2.4 KiB
# MIT License, Copyright (c) 2023-Present, Descript.
|
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/ml/test_model.py)
|
|
import sys
|
|
import tempfile
|
|
|
|
import paddle
|
|
from paddle import nn
|
|
|
|
from audio.audiotools import ml
|
|
from audio.audiotools import util
|
|
from paddlespeech.vector.training.seeding import seed_everything
|
|
SEED = 0
|
|
|
|
|
|
def seed_and_run(model, *args, **kwargs):
|
|
seed_everything(SEED)
|
|
return model(*args, **kwargs)
|
|
|
|
|
|
class Model(ml.BaseModel):
|
|
def __init__(self, arg1: float=1.0):
|
|
super().__init__()
|
|
self.arg1 = arg1
|
|
self.linear = nn.Linear(1, 1)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
|
|
class OtherModel(ml.BaseModel):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = nn.Linear(1, 1)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
|
|
def test_base_model():
|
|
# Save and load
|
|
# ml.BaseModel.EXTERN += ["test_model"]
|
|
|
|
x = paddle.randn([10, 1])
|
|
model1 = Model()
|
|
|
|
# assert str(model1.device) == 'Place(cpu)'
|
|
|
|
out1 = seed_and_run(model1, x)
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".pdparams") as f:
|
|
model1.save(
|
|
f.name, )
|
|
model2 = Model.load(f.name)
|
|
out2 = seed_and_run(model2, x)
|
|
assert paddle.allclose(out1, out2)
|
|
|
|
# test re-export
|
|
model2.save(f.name)
|
|
model3 = Model.load(f.name)
|
|
out3 = seed_and_run(model3, x)
|
|
assert paddle.allclose(out1, out3)
|
|
|
|
# make sure legacy/save load works
|
|
model1.save(f.name, package=False)
|
|
model2 = Model.load(f.name)
|
|
out2 = seed_and_run(model2, x)
|
|
assert paddle.allclose(out1, out2)
|
|
|
|
# make sure new way -> legacy save -> legacy load works
|
|
model1.save(f.name, package=False)
|
|
model2 = Model.load(f.name)
|
|
model2.save(f.name, package=False)
|
|
model3 = Model.load(f.name)
|
|
out3 = seed_and_run(model3, x)
|
|
|
|
# save/load without package, but with model2 being a model
|
|
# without an argument of arg1 to its instantiation.
|
|
model1.save(f.name, package=False)
|
|
model2 = OtherModel.load(f.name)
|
|
out2 = seed_and_run(model2, x)
|
|
assert paddle.allclose(out1, out2)
|
|
|
|
assert paddle.allclose(out1, out3)
|
|
|
|
with tempfile.TemporaryDirectory() as d:
|
|
model1.save_to_folder(d, {"data": 1.0})
|
|
Model.load_from_folder(d)
|