Test: balancer / auth / controller in pkg

pull/1056/head
HFO4 3 years ago
parent f0089045d7
commit 416f4c1dd2

@ -9,6 +9,7 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/crontab"
"github.com/cloudreve/Cloudreve/v3/pkg/email"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/task"
"github.com/gin-gonic/gin"
)
@ -53,7 +54,7 @@ func Init(path string) {
{
"master",
func() {
aria2.Init(false)
aria2.Init(false, cluster.Default, mq.GlobalMQ)
},
},
{

@ -33,7 +33,7 @@ func GetLoadBalancer() balancer.Balancer {
}
// Init 初始化
func Init(isReload bool) {
func Init(isReload bool, pool cluster.Pool, mqClient mq.MQ) {
Lock.Lock()
LB = balancer.NewBalancer("RoundRobin")
Lock.Unlock()
@ -44,7 +44,7 @@ func Init(isReload bool) {
for i := 0; i < len(unfinished); i++ {
// 创建任务监控
monitor.NewMonitor(&unfinished[i], cluster.Default, mq.GlobalMQ)
monitor.NewMonitor(&unfinished[i], pool, mqClient)
}
}
}

@ -2,14 +2,15 @@ package aria2
import (
"database/sql"
"github.com/cloudreve/Cloudreve/v3/pkg/mocks"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
"testing"
"github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
)
var mock sqlmock.Sqlmock
@ -27,66 +28,39 @@ func TestMain(m *testing.M) {
m.Run()
}
func TestDummyAria2(t *testing.T) {
asserts := assert.New(t)
instance := DummyAria2{}
asserts.Error(instance.CreateTask(nil, nil))
_, err := instance.Status(nil)
asserts.Error(err)
asserts.Error(instance.Cancel(nil))
asserts.Error(instance.Select(nil, nil))
}
func TestInit(t *testing.T) {
monitor.MAX_RETRY = 0
asserts := assert.New(t)
cache.Set("setting_aria2_token", "1", 0)
cache.Set("setting_aria2_call_timeout", "5", 0)
cache.Set("setting_aria2_options", `[]`, 0)
a := assert.New(t)
mockPool := &mocks.NodePoolMock{}
mockPool.On("GetNodeByID", testMock.Anything).Return(nil)
mockQueue := mq.NewMQ()
// 未指定RPC地址跳过
{
cache.Set("setting_aria2_rpcurl", "", 0)
Init(false)
asserts.IsType(&DummyAria2{}, Instance)
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
Init(false, mockPool, mockQueue)
a.NoError(mock.ExpectationsWereMet())
mockPool.AssertExpectations(t)
}
// 无法解析服务器地址
{
cache.Set("setting_aria2_rpcurl", string(byte(0x7f)), 0)
Init(false)
asserts.IsType(&DummyAria2{}, Instance)
}
func TestTestRPCConnection(t *testing.T) {
a := assert.New(t)
// 无法解析全局配置
// url not legal
{
Instance = &RPCService{}
cache.Set("setting_aria2_options", "?", 0)
cache.Set("setting_aria2_rpcurl", "ws://127.0.0.1:1234", 0)
Init(false)
asserts.IsType(&DummyAria2{}, Instance)
res, err := TestRPCConnection(string([]byte{0x7f}), "", 10)
a.Error(err)
a.Empty(res.Version)
}
// 连接失败
// rpc failed
{
cache.Set("setting_aria2_options", "{}", 0)
cache.Set("setting_aria2_rpcurl", "http://127.0.0.1:1234", 0)
cache.Set("setting_aria2_call_timeout", "1", 0)
cache.Set("setting_aria2_interval", "100", 0)
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"g_id"}).AddRow("1"))
Init(false)
asserts.NoError(mock.ExpectationsWereMet())
asserts.IsType(&RPCService{}, Instance)
res, err := TestRPCConnection("ws://0.0.0.0", "", 0)
a.Error(err)
a.Empty(res.Version)
}
}
func TestGetStatus(t *testing.T) {
asserts := assert.New(t)
asserts.Equal(4, GetStatus("complete"))
asserts.Equal(1, GetStatus("active"))
asserts.Equal(0, GetStatus("waiting"))
asserts.Equal(2, GetStatus("paused"))
asserts.Equal(3, GetStatus("error"))
asserts.Equal(5, GetStatus("removed"))
asserts.Equal(6, GetStatus("?"))
func TestGetLoadBalancer(t *testing.T) {
a := assert.New(t)
a.NotPanics(func() {
GetLoadBalancer()
})
}

@ -1,114 +0,0 @@
package aria2
import (
"context"
"path/filepath"
"strconv"
"strings"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
// RPCService 通过RPC服务的Aria2任务管理器
type RPCService struct {
options *clientOptions
Caller rpc.Client
}
type clientOptions struct {
Options map[string]interface{} // 创建下载时额外添加的设置
}
// Init 初始化
func (client *RPCService) Init(server, secret string, timeout int, options map[string]interface{}) error {
// 客户端已存在,则关闭先前连接
if client.Caller != nil {
client.Caller.Close()
}
client.options = &clientOptions{
Options: options,
}
caller, err := rpc.New(context.Background(), server, secret, time.Duration(timeout)*time.Second,
mq.GlobalMQ)
client.Caller = caller
return err
}
// Status 查询下载状态
func (client *RPCService) Status(task *model.Download) (rpc.StatusInfo, error) {
res, err := client.Caller.TellStatus(task.GID)
if err != nil {
// 失败后重试
util.Log().Debug("无法获取离线下载状态,%s10秒钟后重试", err)
time.Sleep(time.Duration(10) * time.Second)
res, err = client.Caller.TellStatus(task.GID)
}
return res, err
}
// Cancel 取消下载
func (client *RPCService) Cancel(task *model.Download) error {
// 取消下载任务
_, err := client.Caller.Remove(task.GID)
if err != nil {
util.Log().Warning("无法取消离线下载任务[%s], %s", task.GID, err)
}
//// 删除临时文件
//util.Log().Debug("离线下载任务[%s]已取消1 分钟后删除临时文件", task.GID)
//go func(task *model.Download) {
// select {
// case <-time.After(time.Duration(60) * time.Second):
// err := os.RemoveAll(task.Parent)
// if err != nil {
// util.Log().Warning("无法删除离线下载临时目录[%s], %s", task.Parent, err)
// }
// }
//}(task)
return err
}
// Select 选取要下载的文件
func (client *RPCService) Select(task *model.Download, files []int) error {
var selected = make([]string, len(files))
for i := 0; i < len(files); i++ {
selected[i] = strconv.Itoa(files[i])
}
_, err := client.Caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")})
return err
}
// CreateTask 创建新任务
func (client *RPCService) CreateTask(task *model.Download, groupOptions map[string]interface{}) (string, error) {
// 生成存储路径
path := filepath.Join(
model.GetSettingByName("aria2_temp_path"),
"aria2",
strconv.FormatInt(time.Now().UnixNano(), 10),
)
// 创建下载任务
options := map[string]interface{}{
"dir": path,
}
for k, v := range client.options.Options {
options[k] = v
}
for k, v := range groupOptions {
options[k] = v
}
gid, err := client.Caller.AddURI(task.Source, options)
if err != nil || gid == "" {
return "", err
}
return gid, nil
}

@ -1,52 +0,0 @@
package aria2
import (
"testing"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/stretchr/testify/assert"
)
func TestRPCService_Init(t *testing.T) {
asserts := assert.New(t)
caller := &RPCService{}
asserts.Error(caller.Init("ws://", "", 1, nil))
asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil))
}
func TestRPCService_Status(t *testing.T) {
asserts := assert.New(t)
caller := &RPCService{}
asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil))
_, err := caller.Status(&model.Download{})
asserts.Error(err)
}
func TestRPCService_Cancel(t *testing.T) {
asserts := assert.New(t)
caller := &RPCService{}
asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil))
err := caller.Cancel(&model.Download{Parent: "test"})
asserts.Error(err)
}
func TestRPCService_Select(t *testing.T) {
asserts := assert.New(t)
caller := &RPCService{}
asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil))
err := caller.Select(&model.Download{Parent: "test"}, []int{1, 2, 3})
asserts.Error(err)
}
func TestRPCService_CreateTask(t *testing.T) {
asserts := assert.New(t)
caller := &RPCService{}
asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil))
cache.Set("setting_aria2_temp_path", "test", 0)
err := caller.CreateTask(&model.Download{Parent: "test"}, map[string]interface{}{"1": "1"})
asserts.Error(err)
}

