diff --git a/pkg/cluster/pool.go b/pkg/cluster/pool.go index a297649..710b0a3 100644 --- a/pkg/cluster/pool.go +++ b/pkg/cluster/pool.go @@ -84,14 +84,19 @@ func (pool *NodePool) GetNodeByID(id uint) Node { func (pool *NodePool) nodeStatusChange(isActive bool, id uint) { util.Log().Debug("从机节点 [ID=%d] 状态变更 [Active=%t]", id, isActive) + var node Node pool.lock.Lock() - if isActive { - node := pool.inactive[id] + if n, ok := pool.inactive[id]; ok { + node = n delete(pool.inactive, id) - pool.active[id] = node } else { - node := pool.active[id] + node = pool.active[id] delete(pool.active, id) + } + + if isActive { + pool.active[id] = node + } else { pool.inactive[id] = node } pool.lock.Unlock() diff --git a/pkg/cluster/pool_test.go b/pkg/cluster/pool_test.go index 5af3fc4..dde3455 100644 --- a/pkg/cluster/pool_test.go +++ b/pkg/cluster/pool_test.go @@ -5,6 +5,7 @@ import ( "errors" "github.com/DATA-DOG/go-sqlmock" model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/balancer" "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" "testing" @@ -60,5 +61,101 @@ func TestNodePool_GetNodeByID(t *testing.T) { } func TestNodePool_NodeStatusChange(t *testing.T) { + a := assert.New(t) + p := &NodePool{} + n := &MasterNode{Model: &model.Node{}} + p.Init() + p.inactive[1] = n + + p.nodeStatusChange(true, 1) + a.Len(p.inactive, 0) + a.Equal(n, p.active[1]) + + p.nodeStatusChange(false, 1) + a.Len(p.active, 0) + a.Equal(n, p.inactive[1]) + + p.nodeStatusChange(false, 1) + a.Len(p.active, 0) + a.Equal(n, p.inactive[1]) +} + +func TestNodePool_Add(t *testing.T) { + a := assert.New(t) + p := &NodePool{} + p.Init() + // new node + { + p.Add(&model.Node{}) + a.Len(p.active, 1) + } + + // old node + { + p.inactive[0] = p.active[0] + delete(p.active, 0) + p.Add(&model.Node{}) + a.Len(p.active, 0) + a.Len(p.inactive, 1) + } +} + +func TestNodePool_Delete(t *testing.T) { + a := assert.New(t) + p := &NodePool{} + p.Init() + + // active + { + mockNode := &nodeMock{} + mockNode.On("Kill") + p.active[0] = mockNode + p.Delete(0) + a.Len(p.active, 0) + a.Len(p.inactive, 0) + mockNode.AssertExpectations(t) + } + + p.Init() + + // inactive + { + mockNode := &nodeMock{} + mockNode.On("Kill") + p.inactive[0] = mockNode + p.Delete(0) + a.Len(p.active, 0) + a.Len(p.inactive, 0) + mockNode.AssertExpectations(t) + } +} + +func TestNodePool_BalanceNodeByFeature(t *testing.T) { + a := assert.New(t) + p := &NodePool{} + p.Init() + + // success + { + p.featureMap["test"] = []Node{&MasterNode{}} + err, res := p.BalanceNodeByFeature("test", balancer.NewBalancer("round-robin")) + a.NoError(err) + a.Equal(p.featureMap["test"][0], res) + } + + // NoNodes + { + p.featureMap["test"] = []Node{} + err, res := p.BalanceNodeByFeature("test", balancer.NewBalancer("round-robin")) + a.Error(err) + a.Nil(res) + } + + // No match feature + { + err, res := p.BalanceNodeByFeature("test2", balancer.NewBalancer("round-robin")) + a.Error(err) + a.Nil(res) + } } diff --git a/pkg/cluster/slave_test.go b/pkg/cluster/slave_test.go new file mode 100644 index 0000000..0b70caa --- /dev/null +++ b/pkg/cluster/slave_test.go @@ -0,0 +1,443 @@ +package cluster + +import ( + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/cache" + "github.com/cloudreve/Cloudreve/v3/pkg/request" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/stretchr/testify/assert" + testMock "github.com/stretchr/testify/mock" + "io/ioutil" + "net/http" + "strings" + "testing" + "time" +) + +func TestSlaveNode_InitAndKill(t *testing.T) { + a := assert.New(t) + n := &SlaveNode{ + callback: func(b bool, u uint) { + + }, + } + + a.NotPanics(func() { + n.Init(&model.Node{}) + time.Sleep(time.Millisecond * 500) + n.Init(&model.Node{}) + n.Kill() + }) +} + +func TestSlaveNode_DummyMethods(t *testing.T) { + a := assert.New(t) + m := &SlaveNode{ + Model: &model.Node{}, + } + + m.Model.ID = 5 + a.Equal(m.Model.ID, m.ID()) + a.Equal(m.Model.ID, m.DBModel().ID) + + a.False(m.IsActive()) + a.False(m.IsMater()) + + m.SubscribeStatusChange(func(isActive bool, id uint) {}) +} + +func TestSlaveNode_IsFeatureEnabled(t *testing.T) { + a := assert.New(t) + m := &SlaveNode{ + Model: &model.Node{}, + } + + a.False(m.IsFeatureEnabled("aria2")) + a.False(m.IsFeatureEnabled("random")) + m.Model.Aria2Enabled = true + a.True(m.IsFeatureEnabled("aria2")) +} + +func TestSlaveNode_Ping(t *testing.T) { + a := assert.New(t) + m := &SlaveNode{ + Model: &model.Node{}, + } + + // master return error code + { + mockRequest := &requestMock{} + mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{ + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")), + }, + }) + m.caller.Client = mockRequest + res, err := m.Ping(&serializer.NodePingReq{}) + a.Error(err) + a.Nil(res) + a.Equal(1, err.(serializer.AppError).Code) + } + + // return unexpected json + { + mockRequest := &requestMock{} + mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{ + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"233\"}")), + }, + }) + m.caller.Client = mockRequest + res, err := m.Ping(&serializer.NodePingReq{}) + a.Error(err) + a.Nil(res) + } + + // return success + { + mockRequest := &requestMock{} + mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{ + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"{}\"}")), + }, + }) + m.caller.Client = mockRequest + res, err := m.Ping(&serializer.NodePingReq{}) + a.NoError(err) + a.NotNil(res) + } +} + +func TestSlaveNode_GetAria2Instance(t *testing.T) { + a := assert.New(t) + m := &SlaveNode{ + Model: &model.Node{}, + } + + a.NotNil(m.GetAria2Instance()) + m.Model.Aria2Enabled = true + a.NotNil(m.GetAria2Instance()) + a.NotNil(m.GetAria2Instance()) +} + +func TestSlaveNode_StartPingLoop(t *testing.T) { + callbackCount := 0 + finishedChan := make(chan struct{}) + mockRequest := requestMock{} + mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{ + Response: &http.Response{ + StatusCode: 404, + }, + }) + m := &SlaveNode{ + Active: true, + Model: &model.Node{}, + callback: func(b bool, u uint) { + callbackCount++ + if callbackCount == 2 { + close(finishedChan) + } + if callbackCount == 1 { + mockRequest.AssertExpectations(t) + mockRequest = requestMock{} + mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{ + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"{}\"}")), + }, + }) + } + }, + } + cache.Set("setting_slave_ping_interval", "0", 0) + cache.Set("setting_slave_recover_interval", "0", 0) + cache.Set("setting_slave_node_retry", "1", 0) + + m.caller.Client = &mockRequest + go func() { + select { + case <-finishedChan: + m.Kill() + } + }() + m.StartPingLoop() + mockRequest.AssertExpectations(t) +} + +func TestSlaveNode_AuthInstance(t *testing.T) { + a := assert.New(t) + m := &SlaveNode{ + Model: &model.Node{}, + } + + a.NotNil(m.MasterAuthInstance()) + a.NotNil(m.SlaveAuthInstance()) +} + +func TestSlaveNode_ChangeStatus(t *testing.T) { + a := assert.New(t) + isActive := false + m := &SlaveNode{ + Model: &model.Node{}, + callback: func(b bool, u uint) { + isActive = b + }, + } + + a.NotPanics(func() { + m.changeStatus(false) + }) + m.changeStatus(true) + a.True(isActive) +} + +func getTestRPCNodeSlave() *SlaveNode { + m := &SlaveNode{ + Model: &model.Node{}, + } + m.caller.parent = m + return m +} + +func TestSlaveCaller_CreateTask(t *testing.T) { + a := assert.New(t) + m := getTestRPCNodeSlave() + + // master return 404 + { + mockRequest := requestMock{} + mockRequest.On("Request", "POST", "aria2/task", testMock.Anything, testMock.Anything).Return(&request.Response{ + Response: &http.Response{ + StatusCode: 404, + }, + }) + m.caller.Client = mockRequest + res, err := m.caller.CreateTask(&model.Download{}, nil) + a.Empty(res) + a.Error(err) + } + + // master return error + { + mockRequest := requestMock{} + mockRequest.On("Request", "POST", "aria2/task", testMock.Anything, testMock.Anything).Return(&request.Response{ + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")), + }, + }) + m.caller.Client = mockRequest + res, err := m.caller.CreateTask(&model.Download{}, nil) + a.Empty(res) + a.Error(err) + } + + // master return success + { + mockRequest := requestMock{} + mockRequest.On("Request", "POST", "aria2/task", testMock.Anything, testMock.Anything).Return(&request.Response{ + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")), + }, + }) + m.caller.Client = mockRequest + res, err := m.caller.CreateTask(&model.Download{}, nil) + a.Equal("res", res) + a.NoError(err) + } +} + +func TestSlaveCaller_Status(t *testing.T) { + a := assert.New(t) + m := getTestRPCNodeSlave() + + // master return 404 + { + mockRequest := requestMock{} + mockRequest.On("Request", "POST", "aria2/status", testMock.Anything, testMock.Anything).Return(&request.Response{ + Response: &http.Response{ + StatusCode: 404, + }, + }) + m.caller.Client = mockRequest + res, err := m.caller.Status(&model.Download{}) + a.Empty(res.Status) + a.Error(err) + } + + // master return error + { + mockRequest := requestMock{} + mockRequest.On("Request", "POST", "aria2/status", testMock.Anything, testMock.Anything).Return(&request.Response{ + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")), + }, + }) + m.caller.Client = mockRequest + res, err := m.caller.Status(&model.Download{}) + a.Empty(res.Status) + a.Error(err) + } + + // master return success + { + mockRequest := requestMock{} + mockRequest.On("Request", "POST", "aria2/status", testMock.Anything, testMock.Anything).Return(&request.Response{ + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"re456456s\"}")), + }, + }) + m.caller.Client = mockRequest + res, err := m.caller.Status(&model.Download{}) + a.Empty(res.Status) + a.NoError(err) + } +} + +func TestSlaveCaller_Cancel(t *testing.T) { + a := assert.New(t) + m := getTestRPCNodeSlave() + + // master return 404 + { + mockRequest := requestMock{} + mockRequest.On("Request", "POST", "aria2/cancel", testMock.Anything, testMock.Anything).Return(&request.Response{ + Response: &http.Response{ + StatusCode: 404, + }, + }) + m.caller.Client = mockRequest + err := m.caller.Cancel(&model.Download{}) + a.Error(err) + } + + // master return error + { + mockRequest := requestMock{} + mockRequest.On("Request", "POST", "aria2/cancel", testMock.Anything, testMock.Anything).Return(&request.Response{ + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")), + }, + }) + m.caller.Client = mockRequest + err := m.caller.Cancel(&model.Download{}) + a.Error(err) + } + + // master return success + { + mockRequest := requestMock{} + mockRequest.On("Request", "POST", "aria2/cancel", testMock.Anything, testMock.Anything).Return(&request.Response{ + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")), + }, + }) + m.caller.Client = mockRequest + err := m.caller.Cancel(&model.Download{}) + a.NoError(err) + } +} + +func TestSlaveCaller_Select(t *testing.T) { + a := assert.New(t) + m := getTestRPCNodeSlave() + m.caller.Init() + m.caller.GetConfig() + + // master return 404 + { + mockRequest := requestMock{} + mockRequest.On("Request", "POST", "aria2/select", testMock.Anything, testMock.Anything).Return(&request.Response{ + Response: &http.Response{ + StatusCode: 404, + }, + }) + m.caller.Client = mockRequest + err := m.caller.Select(&model.Download{}, nil) + a.Error(err) + } + + // master return error + { + mockRequest := requestMock{} + mockRequest.On("Request", "POST", "aria2/select", testMock.Anything, testMock.Anything).Return(&request.Response{ + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")), + }, + }) + m.caller.Client = mockRequest + err := m.caller.Select(&model.Download{}, nil) + a.Error(err) + } + + // master return success + { + mockRequest := requestMock{} + mockRequest.On("Request", "POST", "aria2/select", testMock.Anything, testMock.Anything).Return(&request.Response{ + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")), + }, + }) + m.caller.Client = mockRequest + err := m.caller.Select(&model.Download{}, nil) + a.NoError(err) + } +} + +func TestSlaveCaller_DeleteTempFile(t *testing.T) { + a := assert.New(t) + m := getTestRPCNodeSlave() + m.caller.Init() + m.caller.GetConfig() + + // master return 404 + { + mockRequest := requestMock{} + mockRequest.On("Request", "POST", "aria2/delete", testMock.Anything, testMock.Anything).Return(&request.Response{ + Response: &http.Response{ + StatusCode: 404, + }, + }) + m.caller.Client = mockRequest + err := m.caller.DeleteTempFile(&model.Download{}) + a.Error(err) + } + + // master return error + { + mockRequest := requestMock{} + mockRequest.On("Request", "POST", "aria2/delete", testMock.Anything, testMock.Anything).Return(&request.Response{ + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")), + }, + }) + m.caller.Client = mockRequest + err := m.caller.DeleteTempFile(&model.Download{}) + a.Error(err) + } + + // master return success + { + mockRequest := requestMock{} + mockRequest.On("Request", "POST", "aria2/delete", testMock.Anything, testMock.Anything).Return(&request.Response{ + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")), + }, + }) + m.caller.Client = mockRequest + err := m.caller.DeleteTempFile(&model.Download{}) + a.NoError(err) + } +}