From 3064ed60f305dbb2f159537b00ef005e0cbae9f4 Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Mon, 8 Nov 2021 20:49:07 +0800 Subject: [PATCH] Test: new database models and middlewares --- middleware/cluster_test.go | 76 ++++++++++++++++++++++++++++++++++++++ models/download_test.go | 11 ++++++ models/node_test.go | 64 ++++++++++++++++++++++++++++++++ models/share_test.go | 12 ------ models/user_test.go | 13 +------ 5 files changed, 153 insertions(+), 23 deletions(-) create mode 100644 models/node_test.go diff --git a/middleware/cluster_test.go b/middleware/cluster_test.go index 2c25e29..e1e61e3 100644 --- a/middleware/cluster_test.go +++ b/middleware/cluster_test.go @@ -1,12 +1,17 @@ package middleware import ( + "errors" model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/gin-gonic/gin" "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" + testMock "github.com/stretchr/testify/mock" "net/http/httptest" "testing" ) @@ -74,7 +79,78 @@ func TestSlaveRPCSignRequired(t *testing.T) { } } +type SlaveControllerMock struct { + testMock.Mock +} + +func (s SlaveControllerMock) HandleHeartBeat(pingReq *serializer.NodePingReq) (serializer.NodePingResp, error) { + args := s.Called(pingReq) + return args.Get(0).(serializer.NodePingResp), args.Error(1) +} + +func (s SlaveControllerMock) GetAria2Instance(s2 string) (common.Aria2, error) { + args := s.Called(s2) + return args.Get(0).(common.Aria2), args.Error(1) +} + +func (s SlaveControllerMock) SendNotification(s3 string, s2 string, message mq.Message) error { + args := s.Called(s3, s2, message) + return args.Error(0) +} + +func (s SlaveControllerMock) SubmitTask(s3 string, i interface{}, s2 string, f func(interface{})) error { + args := s.Called(s3, i, s2, f) + return args.Error(0) +} + +func (s SlaveControllerMock) GetMasterInfo(s2 string) (*cluster.MasterInfo, error) { + args := s.Called(s2) + return args.Get(0).(*cluster.MasterInfo), args.Error(1) +} + +func (s SlaveControllerMock) GetOneDriveToken(s2 string, u uint) (string, error) { + args := s.Called(s2, u) + return args.String(0), args.Error(1) +} + func TestUseSlaveAria2Instance(t *testing.T) { a := assert.New(t) + // MasterSiteID not set + { + testController := &SlaveControllerMock{} + useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + useSlaveAria2InstanceFunc(c) + a.True(c.IsAborted()) + } + + // Cannot get aria2 instances + { + testController := &SlaveControllerMock{} + useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Set("MasterSiteID", "expectedSiteID") + testController.On("GetAria2Instance", "expectedSiteID").Return(&common.DummyAria2{}, errors.New("error")) + useSlaveAria2InstanceFunc(c) + a.True(c.IsAborted()) + testController.AssertExpectations(t) + } + + // Success + { + testController := &SlaveControllerMock{} + useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Set("MasterSiteID", "expectedSiteID") + testController.On("GetAria2Instance", "expectedSiteID").Return(&common.DummyAria2{}, nil) + useSlaveAria2InstanceFunc(c) + a.False(c.IsAborted()) + res, _ := c.Get("MasterAria2Instance") + a.NotNil(res) + testController.AssertExpectations(t) + } } diff --git a/models/download_test.go b/models/download_test.go index 9d9cd34..367afb7 100644 --- a/models/download_test.go +++ b/models/download_test.go @@ -177,3 +177,14 @@ func TestDownload_Delete(t *testing.T) { } } + +func TestDownload_GetNodeID(t *testing.T) { + a := assert.New(t) + record := Download{} + + // compatible with 3.4 + a.EqualValues(1, record.GetNodeID()) + + record.NodeID = 5 + a.EqualValues(5, record.GetNodeID()) +} diff --git a/models/node_test.go b/models/node_test.go new file mode 100644 index 0000000..ddc1f95 --- /dev/null +++ b/models/node_test.go @@ -0,0 +1,64 @@ +package model + +import ( + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestGetNodeByID(t *testing.T) { + a := assert.New(t) + mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + res, err := GetNodeByID(1) + a.NoError(err) + a.EqualValues(1, res.ID) + a.NoError(mock.ExpectationsWereMet()) +} + +func TestGetNodesByStatus(t *testing.T) { + a := assert.New(t) + mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"status"}).AddRow(NodeActive)) + res, err := GetNodesByStatus(NodeActive) + a.NoError(err) + a.Len(res, 1) + a.EqualValues(NodeActive, res[0].Status) + a.NoError(mock.ExpectationsWereMet()) +} + +func TestNode_AfterFind(t *testing.T) { + a := assert.New(t) + node := &Node{} + + // No aria2 options + { + a.NoError(node.AfterFind()) + } + + // with aria2 options + { + node.Aria2Options = `{"timeout":1}` + a.NoError(node.AfterFind()) + a.Equal(1, node.Aria2OptionsSerialized.Timeout) + } +} + +func TestNode_BeforeSave(t *testing.T) { + a := assert.New(t) + node := &Node{} + + node.Aria2OptionsSerialized.Timeout = 1 + a.NoError(node.BeforeSave()) + a.Contains("1", node.Aria2Options) +} + +func TestNode_SetStatus(t *testing.T) { + a := assert.New(t) + node := &Node{} + + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)nodes").WithArgs(NodeActive, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + a.NoError(node.SetStatus(NodeActive)) + a.Equal(NodeActive, node.Status) + a.NoError(mock.ExpectationsWereMet()) +} diff --git a/models/share_test.go b/models/share_test.go index 52e2ee6..b3fdf0a 100644 --- a/models/share_test.go +++ b/models/share_test.go @@ -188,18 +188,6 @@ func TestShare_CanBeDownloadBy(t *testing.T) { asserts.Error(share.CanBeDownloadBy(user)) } - // 未登录,需要积分 - { - user := &User{ - Group: Group{ - OptionsSerialized: GroupOption{ - ShareDownload: true, - }, - }, - } - asserts.Error(share.CanBeDownloadBy(user)) - } - // 成功 { user := &User{ diff --git a/models/user_test.go b/models/user_test.go index 5b4d375..a85ddbd 100644 --- a/models/user_test.go +++ b/models/user_test.go @@ -177,10 +177,10 @@ func TestNewUser(t *testing.T) { func TestUser_AfterFind(t *testing.T) { asserts := assert.New(t) - cache.Deletes([]string{"1"}, "policy_") + cache.Deletes([]string{"0"}, "policy_") policyRows := sqlmock.NewRows([]string{"id", "name"}). - AddRow(1, "默认存储策略") + AddRow(144, "默认存储策略") mock.ExpectQuery("^SELECT (.+)").WillReturnRows(policyRows) newUser := NewUser() @@ -240,11 +240,6 @@ func TestUser_GetRemainingCapacity(t *testing.T) { newUser.Group.MaxStorage = 100 newUser.Storage = 200 asserts.Equal(uint64(0), newUser.GetRemainingCapacity()) - - cache.Set("pack_size_0", uint64(10), 0) - newUser.Group.MaxStorage = 100 - newUser.Storage = 101 - asserts.Equal(uint64(9), newUser.GetRemainingCapacity()) } func TestUser_DeductionCapacity(t *testing.T) { @@ -280,10 +275,6 @@ func TestUser_DeductionCapacity(t *testing.T) { asserts.Equal(false, newUser.IncreaseStorage(1)) asserts.Equal(uint64(100), newUser.Storage) - cache.Set("pack_size_1", uint64(1), 0) - asserts.Equal(true, newUser.IncreaseStorage(1)) - asserts.Equal(uint64(101), newUser.Storage) - asserts.True(newUser.IncreaseStorage(0)) }