From e41ec9defa3998571cddd8f8f0214a544c52f397 Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Mon, 8 Nov 2021 19:54:26 +0800 Subject: [PATCH 1/2] Refactor: move slave pkg inside of cluster Test: middleware for node communication --- bootstrap/init.go | 3 +- middleware/auth.go | 1 + middleware/auth_test.go | 17 +++- middleware/cluster.go | 9 +-- middleware/cluster_test.go | 80 +++++++++++++++++++ models/user.go | 2 +- pkg/{slave/slave.go => cluster/controller.go} | 9 +-- pkg/cluster/errors.go | 6 +- pkg/cluster/pool.go | 16 ++-- pkg/filesystem/driver/onedrive/oauth.go | 4 +- pkg/slave/errors.go | 7 -- pkg/task/slavetask/transfer.go | 8 +- routers/router.go | 5 +- service/aria2/add.go | 3 +- service/explorer/slave.go | 4 +- service/node/fabric.go | 4 +- 16 files changed, 135 insertions(+), 43 deletions(-) create mode 100644 middleware/cluster_test.go rename pkg/{slave/slave.go => cluster/controller.go} (96%) delete mode 100644 pkg/slave/errors.go diff --git a/bootstrap/init.go b/bootstrap/init.go index 6b43adbb..60ebcb80 100644 --- a/bootstrap/init.go +++ b/bootstrap/init.go @@ -9,7 +9,6 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/crontab" "github.com/cloudreve/Cloudreve/v3/pkg/email" - "github.com/cloudreve/Cloudreve/v3/pkg/slave" "github.com/cloudreve/Cloudreve/v3/pkg/task" "github.com/gin-gonic/gin" ) @@ -78,7 +77,7 @@ func Init(path string) { { "slave", func() { - slave.Init() + cluster.InitController() }, }, { diff --git a/middleware/auth.go b/middleware/auth.go index 135dd8c5..3e7cbe73 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -37,6 +37,7 @@ func SignRequired(authInstance auth.Auth) gin.HandlerFunc { c.Abort() return } + c.Next() } } diff --git a/middleware/auth_test.go b/middleware/auth_test.go index a17602d8..84d229e2 100644 --- a/middleware/auth_test.go +++ b/middleware/auth_test.go @@ -90,15 +90,27 @@ func TestSignRequired(t *testing.T) { rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request, _ = http.NewRequest("GET", "/test", nil) - SignRequiredFunc := SignRequired(auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}) + authInstance := auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))} + SignRequiredFunc := SignRequired(authInstance) // 鉴权失败 SignRequiredFunc(c) asserts.NotNil(c) + asserts.True(c.IsAborted()) + c, _ = gin.CreateTestContext(rec) c.Request, _ = http.NewRequest("PUT", "/test", nil) SignRequiredFunc(c) asserts.NotNil(c) + asserts.True(c.IsAborted()) + + // Sign verify success + c, _ = gin.CreateTestContext(rec) + c.Request, _ = http.NewRequest("PUT", "/test", nil) + c.Request = auth.SignRequest(authInstance, c.Request, 0) + SignRequiredFunc(c) + asserts.NotNil(c) + asserts.False(c.IsAborted()) } func TestWebDAVAuth(t *testing.T) { @@ -780,8 +792,6 @@ func TestS3CallbackAuth(t *testing.T) { WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1)) mock.ExpectQuery("SELECT(.+)groups(.+)"). WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[702]")) - mock.ExpectQuery("SELECT(.+)policies(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "access_key", "secret_key"}).AddRow(2, "123", "123")) c, _ := gin.CreateTestContext(rec) c.Params = []gin.Param{ {"key", "testCallBackUpyun"}, @@ -789,5 +799,6 @@ func TestS3CallbackAuth(t *testing.T) { c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1"))) AuthFunc(c) asserts.False(c.IsAborted()) + asserts.NoError(mock.ExpectationsWereMet()) } } diff --git a/middleware/cluster.go b/middleware/cluster.go index 079a4f42..d8bf979a 100644 --- a/middleware/cluster.go +++ b/middleware/cluster.go @@ -3,7 +3,6 @@ package middleware import ( "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/slave" "github.com/gin-gonic/gin" "strconv" ) @@ -19,11 +18,11 @@ func MasterMetadata() gin.HandlerFunc { } // UseSlaveAria2Instance 从机用于获取对应主机节点的Aria2实例 -func UseSlaveAria2Instance() gin.HandlerFunc { +func UseSlaveAria2Instance(clusterController cluster.Controller) gin.HandlerFunc { return func(c *gin.Context) { if siteID, exist := c.Get("MasterSiteID"); exist { // 获取对应主机节点的从机Aria2实例 - caller, err := slave.DefaultController.GetAria2Instance(siteID.(string)) + caller, err := clusterController.GetAria2Instance(siteID.(string)) if err != nil { c.JSON(200, serializer.Err(serializer.CodeNotSet, "无法获取 Aria2 实例", err)) c.Abort() @@ -40,7 +39,7 @@ func UseSlaveAria2Instance() gin.HandlerFunc { } } -func SlaveRPCSignRequired() gin.HandlerFunc { +func SlaveRPCSignRequired(nodePool cluster.Pool) gin.HandlerFunc { return func(c *gin.Context) { nodeID, err := strconv.ParseUint(c.GetHeader("X-Node-Id"), 10, 64) if err != nil { @@ -49,7 +48,7 @@ func SlaveRPCSignRequired() gin.HandlerFunc { return } - slaveNode := cluster.Default.GetNodeByID(uint(nodeID)) + slaveNode := nodePool.GetNodeByID(uint(nodeID)) if slaveNode == nil { c.JSON(200, serializer.ParamErr("未知的主机节点ID", err)) c.Abort() diff --git a/middleware/cluster_test.go b/middleware/cluster_test.go new file mode 100644 index 00000000..2c25e292 --- /dev/null +++ b/middleware/cluster_test.go @@ -0,0 +1,80 @@ +package middleware + +import ( + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/auth" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" + "github.com/gin-gonic/gin" + "github.com/jinzhu/gorm" + "github.com/stretchr/testify/assert" + "net/http/httptest" + "testing" +) + +func TestMasterMetadata(t *testing.T) { + a := assert.New(t) + masterMetaDataFunc := MasterMetadata() + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest("GET", "/", nil) + + c.Request.Header = map[string][]string{ + "X-Site-Id": {"expectedSiteID"}, + "X-Site-Url": {"expectedSiteURL"}, + "X-Cloudreve-Version": {"expectedMasterVersion"}, + } + masterMetaDataFunc(c) + siteID, _ := c.Get("MasterSiteID") + siteURL, _ := c.Get("MasterSiteURL") + siteVersion, _ := c.Get("MasterVersion") + + a.Equal("expectedSiteID", siteID.(string)) + a.Equal("expectedSiteURL", siteURL.(string)) + a.Equal("expectedMasterVersion", siteVersion.(string)) +} + +func TestSlaveRPCSignRequired(t *testing.T) { + a := assert.New(t) + np := &cluster.NodePool{} + np.Init() + slaveRPCSignRequiredFunc := SlaveRPCSignRequired(np) + rec := httptest.NewRecorder() + + // id parse failed + { + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Request.Header.Set("X-Node-Id", "unknown") + slaveRPCSignRequiredFunc(c) + a.True(c.IsAborted()) + } + + // node id not exist + { + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Request.Header.Set("X-Node-Id", "38") + slaveRPCSignRequiredFunc(c) + a.True(c.IsAborted()) + } + + // success + { + authInstance := auth.HMACAuth{SecretKey: []byte("")} + np.Add(&model.Node{Model: gorm.Model{ + ID: 38, + }}) + + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest("POST", "/", nil) + c.Request.Header.Set("X-Node-Id", "38") + c.Request = auth.SignRequest(authInstance, c.Request, 0) + slaveRPCSignRequiredFunc(c) + a.False(c.IsAborted()) + } +} + +func TestUseSlaveAria2Instance(t *testing.T) { + a := assert.New(t) + +} diff --git a/models/user.go b/models/user.go index ecd091b4..8045d27f 100644 --- a/models/user.go +++ b/models/user.go @@ -35,7 +35,7 @@ type User struct { Storage uint64 TwoFactor string Avatar string - Options string `json:"-",gorm:"type:text"` + Options string `json:"-" gorm:"type:text"` Authn string `gorm:"type:text"` // 关联模型 diff --git a/pkg/slave/slave.go b/pkg/cluster/controller.go similarity index 96% rename from pkg/slave/slave.go rename to pkg/cluster/controller.go index aa457b5b..d5352ee9 100644 --- a/pkg/slave/slave.go +++ b/pkg/cluster/controller.go @@ -1,4 +1,4 @@ -package slave +package cluster import ( "bytes" @@ -8,7 +8,6 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" @@ -51,13 +50,13 @@ type MasterInfo struct { TTL int URL *url.URL // used to invoke aria2 rpc calls - Instance cluster.Node + Instance Node Client request.Client jobTracker map[string]bool } -func Init() { +func InitController() { DefaultController = &slaveController{ masters: make(map[string]MasterInfo), } @@ -95,7 +94,7 @@ func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializ }, int64(req.CredentialTTL)), ), jobTracker: make(map[string]bool), - Instance: cluster.NewNodeFromDBModel(&model.Node{ + Instance: NewNodeFromDBModel(&model.Node{ Model: gorm.Model{ID: req.Node.ID}, MasterKey: req.Node.MasterKey, Type: model.MasterNodeType, diff --git a/pkg/cluster/errors.go b/pkg/cluster/errors.go index 9afdbef5..84b2ad82 100644 --- a/pkg/cluster/errors.go +++ b/pkg/cluster/errors.go @@ -1,8 +1,12 @@ package cluster -import "errors" +import ( + "errors" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" +) var ( ErrFeatureNotExist = errors.New("No nodes in nodepool match the feature specificed") ErrIlegalPath = errors.New("path out of boundary of setting temp folder") + ErrMasterNotFound = serializer.NewError(serializer.CodeMasterNotFound, "未知的主机节点", nil) ) diff --git a/pkg/cluster/pool.go b/pkg/cluster/pool.go index 2181df42..a2976493 100644 --- a/pkg/cluster/pool.go +++ b/pkg/cluster/pool.go @@ -39,14 +39,22 @@ type NodePool struct { // Init 初始化从机节点池 func Init() { - Default = &NodePool{ - featureMap: make(map[string][]Node), - } + Default = &NodePool{} + Default.Init() if err := Default.initFromDB(); err != nil { util.Log().Warning("节点池初始化失败, %s", err) } } +func (pool *NodePool) Init() { + pool.lock.Lock() + defer pool.lock.Unlock() + + pool.featureMap = make(map[string][]Node) + pool.active = make(map[uint]Node) + pool.inactive = make(map[uint]Node) +} + func (pool *NodePool) buildIndexMap() { pool.lock.Lock() for _, feature := range featureGroup { @@ -98,8 +106,6 @@ func (pool *NodePool) initFromDB() error { } pool.lock.Lock() - pool.active = make(map[uint]Node) - pool.inactive = make(map[uint]Node) for i := 0; i < len(nodes); i++ { pool.add(&nodes[i]) } diff --git a/pkg/filesystem/driver/onedrive/oauth.go b/pkg/filesystem/driver/onedrive/oauth.go index 49170fe7..0d5865b7 100644 --- a/pkg/filesystem/driver/onedrive/oauth.go +++ b/pkg/filesystem/driver/onedrive/oauth.go @@ -3,6 +3,7 @@ package onedrive import ( "context" "encoding/json" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "io/ioutil" "net/http" "net/url" @@ -12,7 +13,6 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/slave" "github.com/cloudreve/Cloudreve/v3/pkg/util" ) @@ -179,7 +179,7 @@ func (client *Client) UpdateCredential(ctx context.Context) error { // UpdateCredential 更新凭证,并检查有效期 func (client *Client) fetchCredentialFromMaster(ctx context.Context) error { - res, err := slave.DefaultController.GetOneDriveToken(client.Policy.MasterID, client.Policy.ID) + res, err := cluster.DefaultController.GetOneDriveToken(client.Policy.MasterID, client.Policy.ID) if err != nil { return err } diff --git a/pkg/slave/errors.go b/pkg/slave/errors.go deleted file mode 100644 index 2af6e13f..00000000 --- a/pkg/slave/errors.go +++ /dev/null @@ -1,7 +0,0 @@ -package slave - -import "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - -var ( - ErrMasterNotFound = serializer.NewError(serializer.CodeMasterNotFound, "未知的主机节点", nil) -) diff --git a/pkg/task/slavetask/transfer.go b/pkg/task/slavetask/transfer.go index c3127425..92310929 100644 --- a/pkg/task/slavetask/transfer.go +++ b/pkg/task/slavetask/transfer.go @@ -3,11 +3,11 @@ package slavetask import ( "context" model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/slave" "github.com/cloudreve/Cloudreve/v3/pkg/task" "github.com/cloudreve/Cloudreve/v3/pkg/util" "os" @@ -68,7 +68,7 @@ func (job *TransferTask) SetErrorMsg(msg string, err error) { }, } - if err := slave.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), notifyMsg); err != nil { + if err := cluster.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), notifyMsg); err != nil { util.Log().Warning("无法发送转存失败通知到从机, ", err) } } @@ -94,7 +94,7 @@ func (job *TransferTask) Do() { return } - master, err := slave.DefaultController.GetMasterInfo(job.MasterID) + master, err := cluster.DefaultController.GetMasterInfo(job.MasterID) if err != nil { job.SetErrorMsg("找不到主机节点", err) return @@ -131,7 +131,7 @@ func (job *TransferTask) Do() { Content: serializer.SlaveTransferResult{}, } - if err := slave.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), msg); err != nil { + if err := cluster.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), msg); err != nil { util.Log().Warning("无法发送转存成功通知到从机, ", err) } } diff --git a/routers/router.go b/routers/router.go index a7204c47..8f335c33 100644 --- a/routers/router.go +++ b/routers/router.go @@ -3,6 +3,7 @@ package routers import ( "github.com/cloudreve/Cloudreve/v3/middleware" "github.com/cloudreve/Cloudreve/v3/pkg/auth" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/hashid" "github.com/cloudreve/Cloudreve/v3/pkg/util" @@ -59,7 +60,7 @@ func InitSlaveRouter() *gin.Engine { // 离线下载 aria2 := v3.Group("aria2") - aria2.Use(middleware.UseSlaveAria2Instance()) + aria2.Use(middleware.UseSlaveAria2Instance(cluster.DefaultController)) { // 创建离线下载任务 aria2.POST("task", controllers.SlaveAria2Create) @@ -205,7 +206,7 @@ func InitMasterRouter() *gin.Engine { // 从机的 RPC 通信 slave := v3.Group("slave") - slave.Use(middleware.SlaveRPCSignRequired()) + slave.Use(middleware.SlaveRPCSignRequired(cluster.Default)) { // 事件通知 slave.PUT("notification/:subject", controllers.SlaveNotificationPush) diff --git a/service/aria2/add.go b/service/aria2/add.go index 26b6baaa..73446b4d 100644 --- a/service/aria2/add.go +++ b/service/aria2/add.go @@ -9,7 +9,6 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/slave" "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/gin-gonic/gin" ) @@ -91,7 +90,7 @@ func Add(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response // 创建事件通知回调 siteID, _ := c.Get("MasterSiteID") mq.GlobalMQ.SubscribeCallback(gid, func(message mq.Message) { - if err := slave.DefaultController.SendNotification(siteID.(string), message.TriggeredBy, message); err != nil { + if err := cluster.DefaultController.SendNotification(siteID.(string), message.TriggeredBy, message); err != nil { util.Log().Warning("无法发送离线下载任务状态变更通知, %s", err) } }) diff --git a/service/explorer/slave.go b/service/explorer/slave.go index 8beb15b8..54638eea 100644 --- a/service/explorer/slave.go +++ b/service/explorer/slave.go @@ -6,10 +6,10 @@ import ( "encoding/json" "fmt" model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/slave" "github.com/cloudreve/Cloudreve/v3/pkg/task" "github.com/cloudreve/Cloudreve/v3/pkg/task/slavetask" "github.com/gin-gonic/gin" @@ -153,7 +153,7 @@ func CreateTransferTask(c *gin.Context, req *serializer.SlaveTransferReq) serial MasterID: id.(string), } - if err := slave.DefaultController.SubmitTask(job.MasterID, job, req.Hash(job.MasterID), func(job interface{}) { + if err := cluster.DefaultController.SubmitTask(job.MasterID, job, req.Hash(job.MasterID), func(job interface{}) { task.TaskPoll.Submit(job.(task.Job)) }); err != nil { return serializer.Err(serializer.CodeInternalSetting, "任务创建失败", err) diff --git a/service/node/fabric.go b/service/node/fabric.go index 79dfb29d..63b5ecf4 100644 --- a/service/node/fabric.go +++ b/service/node/fabric.go @@ -3,10 +3,10 @@ package node import ( "encoding/gob" model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/onedrive" "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/slave" "github.com/gin-gonic/gin" ) @@ -19,7 +19,7 @@ type OneDriveCredentialService struct { } func HandleMasterHeartbeat(req *serializer.NodePingReq) serializer.Response { - res, err := slave.DefaultController.HandleHeartBeat(req) + res, err := cluster.DefaultController.HandleHeartBeat(req) if err != nil { return serializer.Err(serializer.CodeInternalSetting, "Cannot initialize slave controller", err) } From 3064ed60f305dbb2f159537b00ef005e0cbae9f4 Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Mon, 8 Nov 2021 20:49:07 +0800 Subject: [PATCH 2/2] Test: new database models and middlewares --- middleware/cluster_test.go | 76 ++++++++++++++++++++++++++++++++++++++ models/download_test.go | 11 ++++++ models/node_test.go | 64 ++++++++++++++++++++++++++++++++ models/share_test.go | 12 ------ models/user_test.go | 13 +------ 5 files changed, 153 insertions(+), 23 deletions(-) create mode 100644 models/node_test.go diff --git a/middleware/cluster_test.go b/middleware/cluster_test.go index 2c25e292..e1e61e36 100644 --- a/middleware/cluster_test.go +++ b/middleware/cluster_test.go @@ -1,12 +1,17 @@ package middleware import ( + "errors" model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/gin-gonic/gin" "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" + testMock "github.com/stretchr/testify/mock" "net/http/httptest" "testing" ) @@ -74,7 +79,78 @@ func TestSlaveRPCSignRequired(t *testing.T) { } } +type SlaveControllerMock struct { + testMock.Mock +} + +func (s SlaveControllerMock) HandleHeartBeat(pingReq *serializer.NodePingReq) (serializer.NodePingResp, error) { + args := s.Called(pingReq) + return args.Get(0).(serializer.NodePingResp), args.Error(1) +} + +func (s SlaveControllerMock) GetAria2Instance(s2 string) (common.Aria2, error) { + args := s.Called(s2) + return args.Get(0).(common.Aria2), args.Error(1) +} + +func (s SlaveControllerMock) SendNotification(s3 string, s2 string, message mq.Message) error { + args := s.Called(s3, s2, message) + return args.Error(0) +} + +func (s SlaveControllerMock) SubmitTask(s3 string, i interface{}, s2 string, f func(interface{})) error { + args := s.Called(s3, i, s2, f) + return args.Error(0) +} + +func (s SlaveControllerMock) GetMasterInfo(s2 string) (*cluster.MasterInfo, error) { + args := s.Called(s2) + return args.Get(0).(*cluster.MasterInfo), args.Error(1) +} + +func (s SlaveControllerMock) GetOneDriveToken(s2 string, u uint) (string, error) { + args := s.Called(s2, u) + return args.String(0), args.Error(1) +} + func TestUseSlaveAria2Instance(t *testing.T) { a := assert.New(t) + // MasterSiteID not set + { + testController := &SlaveControllerMock{} + useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + useSlaveAria2InstanceFunc(c) + a.True(c.IsAborted()) + } + + // Cannot get aria2 instances + { + testController := &SlaveControllerMock{} + useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Set("MasterSiteID", "expectedSiteID") + testController.On("GetAria2Instance", "expectedSiteID").Return(&common.DummyAria2{}, errors.New("error")) + useSlaveAria2InstanceFunc(c) + a.True(c.IsAborted()) + testController.AssertExpectations(t) + } + + // Success + { + testController := &SlaveControllerMock{} + useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Set("MasterSiteID", "expectedSiteID") + testController.On("GetAria2Instance", "expectedSiteID").Return(&common.DummyAria2{}, nil) + useSlaveAria2InstanceFunc(c) + a.False(c.IsAborted()) + res, _ := c.Get("MasterAria2Instance") + a.NotNil(res) + testController.AssertExpectations(t) + } } diff --git a/models/download_test.go b/models/download_test.go index 9d9cd34d..367afb78 100644 --- a/models/download_test.go +++ b/models/download_test.go @@ -177,3 +177,14 @@ func TestDownload_Delete(t *testing.T) { } } + +func TestDownload_GetNodeID(t *testing.T) { + a := assert.New(t) + record := Download{} + + // compatible with 3.4 + a.EqualValues(1, record.GetNodeID()) + + record.NodeID = 5 + a.EqualValues(5, record.GetNodeID()) +} diff --git a/models/node_test.go b/models/node_test.go new file mode 100644 index 00000000..ddc1f95b --- /dev/null +++ b/models/node_test.go @@ -0,0 +1,64 @@ +package model + +import ( + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestGetNodeByID(t *testing.T) { + a := assert.New(t) + mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + res, err := GetNodeByID(1) + a.NoError(err) + a.EqualValues(1, res.ID) + a.NoError(mock.ExpectationsWereMet()) +} + +func TestGetNodesByStatus(t *testing.T) { + a := assert.New(t) + mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"status"}).AddRow(NodeActive)) + res, err := GetNodesByStatus(NodeActive) + a.NoError(err) + a.Len(res, 1) + a.EqualValues(NodeActive, res[0].Status) + a.NoError(mock.ExpectationsWereMet()) +} + +func TestNode_AfterFind(t *testing.T) { + a := assert.New(t) + node := &Node{} + + // No aria2 options + { + a.NoError(node.AfterFind()) + } + + // with aria2 options + { + node.Aria2Options = `{"timeout":1}` + a.NoError(node.AfterFind()) + a.Equal(1, node.Aria2OptionsSerialized.Timeout) + } +} + +func TestNode_BeforeSave(t *testing.T) { + a := assert.New(t) + node := &Node{} + + node.Aria2OptionsSerialized.Timeout = 1 + a.NoError(node.BeforeSave()) + a.Contains("1", node.Aria2Options) +} + +func TestNode_SetStatus(t *testing.T) { + a := assert.New(t) + node := &Node{} + + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)nodes").WithArgs(NodeActive, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + a.NoError(node.SetStatus(NodeActive)) + a.Equal(NodeActive, node.Status) + a.NoError(mock.ExpectationsWereMet()) +} diff --git a/models/share_test.go b/models/share_test.go index 52e2ee66..b3fdf0a1 100644 --- a/models/share_test.go +++ b/models/share_test.go @@ -188,18 +188,6 @@ func TestShare_CanBeDownloadBy(t *testing.T) { asserts.Error(share.CanBeDownloadBy(user)) } - // 未登录,需要积分 - { - user := &User{ - Group: Group{ - OptionsSerialized: GroupOption{ - ShareDownload: true, - }, - }, - } - asserts.Error(share.CanBeDownloadBy(user)) - } - // 成功 { user := &User{ diff --git a/models/user_test.go b/models/user_test.go index 5b4d375c..a85ddbd3 100644 --- a/models/user_test.go +++ b/models/user_test.go @@ -177,10 +177,10 @@ func TestNewUser(t *testing.T) { func TestUser_AfterFind(t *testing.T) { asserts := assert.New(t) - cache.Deletes([]string{"1"}, "policy_") + cache.Deletes([]string{"0"}, "policy_") policyRows := sqlmock.NewRows([]string{"id", "name"}). - AddRow(1, "默认存储策略") + AddRow(144, "默认存储策略") mock.ExpectQuery("^SELECT (.+)").WillReturnRows(policyRows) newUser := NewUser() @@ -240,11 +240,6 @@ func TestUser_GetRemainingCapacity(t *testing.T) { newUser.Group.MaxStorage = 100 newUser.Storage = 200 asserts.Equal(uint64(0), newUser.GetRemainingCapacity()) - - cache.Set("pack_size_0", uint64(10), 0) - newUser.Group.MaxStorage = 100 - newUser.Storage = 101 - asserts.Equal(uint64(9), newUser.GetRemainingCapacity()) } func TestUser_DeductionCapacity(t *testing.T) { @@ -280,10 +275,6 @@ func TestUser_DeductionCapacity(t *testing.T) { asserts.Equal(false, newUser.IncreaseStorage(1)) asserts.Equal(uint64(100), newUser.Storage) - cache.Set("pack_size_1", uint64(1), 0) - asserts.Equal(true, newUser.IncreaseStorage(1)) - asserts.Equal(uint64(101), newUser.Storage) - asserts.True(newUser.IncreaseStorage(0)) }