From 416f4c1dd22ccdf3e96ee848261f3f78976403be Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Thu, 11 Nov 2021 20:56:16 +0800 Subject: [PATCH] Test: balancer / auth / controller in pkg --- bootstrap/init.go | 3 +- pkg/aria2/aria2.go | 4 +- pkg/aria2/aria2_test.go | 82 ++++------- pkg/aria2/caller.go | 114 -------------- pkg/aria2/caller_test.go | 52 ------- pkg/aria2/notification_test.go | 52 ------- pkg/auth/auth.go | 8 +- pkg/auth/auth_test.go | 13 ++ pkg/auth/hmac.go | 2 +- pkg/balancer/balancer_test.go | 12 ++ pkg/balancer/roundrobin_test.go | 42 ++++++ pkg/cluster/controller_test.go | 254 ++++++++++++++++++++++++++++++++ pkg/mocks/mocks.go | 10 ++ routers/controllers/admin.go | 4 +- service/aria2/add.go | 4 +- 15 files changed, 373 insertions(+), 283 deletions(-) delete mode 100644 pkg/aria2/caller.go delete mode 100644 pkg/aria2/caller_test.go delete mode 100644 pkg/aria2/notification_test.go create mode 100644 pkg/balancer/balancer_test.go create mode 100644 pkg/balancer/roundrobin_test.go create mode 100644 pkg/cluster/controller_test.go diff --git a/bootstrap/init.go b/bootstrap/init.go index 60ebcb8..0a51835 100644 --- a/bootstrap/init.go +++ b/bootstrap/init.go @@ -9,6 +9,7 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/crontab" "github.com/cloudreve/Cloudreve/v3/pkg/email" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/task" "github.com/gin-gonic/gin" ) @@ -53,7 +54,7 @@ func Init(path string) { { "master", func() { - aria2.Init(false) + aria2.Init(false, cluster.Default, mq.GlobalMQ) }, }, { diff --git a/pkg/aria2/aria2.go b/pkg/aria2/aria2.go index ef2f0df..60d254e 100644 --- a/pkg/aria2/aria2.go +++ b/pkg/aria2/aria2.go @@ -33,7 +33,7 @@ func GetLoadBalancer() balancer.Balancer { } // Init 初始化 -func Init(isReload bool) { +func Init(isReload bool, pool cluster.Pool, mqClient mq.MQ) { Lock.Lock() LB = balancer.NewBalancer("RoundRobin") Lock.Unlock() @@ -44,7 +44,7 @@ func Init(isReload bool) { for i := 0; i < len(unfinished); i++ { // 创建任务监控 - monitor.NewMonitor(&unfinished[i], cluster.Default, mq.GlobalMQ) + monitor.NewMonitor(&unfinished[i], pool, mqClient) } } } diff --git a/pkg/aria2/aria2_test.go b/pkg/aria2/aria2_test.go index dfd71a3..b6e7092 100644 --- a/pkg/aria2/aria2_test.go +++ b/pkg/aria2/aria2_test.go @@ -2,14 +2,15 @@ package aria2 import ( "database/sql" + "github.com/cloudreve/Cloudreve/v3/pkg/mocks" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" + "github.com/stretchr/testify/assert" + testMock "github.com/stretchr/testify/mock" "testing" "github.com/DATA-DOG/go-sqlmock" model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" ) var mock sqlmock.Sqlmock @@ -27,66 +28,39 @@ func TestMain(m *testing.M) { m.Run() } -func TestDummyAria2(t *testing.T) { - asserts := assert.New(t) - instance := DummyAria2{} - asserts.Error(instance.CreateTask(nil, nil)) - _, err := instance.Status(nil) - asserts.Error(err) - asserts.Error(instance.Cancel(nil)) - asserts.Error(instance.Select(nil, nil)) -} - func TestInit(t *testing.T) { - monitor.MAX_RETRY = 0 - asserts := assert.New(t) - cache.Set("setting_aria2_token", "1", 0) - cache.Set("setting_aria2_call_timeout", "5", 0) - cache.Set("setting_aria2_options", `[]`, 0) + a := assert.New(t) + mockPool := &mocks.NodePoolMock{} + mockPool.On("GetNodeByID", testMock.Anything).Return(nil) + mockQueue := mq.NewMQ() - // 未指定RPC地址,跳过 - { - cache.Set("setting_aria2_rpcurl", "", 0) - Init(false) - asserts.IsType(&DummyAria2{}, Instance) - } + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + Init(false, mockPool, mockQueue) + a.NoError(mock.ExpectationsWereMet()) + mockPool.AssertExpectations(t) +} - // 无法解析服务器地址 - { - cache.Set("setting_aria2_rpcurl", string(byte(0x7f)), 0) - Init(false) - asserts.IsType(&DummyAria2{}, Instance) - } +func TestTestRPCConnection(t *testing.T) { + a := assert.New(t) - // 无法解析全局配置 + // url not legal { - Instance = &RPCService{} - cache.Set("setting_aria2_options", "?", 0) - cache.Set("setting_aria2_rpcurl", "ws://127.0.0.1:1234", 0) - Init(false) - asserts.IsType(&DummyAria2{}, Instance) + res, err := TestRPCConnection(string([]byte{0x7f}), "", 10) + a.Error(err) + a.Empty(res.Version) } - // 连接失败 + // rpc failed { - cache.Set("setting_aria2_options", "{}", 0) - cache.Set("setting_aria2_rpcurl", "http://127.0.0.1:1234", 0) - cache.Set("setting_aria2_call_timeout", "1", 0) - cache.Set("setting_aria2_interval", "100", 0) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"g_id"}).AddRow("1")) - Init(false) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.IsType(&RPCService{}, Instance) + res, err := TestRPCConnection("ws://0.0.0.0", "", 0) + a.Error(err) + a.Empty(res.Version) } } -func TestGetStatus(t *testing.T) { - asserts := assert.New(t) - asserts.Equal(4, GetStatus("complete")) - asserts.Equal(1, GetStatus("active")) - asserts.Equal(0, GetStatus("waiting")) - asserts.Equal(2, GetStatus("paused")) - asserts.Equal(3, GetStatus("error")) - asserts.Equal(5, GetStatus("removed")) - asserts.Equal(6, GetStatus("?")) +func TestGetLoadBalancer(t *testing.T) { + a := assert.New(t) + a.NotPanics(func() { + GetLoadBalancer() + }) } diff --git a/pkg/aria2/caller.go b/pkg/aria2/caller.go deleted file mode 100644 index 70e0bea..0000000 --- a/pkg/aria2/caller.go +++ /dev/null @@ -1,114 +0,0 @@ -package aria2 - -import ( - "context" - "path/filepath" - "strconv" - "strings" - "time" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" - "github.com/cloudreve/Cloudreve/v3/pkg/util" -) - -// RPCService 通过RPC服务的Aria2任务管理器 -type RPCService struct { - options *clientOptions - Caller rpc.Client -} - -type clientOptions struct { - Options map[string]interface{} // 创建下载时额外添加的设置 -} - -// Init 初始化 -func (client *RPCService) Init(server, secret string, timeout int, options map[string]interface{}) error { - // 客户端已存在,则关闭先前连接 - if client.Caller != nil { - client.Caller.Close() - } - - client.options = &clientOptions{ - Options: options, - } - caller, err := rpc.New(context.Background(), server, secret, time.Duration(timeout)*time.Second, - mq.GlobalMQ) - client.Caller = caller - return err -} - -// Status 查询下载状态 -func (client *RPCService) Status(task *model.Download) (rpc.StatusInfo, error) { - res, err := client.Caller.TellStatus(task.GID) - if err != nil { - // 失败后重试 - util.Log().Debug("无法获取离线下载状态,%s,10秒钟后重试", err) - time.Sleep(time.Duration(10) * time.Second) - res, err = client.Caller.TellStatus(task.GID) - } - - return res, err -} - -// Cancel 取消下载 -func (client *RPCService) Cancel(task *model.Download) error { - // 取消下载任务 - _, err := client.Caller.Remove(task.GID) - if err != nil { - util.Log().Warning("无法取消离线下载任务[%s], %s", task.GID, err) - } - - //// 删除临时文件 - //util.Log().Debug("离线下载任务[%s]已取消,1 分钟后删除临时文件", task.GID) - //go func(task *model.Download) { - // select { - // case <-time.After(time.Duration(60) * time.Second): - // err := os.RemoveAll(task.Parent) - // if err != nil { - // util.Log().Warning("无法删除离线下载临时目录[%s], %s", task.Parent, err) - // } - // } - //}(task) - - return err -} - -// Select 选取要下载的文件 -func (client *RPCService) Select(task *model.Download, files []int) error { - var selected = make([]string, len(files)) - for i := 0; i < len(files); i++ { - selected[i] = strconv.Itoa(files[i]) - } - _, err := client.Caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")}) - return err -} - -// CreateTask 创建新任务 -func (client *RPCService) CreateTask(task *model.Download, groupOptions map[string]interface{}) (string, error) { - // 生成存储路径 - path := filepath.Join( - model.GetSettingByName("aria2_temp_path"), - "aria2", - strconv.FormatInt(time.Now().UnixNano(), 10), - ) - - // 创建下载任务 - options := map[string]interface{}{ - "dir": path, - } - for k, v := range client.options.Options { - options[k] = v - } - for k, v := range groupOptions { - options[k] = v - } - - gid, err := client.Caller.AddURI(task.Source, options) - if err != nil || gid == "" { - return "", err - } - - return gid, nil -} diff --git a/pkg/aria2/caller_test.go b/pkg/aria2/caller_test.go deleted file mode 100644 index f215689..0000000 --- a/pkg/aria2/caller_test.go +++ /dev/null @@ -1,52 +0,0 @@ -package aria2 - -import ( - "testing" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/stretchr/testify/assert" -) - -func TestRPCService_Init(t *testing.T) { - asserts := assert.New(t) - caller := &RPCService{} - asserts.Error(caller.Init("ws://", "", 1, nil)) - asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil)) -} - -func TestRPCService_Status(t *testing.T) { - asserts := assert.New(t) - caller := &RPCService{} - asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil)) - - _, err := caller.Status(&model.Download{}) - asserts.Error(err) -} - -func TestRPCService_Cancel(t *testing.T) { - asserts := assert.New(t) - caller := &RPCService{} - asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil)) - - err := caller.Cancel(&model.Download{Parent: "test"}) - asserts.Error(err) -} - -func TestRPCService_Select(t *testing.T) { - asserts := assert.New(t) - caller := &RPCService{} - asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil)) - - err := caller.Select(&model.Download{Parent: "test"}, []int{1, 2, 3}) - asserts.Error(err) -} - -func TestRPCService_CreateTask(t *testing.T) { - asserts := assert.New(t) - caller := &RPCService{} - asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil)) - cache.Set("setting_aria2_temp_path", "test", 0) - err := caller.CreateTask(&model.Download{Parent: "test"}, map[string]interface{}{"1": "1"}) - asserts.Error(err) -} diff --git a/pkg/aria2/notification_test.go b/pkg/aria2/notification_test.go deleted file mode 100644 index 21a7ac1..0000000 --- a/pkg/aria2/notification_test.go +++ /dev/null @@ -1,52 +0,0 @@ -package aria2 - -import ( - "testing" - - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" - "github.com/stretchr/testify/assert" -) - -func TestNotifier_Notify(t *testing.T) { - asserts := assert.New(t) - notifier2 := &Notifier{} - notifyChan := make(chan StatusEvent, 10) - notifier2.Subscribe(notifyChan, "1") - - // 未订阅 - { - notifier2.Notify([]rpc.Event{rpc.Event{Gid: ""}}, 1) - asserts.Len(notifyChan, 0) - } - - // 订阅 - { - notifier2.Notify([]rpc.Event{{Gid: "1"}}, 1) - asserts.Len(notifyChan, 1) - <-notifyChan - - notifier2.OnBtDownloadComplete([]rpc.Event{{Gid: "1"}}) - asserts.Len(notifyChan, 1) - <-notifyChan - - notifier2.OnDownloadStart([]rpc.Event{{Gid: "1"}}) - asserts.Len(notifyChan, 1) - <-notifyChan - - notifier2.OnDownloadPause([]rpc.Event{{Gid: "1"}}) - asserts.Len(notifyChan, 1) - <-notifyChan - - notifier2.OnDownloadStop([]rpc.Event{{Gid: "1"}}) - asserts.Len(notifyChan, 1) - <-notifyChan - - notifier2.OnDownloadComplete([]rpc.Event{{Gid: "1"}}) - asserts.Len(notifyChan, 1) - <-notifyChan - - notifier2.OnDownloadError([]rpc.Event{{Gid: "1"}}) - asserts.Len(notifyChan, 1) - <-notifyChan - } -} diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index d8250e8..20b9e10 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -17,8 +17,10 @@ import ( ) var ( - ErrAuthFailed = serializer.NewError(serializer.CodeNoPermissionErr, "鉴权失败", nil) - ErrExpired = serializer.NewError(serializer.CodeSignExpired, "签名已过期", nil) + ErrAuthFailed = serializer.NewError(serializer.CodeNoPermissionErr, "鉴权失败", nil) + ErrAuthHeaderMissing = serializer.NewError(serializer.CodeNoPermissionErr, "authorization header is missing", nil) + ErrExpiresMissing = serializer.NewError(serializer.CodeNoPermissionErr, "expire timestamp is missing", nil) + ErrExpired = serializer.NewError(serializer.CodeSignExpired, "签名已过期", nil) ) // General 通用的认证接口 @@ -55,7 +57,7 @@ func CheckRequest(instance Auth, r *http.Request) error { ok bool ) if sign, ok = r.Header["Authorization"]; !ok || len(sign) == 0 { - return ErrAuthFailed + return ErrAuthHeaderMissing } sign[0] = strings.TrimPrefix(sign[0], "Bearer ") diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index 1092cb5..46533fb 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -80,6 +80,19 @@ func TestCheckRequest(t *testing.T) { asserts := assert.New(t) General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))} + // 缺少请求头 + { + req, err := http.NewRequest( + "POST", + "http://127.0.0.1/api/v3/upload", + strings.NewReader("I am body."), + ) + asserts.NoError(err) + err = CheckRequest(General, req) + asserts.Error(err) + asserts.Equal(ErrAuthHeaderMissing, err) + } + // 非上传请求 验证成功 { req, err := http.NewRequest( diff --git a/pkg/auth/hmac.go b/pkg/auth/hmac.go index e0a9573..50849cc 100644 --- a/pkg/auth/hmac.go +++ b/pkg/auth/hmac.go @@ -33,7 +33,7 @@ func (auth HMACAuth) Check(body string, sign string) error { signSlice := strings.Split(sign, ":") // 如果未携带expires字段 if signSlice[len(signSlice)-1] == "" { - return ErrAuthFailed + return ErrExpiresMissing } // 验证是否过期 diff --git a/pkg/balancer/balancer_test.go b/pkg/balancer/balancer_test.go new file mode 100644 index 0000000..4493bbb --- /dev/null +++ b/pkg/balancer/balancer_test.go @@ -0,0 +1,12 @@ +package balancer + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestNewBalancer(t *testing.T) { + a := assert.New(t) + a.NotNil(NewBalancer("")) + a.IsType(&RoundRobin{}, NewBalancer("RoundRobin")) +} diff --git a/pkg/balancer/roundrobin_test.go b/pkg/balancer/roundrobin_test.go new file mode 100644 index 0000000..9cdcc00 --- /dev/null +++ b/pkg/balancer/roundrobin_test.go @@ -0,0 +1,42 @@ +package balancer + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestRoundRobin_NextIndex(t *testing.T) { + a := assert.New(t) + r := &RoundRobin{} + total := 5 + for i := 1; i < total; i++ { + a.Equal(i, r.NextIndex(total)) + } + for i := 0; i < total; i++ { + a.Equal(i, r.NextIndex(total)) + } +} + +func TestRoundRobin_NextPeer(t *testing.T) { + a := assert.New(t) + r := &RoundRobin{} + + // not slice + { + err, _ := r.NextPeer("s") + a.Equal(ErrInputNotSlice, err) + } + + // no nodes + { + err, _ := r.NextPeer([]string{}) + a.Equal(ErrNoAvaliableNode, err) + } + + // pass + { + err, res := r.NextPeer([]string{"a"}) + a.NoError(err) + a.Equal("a", res.(string)) + } +} diff --git a/pkg/cluster/controller_test.go b/pkg/cluster/controller_test.go new file mode 100644 index 0000000..0ee8651 --- /dev/null +++ b/pkg/cluster/controller_test.go @@ -0,0 +1,254 @@ +package cluster + +import ( + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" + "github.com/cloudreve/Cloudreve/v3/pkg/auth" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" + "github.com/cloudreve/Cloudreve/v3/pkg/request" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/stretchr/testify/assert" + testMock "github.com/stretchr/testify/mock" + "io" + "io/ioutil" + "net/http" + "strings" + "testing" +) + +func TestInitController(t *testing.T) { + assert.NotPanics(t, func() { + InitController() + }) +} + +func TestSlaveController_HandleHeartBeat(t *testing.T) { + a := assert.New(t) + c := &slaveController{ + masters: make(map[string]MasterInfo), + } + + // first heart beat + { + _, err := c.HandleHeartBeat(&serializer.NodePingReq{ + SiteID: "1", + Node: &model.Node{}, + }) + a.NoError(err) + + _, err = c.HandleHeartBeat(&serializer.NodePingReq{ + SiteID: "2", + Node: &model.Node{}, + }) + a.NoError(err) + + a.Len(c.masters, 2) + } + + // second heart beat, no fresh + { + _, err := c.HandleHeartBeat(&serializer.NodePingReq{ + SiteID: "1", + SiteURL: "http://127.0.0.1", + Node: &model.Node{}, + }) + a.NoError(err) + a.Len(c.masters, 2) + a.Empty(c.masters["1"].URL) + } + + // second heart beat, fresh + { + _, err := c.HandleHeartBeat(&serializer.NodePingReq{ + SiteID: "1", + IsUpdate: true, + SiteURL: "http://127.0.0.1", + Node: &model.Node{}, + }) + a.NoError(err) + a.Len(c.masters, 2) + a.Equal("http://127.0.0.1", c.masters["1"].URL.String()) + } + + // second heart beat, fresh, url illegal + { + _, err := c.HandleHeartBeat(&serializer.NodePingReq{ + SiteID: "1", + IsUpdate: true, + SiteURL: string([]byte{0x7f}), + Node: &model.Node{}, + }) + a.Error(err) + a.Len(c.masters, 2) + a.Equal("http://127.0.0.1", c.masters["1"].URL.String()) + } +} + +type nodeMock struct { + testMock.Mock +} + +func (n nodeMock) Init(node *model.Node) { + n.Called(node) +} + +func (n nodeMock) IsFeatureEnabled(feature string) bool { + args := n.Called(feature) + return args.Bool(0) +} + +func (n nodeMock) SubscribeStatusChange(callback func(isActive bool, id uint)) { + n.Called(callback) +} + +func (n nodeMock) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) { + args := n.Called(req) + return args.Get(0).(*serializer.NodePingResp), args.Error(1) +} + +func (n nodeMock) IsActive() bool { + args := n.Called() + return args.Bool(0) +} + +func (n nodeMock) GetAria2Instance() common.Aria2 { + args := n.Called() + return args.Get(0).(common.Aria2) +} + +func (n nodeMock) ID() uint { + args := n.Called() + return args.Get(0).(uint) +} + +func (n nodeMock) Kill() { + n.Called() +} + +func (n nodeMock) IsMater() bool { + args := n.Called() + return args.Bool(0) +} + +func (n nodeMock) MasterAuthInstance() auth.Auth { + args := n.Called() + return args.Get(0).(auth.Auth) +} + +func (n nodeMock) SlaveAuthInstance() auth.Auth { + args := n.Called() + return args.Get(0).(auth.Auth) +} + +func (n nodeMock) DBModel() *model.Node { + args := n.Called() + return args.Get(0).(*model.Node) +} + +func TestSlaveController_GetAria2Instance(t *testing.T) { + a := assert.New(t) + mockNode := &nodeMock{} + mockNode.On("GetAria2Instance").Return(&common.DummyAria2{}) + c := &slaveController{ + masters: map[string]MasterInfo{ + "1": {Instance: mockNode}, + }, + } + + // node node found + { + res, err := c.GetAria2Instance("2") + a.Nil(res) + a.Equal(ErrMasterNotFound, err) + } + + // node found + { + res, err := c.GetAria2Instance("1") + a.NotNil(res) + a.NoError(err) + mockNode.AssertExpectations(t) + } + +} + +type requestMock struct { + testMock.Mock +} + +func (r requestMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response { + return r.Called(method, target, body, opts).Get(0).(*request.Response) +} + +func TestSlaveController_SendNotification(t *testing.T) { + a := assert.New(t) + c := &slaveController{ + masters: map[string]MasterInfo{ + "1": {}, + }, + } + + // node not exit + { + a.Equal(ErrMasterNotFound, c.SendNotification("2", "", mq.Message{})) + } + + // gob encode error + { + type randomType struct{} + a.Error(c.SendNotification("1", "", mq.Message{ + Content: randomType{}, + })) + } + + // return none 200 + { + mockRequest := &requestMock{} + mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s1", testMock.Anything, testMock.Anything).Return(&request.Response{ + Response: &http.Response{StatusCode: http.StatusConflict}, + }) + c := &slaveController{ + masters: map[string]MasterInfo{ + "1": {Client: mockRequest}, + }, + } + a.Error(c.SendNotification("1", "s1", mq.Message{})) + mockRequest.AssertExpectations(t) + } + + // master return error + { + mockRequest := &requestMock{} + mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s2", testMock.Anything, testMock.Anything).Return(&request.Response{ + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")), + }, + }) + c := &slaveController{ + masters: map[string]MasterInfo{ + "1": {Client: mockRequest}, + }, + } + a.Equal(1, c.SendNotification("1", "s2", mq.Message{}).(serializer.AppError).Code) + mockRequest.AssertExpectations(t) + } + + // success + { + mockRequest := &requestMock{} + mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s3", testMock.Anything, testMock.Anything).Return(&request.Response{ + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader("{\"code\":0}")), + }, + }) + c := &slaveController{ + masters: map[string]MasterInfo{ + "1": {Client: mockRequest}, + }, + } + a.NoError(c.SendNotification("1", "s3", mq.Message{})) + mockRequest.AssertExpectations(t) + } +} diff --git a/pkg/mocks/mocks.go b/pkg/mocks/mocks.go index 2134e86..6b7e674 100644 --- a/pkg/mocks/mocks.go +++ b/pkg/mocks/mocks.go @@ -8,9 +8,11 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/balancer" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/mq" + "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/task" testMock "github.com/stretchr/testify/mock" + "io" ) type SlaveControllerMock struct { @@ -184,3 +186,11 @@ func (t TaskPoolMock) Add(num int) { func (t TaskPoolMock) Submit(job task.Job) { t.Called(job) } + +type RequestMock struct { + testMock.Mock +} + +func (r RequestMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response { + return r.Called(method, target, body, opts).Get(0).(*request.Response) +} diff --git a/routers/controllers/admin.go b/routers/controllers/admin.go index a3ebfa5..fb0d6d6 100644 --- a/routers/controllers/admin.go +++ b/routers/controllers/admin.go @@ -1,6 +1,8 @@ package controllers import ( + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" "io" model "github.com/cloudreve/Cloudreve/v3/models" @@ -72,7 +74,7 @@ func AdminReloadService(c *gin.Context) { case "email": email.Init() case "aria2": - aria2.Init(true) + aria2.Init(true, cluster.Default, mq.GlobalMQ) } c.JSON(200, serializer.Response{}) diff --git a/service/aria2/add.go b/service/aria2/add.go index 8443c14..2c72c8b 100644 --- a/service/aria2/add.go +++ b/service/aria2/add.go @@ -48,9 +48,7 @@ func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Respo } // 获取 Aria2 负载均衡器 - aria2.Lock.RLock() - lb := aria2.LB - aria2.Lock.RUnlock() + lb := aria2.GetLoadBalancer() // 获取 Aria2 实例 err, node := cluster.Default.BalanceNodeByFeature("aria2", lb)