You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/audio/tests/audiotools/ml/test_model.py

90 lines
2.4 KiB

# MIT License, Copyright (c) 2023-Present, Descript.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/ml/test_model.py)
import sys
import tempfile
import paddle
from paddle import nn
from audio.audiotools import ml
from audio.audiotools import util
from paddlespeech.vector.training.seeding import seed_everything
SEED = 0
def seed_and_run(model, *args, **kwargs):
seed_everything(SEED)
return model(*args, **kwargs)
class Model(ml.BaseModel):
def __init__(self, arg1: float=1.0):
super().__init__()
self.arg1 = arg1
self.linear = nn.Linear(1, 1)
def forward(self, x):
return self.linear(x)
class OtherModel(ml.BaseModel):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1, 1)
def forward(self, x):
return self.linear(x)
def test_base_model():
# Save and load
# ml.BaseModel.EXTERN += ["test_model"]
x = paddle.randn([10, 1])
model1 = Model()
# assert str(model1.device) == 'Place(cpu)'
out1 = seed_and_run(model1, x)
with tempfile.NamedTemporaryFile(suffix=".pdparams") as f:
model1.save(
f.name, )
model2 = Model.load(f.name)
out2 = seed_and_run(model2, x)
assert paddle.allclose(out1, out2)
# test re-export
model2.save(f.name)
model3 = Model.load(f.name)
out3 = seed_and_run(model3, x)
assert paddle.allclose(out1, out3)
# make sure legacy/save load works
model1.save(f.name, package=False)
model2 = Model.load(f.name)
out2 = seed_and_run(model2, x)
assert paddle.allclose(out1, out2)
# make sure new way -> legacy save -> legacy load works
model1.save(f.name, package=False)
model2 = Model.load(f.name)
model2.save(f.name, package=False)
model3 = Model.load(f.name)
out3 = seed_and_run(model3, x)
# save/load without package, but with model2 being a model
# without an argument of arg1 to its instantiation.
model1.save(f.name, package=False)
model2 = OtherModel.load(f.name)
out2 = seed_and_run(model2, x)
assert paddle.allclose(out1, out2)
assert paddle.allclose(out1, out3)
with tempfile.TemporaryDirectory() as d:
model1.save_to_folder(d, {"data": 1.0})
Model.load_from_folder(d)