@ -1,52 +0,0 @@
package aria2
import (
"testing"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/stretchr/testify/assert"
)
func TestNotifier_Notify(t *testing.T) {
asserts := assert.New(t)
notifier2 := &Notifier{}
notifyChan := make(chan StatusEvent, 10)
notifier2.Subscribe(notifyChan, "1")
// 未订阅
{
notifier2.Notify([]rpc.Event{rpc.Event{Gid: ""}}, 1)
asserts.Len(notifyChan, 0)
}
// 订阅
{
notifier2.Notify([]rpc.Event{{Gid: "1"}}, 1)
asserts.Len(notifyChan, 1)
<-notifyChan
notifier2.OnBtDownloadComplete([]rpc.Event{{Gid: "1"}})
asserts.Len(notifyChan, 1)
<-notifyChan
notifier2.OnDownloadStart([]rpc.Event{{Gid: "1"}})
asserts.Len(notifyChan, 1)
<-notifyChan
notifier2.OnDownloadPause([]rpc.Event{{Gid: "1"}})
asserts.Len(notifyChan, 1)
<-notifyChan
notifier2.OnDownloadStop([]rpc.Event{{Gid: "1"}})
asserts.Len(notifyChan, 1)
<-notifyChan
notifier2.OnDownloadComplete([]rpc.Event{{Gid: "1"}})
asserts.Len(notifyChan, 1)
<-notifyChan
notifier2.OnDownloadError([]rpc.Event{{Gid: "1"}})
asserts.Len(notifyChan, 1)
<-notifyChan
}
}

