diff --git a/pkg/aria2/common/common_test.go b/pkg/aria2/common/common_test.go index 52c555f..a93f5f8 100644 --- a/pkg/aria2/common/common_test.go +++ b/pkg/aria2/common/common_test.go @@ -26,7 +26,7 @@ func TestDummyAria2(t *testing.T) { a.Error(err) configRes := d.GetConfig() - a.NotEmpty(configRes) + a.NotNil(configRes) err = d.DeleteTempFile(&model.Download{}) a.Error(err) diff --git a/pkg/cluster/controller.go b/pkg/cluster/controller.go index d5352ee..c597d04 100644 --- a/pkg/cluster/controller.go +++ b/pkg/cluster/controller.go @@ -161,6 +161,7 @@ func (c *slaveController) SubmitTask(id string, job interface{}, hash string, su return nil } + node.jobTracker[hash] = true submitter(job) return nil } diff --git a/pkg/cluster/controller_test.go b/pkg/cluster/controller_test.go index 0ee8651..305856a 100644 --- a/pkg/cluster/controller_test.go +++ b/pkg/cluster/controller_test.go @@ -252,3 +252,134 @@ func TestSlaveController_SendNotification(t *testing.T) { mockRequest.AssertExpectations(t) } } + +func TestSlaveController_SubmitTask(t *testing.T) { + a := assert.New(t) + c := &slaveController{ + masters: map[string]MasterInfo{ + "1": { + jobTracker: map[string]bool{}, + }, + }, + } + + // node not exit + { + a.Equal(ErrMasterNotFound, c.SubmitTask("2", "", "", nil)) + } + + // success + { + submitted := false + a.NoError(c.SubmitTask("1", "", "hash", func(i interface{}) { + submitted = true + })) + a.True(submitted) + } + + // job already submitted + { + submitted := false + a.NoError(c.SubmitTask("1", "", "hash", func(i interface{}) { + submitted = true + })) + a.False(submitted) + } +} + +func TestSlaveController_GetMasterInfo(t *testing.T) { + a := assert.New(t) + c := &slaveController{ + masters: map[string]MasterInfo{ + "1": {}, + }, + } + + // node not exit + { + res, err := c.GetMasterInfo("2") + a.Equal(ErrMasterNotFound, err) + a.Nil(res) + } + + // success + { + res, err := c.GetMasterInfo("1") + a.NoError(err) + a.NotNil(res) + } +} + +func TestSlaveController_GetOneDriveToken(t *testing.T) { + a := assert.New(t) + c := &slaveController{ + masters: map[string]MasterInfo{ + "1": {}, + }, + } + + // node not exit + { + res, err := c.GetOneDriveToken("2", 1) + a.Equal(ErrMasterNotFound, err) + a.Empty(res) + } + + // return none 200 + { + mockRequest := &requestMock{} + mockRequest.On("Request", "GET", "/api/v3/slave/credential/onedrive/1", testMock.Anything, testMock.Anything).Return(&request.Response{ + Response: &http.Response{StatusCode: http.StatusConflict}, + }) + c := &slaveController{ + masters: map[string]MasterInfo{ + "1": {Client: mockRequest}, + }, + } + res, err := c.GetOneDriveToken("1", 1) + a.Error(err) + a.Empty(res) + mockRequest.AssertExpectations(t) + } + + // master return error + { + mockRequest := &requestMock{} + mockRequest.On("Request", "GET", "/api/v3/slave/credential/onedrive/1", testMock.Anything, testMock.Anything).Return(&request.Response{ + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")), + }, + }) + c := &slaveController{ + masters: map[string]MasterInfo{ + "1": {Client: mockRequest}, + }, + } + res, err := c.GetOneDriveToken("1", 1) + a.Equal(1, err.(serializer.AppError).Code) + a.Empty(res) + mockRequest.AssertExpectations(t) + } + + // success + { + mockRequest := &requestMock{} + mockRequest.On("Request", "GET", "/api/v3/slave/credential/onedrive/1", testMock.Anything, testMock.Anything).Return(&request.Response{ + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"expected\"}")), + }, + }) + c := &slaveController{ + masters: map[string]MasterInfo{ + "1": {Client: mockRequest}, + }, + } + res, err := c.GetOneDriveToken("1", 1) + a.NoError(err) + a.Equal("expected", res) + mockRequest.AssertExpectations(t) + } + +} diff --git a/pkg/cluster/master.go b/pkg/cluster/master.go index e877920..885e99a 100644 --- a/pkg/cluster/master.go +++ b/pkg/cluster/master.go @@ -20,7 +20,10 @@ import ( "time" ) -const deleteTempFileDuration = 60 * time.Second +const ( + deleteTempFileDuration = 60 * time.Second + statusRetryDuration = 10 * time.Second +) type MasterNode struct { Model *model.Node @@ -33,8 +36,10 @@ type rpcService struct { Caller rpc.Client Initialized bool - parent *MasterNode - options *clientOptions + retryDuration time.Duration + deletePaddingDuration time.Duration + parent *MasterNode + options *clientOptions } type clientOptions struct { @@ -46,6 +51,8 @@ func (node *MasterNode) Init(nodeModel *model.Node) { node.lock.Lock() node.Model = nodeModel node.aria2RPC.parent = node + node.aria2RPC.retryDuration = statusRetryDuration + node.aria2RPC.deletePaddingDuration = deleteTempFileDuration node.lock.Unlock() node.lock.RLock() @@ -214,8 +221,8 @@ func (r *rpcService) Status(task *model.Download) (rpc.StatusInfo, error) { res, err := r.Caller.TellStatus(task.GID) if err != nil { // 失败后重试 - util.Log().Debug("无法获取离线下载状态,%s,10秒钟后重试", err) - time.Sleep(time.Duration(10) * time.Second) + util.Log().Debug("无法获取离线下载状态,%s,稍后重试", err) + time.Sleep(r.retryDuration) res, err = r.Caller.TellStatus(task.GID) } @@ -253,13 +260,13 @@ func (s *rpcService) DeleteTempFile(task *model.Download) error { defer s.parent.lock.RUnlock() // 避免被aria2占用,异步执行删除 - go func(src string) { - time.Sleep(deleteTempFileDuration) + go func(d time.Duration, src string) { + time.Sleep(d) err := os.RemoveAll(src) if err != nil { util.Log().Warning("无法删除离线下载临时目录[%s], %s", src, err) } - }(task.Parent) + }(s.deletePaddingDuration, task.Parent) return nil } diff --git a/pkg/cluster/master_test.go b/pkg/cluster/master_test.go new file mode 100644 index 0000000..7ff07ac --- /dev/null +++ b/pkg/cluster/master_test.go @@ -0,0 +1,186 @@ +package cluster + +import ( + "context" + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/stretchr/testify/assert" + "os" + "testing" + "time" +) + +func TestMasterNode_Init(t *testing.T) { + a := assert.New(t) + m := &MasterNode{} + m.Init(&model.Node{Status: model.NodeSuspend}) + a.Equal(model.NodeSuspend, m.DBModel().Status) + m.Init(&model.Node{Aria2Enabled: true}) +} + +func TestMasterNode_DummyMethods(t *testing.T) { + a := assert.New(t) + m := &MasterNode{ + Model: &model.Node{}, + } + + m.Model.ID = 5 + a.Equal(m.Model.ID, m.ID()) + + res, err := m.Ping(&serializer.NodePingReq{}) + a.NoError(err) + a.NotNil(res) + + a.True(m.IsActive()) + a.True(m.IsMater()) + + m.SubscribeStatusChange(func(isActive bool, id uint) {}) +} + +func TestMasterNode_IsFeatureEnabled(t *testing.T) { + a := assert.New(t) + m := &MasterNode{ + Model: &model.Node{}, + } + + a.False(m.IsFeatureEnabled("aria2")) + a.False(m.IsFeatureEnabled("random")) + m.Model.Aria2Enabled = true + a.True(m.IsFeatureEnabled("aria2")) +} + +func TestMasterNode_AuthInstance(t *testing.T) { + a := assert.New(t) + m := &MasterNode{ + Model: &model.Node{}, + } + + a.NotNil(m.MasterAuthInstance()) + a.NotNil(m.SlaveAuthInstance()) +} + +func TestMasterNode_Kill(t *testing.T) { + m := &MasterNode{ + Model: &model.Node{}, + } + + m.Kill() + + caller, _ := rpc.New(context.Background(), "http://", "", 0, nil) + m.aria2RPC.Caller = caller + m.Kill() +} + +func TestMasterNode_GetAria2Instance(t *testing.T) { + a := assert.New(t) + m := &MasterNode{ + Model: &model.Node{}, + aria2RPC: rpcService{}, + } + + m.aria2RPC.parent = m + + a.NotNil(m.GetAria2Instance()) + m.Model.Aria2Enabled = true + a.NotNil(m.GetAria2Instance()) + m.aria2RPC.Initialized = true + a.NotNil(m.GetAria2Instance()) +} + +func TestRpcService_Init(t *testing.T) { + a := assert.New(t) + m := &MasterNode{ + Model: &model.Node{ + Aria2OptionsSerialized: model.Aria2Option{ + Options: "{", + }, + }, + aria2RPC: rpcService{}, + } + m.aria2RPC.parent = m + + // failed to decode address + { + m.Model.Aria2OptionsSerialized.Server = string([]byte{0x7f}) + a.Error(m.aria2RPC.Init()) + } + + // failed to decode options + { + m.Model.Aria2OptionsSerialized.Server = "" + a.Error(m.aria2RPC.Init()) + } + + // failed to initialized + { + m.Model.Aria2OptionsSerialized.Server = "" + m.Model.Aria2OptionsSerialized.Options = "{}" + caller, _ := rpc.New(context.Background(), "http://", "", 0, nil) + m.aria2RPC.Caller = caller + a.Error(m.aria2RPC.Init()) + a.False(m.aria2RPC.Initialized) + } +} + +func getTestRPCNode() *MasterNode { + m := &MasterNode{ + Model: &model.Node{ + Aria2OptionsSerialized: model.Aria2Option{}, + }, + aria2RPC: rpcService{ + options: &clientOptions{ + Options: map[string]interface{}{"1": "1"}, + }, + }, + } + m.aria2RPC.parent = m + caller, _ := rpc.New(context.Background(), "http://", "", 0, nil) + m.aria2RPC.Caller = caller + return m +} + +func TestRpcService_CreateTask(t *testing.T) { + a := assert.New(t) + m := getTestRPCNode() + + res, err := m.aria2RPC.CreateTask(&model.Download{}, map[string]interface{}{"1": "1"}) + a.Error(err) + a.Empty(res) +} + +func TestRpcService_Status(t *testing.T) { + a := assert.New(t) + m := getTestRPCNode() + + res, err := m.aria2RPC.Status(&model.Download{}) + a.Error(err) + a.Empty(res) +} + +func TestRpcService_Cancel(t *testing.T) { + a := assert.New(t) + m := getTestRPCNode() + + a.Error(m.aria2RPC.Cancel(&model.Download{})) +} + +func TestRpcService_Select(t *testing.T) { + a := assert.New(t) + m := getTestRPCNode() + + a.NotNil(m.aria2RPC.GetConfig()) + a.Error(m.aria2RPC.Select(&model.Download{}, []int{1, 2, 3})) +} + +func TestRpcService_DeleteTempFile(t *testing.T) { + a := assert.New(t) + m := getTestRPCNode() + fdName := "TestRpcService_DeleteTempFile" + a.NoError(os.Mkdir(fdName, 0644)) + + a.NoError(m.aria2RPC.DeleteTempFile(&model.Download{Parent: fdName})) + time.Sleep(500 * time.Millisecond) + a.False(util.Exists(fdName)) +} diff --git a/pkg/cluster/node_test.go b/pkg/cluster/node_test.go new file mode 100644 index 0000000..d817425 --- /dev/null +++ b/pkg/cluster/node_test.go @@ -0,0 +1,17 @@ +package cluster + +import ( + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestNewNodeFromDBModel(t *testing.T) { + a := assert.New(t) + a.IsType(&SlaveNode{}, NewNodeFromDBModel(&model.Node{ + Type: model.SlaveNodeType, + })) + a.IsType(&MasterNode{}, NewNodeFromDBModel(&model.Node{ + Type: model.MasterNodeType, + })) +} diff --git a/pkg/cluster/pool_test.go b/pkg/cluster/pool_test.go new file mode 100644 index 0000000..5af3fc4 --- /dev/null +++ b/pkg/cluster/pool_test.go @@ -0,0 +1,64 @@ +package cluster + +import ( + "database/sql" + "errors" + "github.com/DATA-DOG/go-sqlmock" + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/jinzhu/gorm" + "github.com/stretchr/testify/assert" + "testing" +) + +var mock sqlmock.Sqlmock + +// TestMain 初始化数据库Mock +func TestMain(m *testing.M) { + var db *sql.DB + var err error + db, mock, err = sqlmock.New() + if err != nil { + panic("An error was not expected when opening a stub database connection") + } + model.DB, _ = gorm.Open("mysql", db) + defer db.Close() + m.Run() +} + +func TestInitFailed(t *testing.T) { + a := assert.New(t) + mock.ExpectQuery("SELECT(.+)").WillReturnError(errors.New("error")) + Init() + a.NoError(mock.ExpectationsWereMet()) +} + +func TestInitSuccess(t *testing.T) { + a := assert.New(t) + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "aria2_enabled", "type"}).AddRow(1, true, model.MasterNodeType)) + Init() + a.NoError(mock.ExpectationsWereMet()) +} + +func TestNodePool_GetNodeByID(t *testing.T) { + a := assert.New(t) + p := &NodePool{} + p.Init() + mockNode := &nodeMock{} + + // inactive + { + p.inactive[1] = mockNode + a.Equal(mockNode, p.GetNodeByID(1)) + } + + // active + { + delete(p.inactive, 1) + p.active[1] = mockNode + a.Equal(mockNode, p.GetNodeByID(1)) + } +} + +func TestNodePool_NodeStatusChange(t *testing.T) { + +}