@ -18,6 +18,8 @@ import (
var (
ErrAuthFailed = serializer.NewError(serializer.CodeNoPermissionErr, "鉴权失败", nil)
ErrAuthHeaderMissing = serializer.NewError(serializer.CodeNoPermissionErr, "authorization header is missing", nil)
ErrExpiresMissing = serializer.NewError(serializer.CodeNoPermissionErr, "expire timestamp is missing", nil)
ErrExpired = serializer.NewError(serializer.CodeSignExpired, "签名已过期", nil)
)
@ -55,7 +57,7 @@ func CheckRequest(instance Auth, r *http.Request) error {
ok bool
)
if sign, ok = r.Header["Authorization"]; !ok || len(sign) == 0 {
return ErrAuthFailed
return ErrAuthHeaderMissing
}
sign[0] = strings.TrimPrefix(sign[0], "Bearer ")

@ -80,6 +80,19 @@ func TestCheckRequest(t *testing.T) {
asserts := assert.New(t)
General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
// 缺少请求头
{
req, err := http.NewRequest(
"POST",
"http://127.0.0.1/api/v3/upload",
strings.NewReader("I am body."),
)
asserts.NoError(err)
err = CheckRequest(General, req)
asserts.Error(err)
asserts.Equal(ErrAuthHeaderMissing, err)
}
// 非上传请求 验证成功
{
req, err := http.NewRequest(

@ -33,7 +33,7 @@ func (auth HMACAuth) Check(body string, sign string) error {
signSlice := strings.Split(sign, ":")
// 如果未携带expires字段
if signSlice[len(signSlice)-1] == "" {
return ErrAuthFailed
return ErrExpiresMissing
}
// 验证是否过期

@ -0,0 +1,12 @@
package balancer
import (
"github.com/stretchr/testify/assert"
"testing"
)
func TestNewBalancer(t *testing.T) {
a := assert.New(t)
a.NotNil(NewBalancer(""))
a.IsType(&RoundRobin{}, NewBalancer("RoundRobin"))
}

@ -0,0 +1,42 @@
package balancer
import (
"github.com/stretchr/testify/assert"
"testing"
)
func TestRoundRobin_NextIndex(t *testing.T) {
a := assert.New(t)
r := &RoundRobin{}
total := 5
for i := 1; i < total; i++ {
a.Equal(i, r.NextIndex(total))
}
for i := 0; i < total; i++ {
a.Equal(i, r.NextIndex(total))
}
}
func TestRoundRobin_NextPeer(t *testing.T) {
a := assert.New(t)
r := &RoundRobin{}
// not slice
{
err, _ := r.NextPeer("s")
a.Equal(ErrInputNotSlice, err)
}
// no nodes
{
err, _ := r.NextPeer([]string{})
a.Equal(ErrNoAvaliableNode, err)
}
// pass
{
err, res := r.NextPeer([]string{"a"})
a.NoError(err)
a.Equal("a", res.(string))
}
}

@ -0,0 +1,254 @@
package cluster
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
"io"
"io/ioutil"
"net/http"
"strings"
"testing"
)
func TestInitController(t *testing.T) {
assert.NotPanics(t, func() {
InitController()
})
}
func TestSlaveController_HandleHeartBeat(t *testing.T) {
a := assert.New(t)
c := &slaveController{
masters: make(map[string]MasterInfo),
}
// first heart beat
{
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
SiteID: "1",
Node: &model.Node{},
})
a.NoError(err)
_, err = c.HandleHeartBeat(&serializer.NodePingReq{
SiteID: "2",
Node: &model.Node{},
})
a.NoError(err)
a.Len(c.masters, 2)
}
// second heart beat, no fresh
{
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
SiteID: "1",
SiteURL: "http://127.0.0.1",
Node: &model.Node{},
})
a.NoError(err)
a.Len(c.masters, 2)
a.Empty(c.masters["1"].URL)
}
// second heart beat, fresh
{
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
SiteID: "1",
IsUpdate: true,
SiteURL: "http://127.0.0.1",
Node: &model.Node{},
})
a.NoError(err)
a.Len(c.masters, 2)
a.Equal("http://127.0.0.1", c.masters["1"].URL.String())
}
// second heart beat, fresh, url illegal
{
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
SiteID: "1",
IsUpdate: true,
SiteURL: string([]byte{0x7f}),
Node: &model.Node{},
})
a.Error(err)
a.Len(c.masters, 2)
a.Equal("http://127.0.0.1", c.masters["1"].URL.String())
}
}
type nodeMock struct {
testMock.Mock
}
func (n nodeMock) Init(node *model.Node) {
n.Called(node)
}
func (n nodeMock) IsFeatureEnabled(feature string) bool {
args := n.Called(feature)
return args.Bool(0)
}
func (n nodeMock) SubscribeStatusChange(callback func(isActive bool, id uint)) {
n.Called(callback)
}
func (n nodeMock) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
args := n.Called(req)
return args.Get(0).(*serializer.NodePingResp), args.Error(1)
}
func (n nodeMock) IsActive() bool {
args := n.Called()
return args.Bool(0)
}
func (n nodeMock) GetAria2Instance() common.Aria2 {
args := n.Called()
return args.Get(0).(common.Aria2)
}
func (n nodeMock) ID() uint {
args := n.Called()
return args.Get(0).(uint)
}
func (n nodeMock) Kill() {
n.Called()
}
func (n nodeMock) IsMater() bool {
args := n.Called()
return args.Bool(0)
}
func (n nodeMock) MasterAuthInstance() auth.Auth {
args := n.Called()
return args.Get(0).(auth.Auth)
}
func (n nodeMock) SlaveAuthInstance() auth.Auth {
args := n.Called()
return args.Get(0).(auth.Auth)
}
func (n nodeMock) DBModel() *model.Node {
args := n.Called()
return args.Get(0).(*model.Node)
}
func TestSlaveController_GetAria2Instance(t *testing.T) {
a := assert.New(t)
mockNode := &nodeMock{}
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
c := &slaveController{
masters: map[string]MasterInfo{
"1": {Instance: mockNode},
},
}
// node node found
{
res, err := c.GetAria2Instance("2")
a.Nil(res)
a.Equal(ErrMasterNotFound, err)
}
// node found
{
res, err := c.GetAria2Instance("1")
a.NotNil(res)
a.NoError(err)
mockNode.AssertExpectations(t)
}
}
type requestMock struct {
testMock.Mock
}
func (r requestMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response {
return r.Called(method, target, body, opts).Get(0).(*request.Response)
}
func TestSlaveController_SendNotification(t *testing.T) {
a := assert.New(t)
c := &slaveController{
masters: map[string]MasterInfo{
"1": {},
},
}
// node not exit
{
a.Equal(ErrMasterNotFound, c.SendNotification("2", "", mq.Message{}))
}
// gob encode error
{
type randomType struct{}
a.Error(c.SendNotification("1", "", mq.Message{
Content: randomType{},
}))
}
// return none 200
{
mockRequest := &requestMock{}
mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s1", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{StatusCode: http.StatusConflict},
})
c := &slaveController{
masters: map[string]MasterInfo{
"1": {Client: mockRequest},
},
}
a.Error(c.SendNotification("1", "s1", mq.Message{}))
mockRequest.AssertExpectations(t)
}
// master return error
{
mockRequest := &requestMock{}
mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s2", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
},
})
c := &slaveController{
masters: map[string]MasterInfo{
"1": {Client: mockRequest},
},
}
a.Equal(1, c.SendNotification("1", "s2", mq.Message{}).(serializer.AppError).Code)
mockRequest.AssertExpectations(t)
}
// success
{
mockRequest := &requestMock{}
mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s3", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"code\":0}")),
},
})
c := &slaveController{
masters: map[string]MasterInfo{
"1": {Client: mockRequest},
},
}
a.NoError(c.SendNotification("1", "s3", mq.Message{}))
mockRequest.AssertExpectations(t)
}
}

@ -8,9 +8,11 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/task"
testMock "github.com/stretchr/testify/mock"
"io"
)
type SlaveControllerMock struct {
@ -184,3 +186,11 @@ func (t TaskPoolMock) Add(num int) {
func (t TaskPoolMock) Submit(job task.Job) {
t.Called(job)
}
type RequestMock struct {
testMock.Mock
}
func (r RequestMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response {
return r.Called(method, target, body, opts).Get(0).(*request.Response)
}

@ -1,6 +1,8 @@
package controllers
import (
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"io"
model "github.com/cloudreve/Cloudreve/v3/models"
@ -72,7 +74,7 @@ func AdminReloadService(c *gin.Context) {
case "email":
email.Init()
case "aria2":
aria2.Init(true)
aria2.Init(true, cluster.Default, mq.GlobalMQ)
}
c.JSON(200, serializer.Response{})

@ -48,9 +48,7 @@ func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Respo
}
// 获取 Aria2 负载均衡器
aria2.Lock.RLock()
lb := aria2.LB
aria2.Lock.RUnlock()
lb := aria2.GetLoadBalancer()
// 获取 Aria2 实例
err, node := cluster.Default.BalanceNodeByFeature("aria2", lb)

Loading…
Cancel
Save