Refactor: move slave pkg inside of cluster

Test: middleware for node communication
pull/1048/head
HFO4 3 years ago
parent eaa0f6be91
commit e41ec9defa

@ -9,7 +9,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/crontab"
"github.com/cloudreve/Cloudreve/v3/pkg/email"
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
"github.com/cloudreve/Cloudreve/v3/pkg/task"
"github.com/gin-gonic/gin"
)
@ -78,7 +77,7 @@ func Init(path string) {
{
"slave",
func() {
slave.Init()
cluster.InitController()
},
},
{

@ -37,6 +37,7 @@ func SignRequired(authInstance auth.Auth) gin.HandlerFunc {
c.Abort()
return
}
c.Next()
}
}

@ -90,15 +90,27 @@ func TestSignRequired(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("GET", "/test", nil)
SignRequiredFunc := SignRequired(auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))})
authInstance := auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
SignRequiredFunc := SignRequired(authInstance)
// 鉴权失败
SignRequiredFunc(c)
asserts.NotNil(c)
asserts.True(c.IsAborted())
c, _ = gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("PUT", "/test", nil)
SignRequiredFunc(c)
asserts.NotNil(c)
asserts.True(c.IsAborted())
// Sign verify success
c, _ = gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("PUT", "/test", nil)
c.Request = auth.SignRequest(authInstance, c.Request, 0)
SignRequiredFunc(c)
asserts.NotNil(c)
asserts.False(c.IsAborted())
}
func TestWebDAVAuth(t *testing.T) {
@ -780,8 +792,6 @@ func TestS3CallbackAuth(t *testing.T) {
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1))
mock.ExpectQuery("SELECT(.+)groups(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[702]"))
mock.ExpectQuery("SELECT(.+)policies(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "access_key", "secret_key"}).AddRow(2, "123", "123"))
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"key", "testCallBackUpyun"},
@ -789,5 +799,6 @@ func TestS3CallbackAuth(t *testing.T) {
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1")))
AuthFunc(c)
asserts.False(c.IsAborted())
asserts.NoError(mock.ExpectationsWereMet())
}
}

@ -3,7 +3,6 @@ package middleware
import (
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
"github.com/gin-gonic/gin"
"strconv"
)
@ -19,11 +18,11 @@ func MasterMetadata() gin.HandlerFunc {
}
// UseSlaveAria2Instance 从机用于获取对应主机节点的Aria2实例
func UseSlaveAria2Instance() gin.HandlerFunc {
func UseSlaveAria2Instance(clusterController cluster.Controller) gin.HandlerFunc {
return func(c *gin.Context) {
if siteID, exist := c.Get("MasterSiteID"); exist {
// 获取对应主机节点的从机Aria2实例
caller, err := slave.DefaultController.GetAria2Instance(siteID.(string))
caller, err := clusterController.GetAria2Instance(siteID.(string))
if err != nil {
c.JSON(200, serializer.Err(serializer.CodeNotSet, "无法获取 Aria2 实例", err))
c.Abort()
@ -40,7 +39,7 @@ func UseSlaveAria2Instance() gin.HandlerFunc {
}
}
func SlaveRPCSignRequired() gin.HandlerFunc {
func SlaveRPCSignRequired(nodePool cluster.Pool) gin.HandlerFunc {
return func(c *gin.Context) {
nodeID, err := strconv.ParseUint(c.GetHeader("X-Node-Id"), 10, 64)
if err != nil {
@ -49,7 +48,7 @@ func SlaveRPCSignRequired() gin.HandlerFunc {
return
}
slaveNode := cluster.Default.GetNodeByID(uint(nodeID))
slaveNode := nodePool.GetNodeByID(uint(nodeID))
if slaveNode == nil {
c.JSON(200, serializer.ParamErr("未知的主机节点ID", err))
c.Abort()

@ -0,0 +1,80 @@
package middleware
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
"net/http/httptest"
"testing"
)
func TestMasterMetadata(t *testing.T) {
a := assert.New(t)
masterMetaDataFunc := MasterMetadata()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("GET", "/", nil)
c.Request.Header = map[string][]string{
"X-Site-Id": {"expectedSiteID"},
"X-Site-Url": {"expectedSiteURL"},
"X-Cloudreve-Version": {"expectedMasterVersion"},
}
masterMetaDataFunc(c)
siteID, _ := c.Get("MasterSiteID")
siteURL, _ := c.Get("MasterSiteURL")
siteVersion, _ := c.Get("MasterVersion")
a.Equal("expectedSiteID", siteID.(string))
a.Equal("expectedSiteURL", siteURL.(string))
a.Equal("expectedMasterVersion", siteVersion.(string))
}
func TestSlaveRPCSignRequired(t *testing.T) {
a := assert.New(t)
np := &cluster.NodePool{}
np.Init()
slaveRPCSignRequiredFunc := SlaveRPCSignRequired(np)
rec := httptest.NewRecorder()
// id parse failed
{
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("GET", "/", nil)
c.Request.Header.Set("X-Node-Id", "unknown")
slaveRPCSignRequiredFunc(c)
a.True(c.IsAborted())
}
// node id not exist
{
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("GET", "/", nil)
c.Request.Header.Set("X-Node-Id", "38")
slaveRPCSignRequiredFunc(c)
a.True(c.IsAborted())
}
// success
{
authInstance := auth.HMACAuth{SecretKey: []byte("")}
np.Add(&model.Node{Model: gorm.Model{
ID: 38,
}})
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("POST", "/", nil)
c.Request.Header.Set("X-Node-Id", "38")
c.Request = auth.SignRequest(authInstance, c.Request, 0)
slaveRPCSignRequiredFunc(c)
a.False(c.IsAborted())
}
}
func TestUseSlaveAria2Instance(t *testing.T) {
a := assert.New(t)
}

@ -35,7 +35,7 @@ type User struct {
Storage uint64
TwoFactor string
Avatar string
Options string `json:"-",gorm:"type:text"`
Options string `json:"-" gorm:"type:text"`
Authn string `gorm:"type:text"`
// 关联模型

@ -1,4 +1,4 @@
package slave
package cluster
import (
"bytes"
@ -8,7 +8,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
@ -51,13 +50,13 @@ type MasterInfo struct {
TTL int
URL *url.URL
// used to invoke aria2 rpc calls
Instance cluster.Node
Instance Node
Client request.Client
jobTracker map[string]bool
}
func Init() {
func InitController() {
DefaultController = &slaveController{
masters: make(map[string]MasterInfo),
}
@ -95,7 +94,7 @@ func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializ
}, int64(req.CredentialTTL)),
),
jobTracker: make(map[string]bool),
Instance: cluster.NewNodeFromDBModel(&model.Node{
Instance: NewNodeFromDBModel(&model.Node{
Model: gorm.Model{ID: req.Node.ID},
MasterKey: req.Node.MasterKey,
Type: model.MasterNodeType,

@ -1,8 +1,12 @@
package cluster
import "errors"
import (
"errors"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)
var (
ErrFeatureNotExist = errors.New("No nodes in nodepool match the feature specificed")
ErrIlegalPath = errors.New("path out of boundary of setting temp folder")
ErrMasterNotFound = serializer.NewError(serializer.CodeMasterNotFound, "未知的主机节点", nil)
)

@ -39,14 +39,22 @@ type NodePool struct {
// Init 初始化从机节点池
func Init() {
Default = &NodePool{
featureMap: make(map[string][]Node),
}
Default = &NodePool{}
Default.Init()
if err := Default.initFromDB(); err != nil {
util.Log().Warning("节点池初始化失败, %s", err)
}
}
func (pool *NodePool) Init() {
pool.lock.Lock()
defer pool.lock.Unlock()
pool.featureMap = make(map[string][]Node)
pool.active = make(map[uint]Node)
pool.inactive = make(map[uint]Node)
}
func (pool *NodePool) buildIndexMap() {
pool.lock.Lock()
for _, feature := range featureGroup {
@ -98,8 +106,6 @@ func (pool *NodePool) initFromDB() error {
}
pool.lock.Lock()
pool.active = make(map[uint]Node)
pool.inactive = make(map[uint]Node)
for i := 0; i < len(nodes); i++ {
pool.add(&nodes[i])
}

@ -3,6 +3,7 @@ package onedrive
import (
"context"
"encoding/json"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"io/ioutil"
"net/http"
"net/url"
@ -12,7 +13,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
@ -179,7 +179,7 @@ func (client *Client) UpdateCredential(ctx context.Context) error {
// UpdateCredential 更新凭证,并检查有效期
func (client *Client) fetchCredentialFromMaster(ctx context.Context) error {
res, err := slave.DefaultController.GetOneDriveToken(client.Policy.MasterID, client.Policy.ID)
res, err := cluster.DefaultController.GetOneDriveToken(client.Policy.MasterID, client.Policy.ID)
if err != nil {
return err
}

@ -1,7 +0,0 @@
package slave
import "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
var (
ErrMasterNotFound = serializer.NewError(serializer.CodeMasterNotFound, "未知的主机节点", nil)
)

@ -3,11 +3,11 @@ package slavetask
import (
"context"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
"github.com/cloudreve/Cloudreve/v3/pkg/task"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"os"
@ -68,7 +68,7 @@ func (job *TransferTask) SetErrorMsg(msg string, err error) {
},
}
if err := slave.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), notifyMsg); err != nil {
if err := cluster.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), notifyMsg); err != nil {
util.Log().Warning("无法发送转存失败通知到从机, ", err)
}
}
@ -94,7 +94,7 @@ func (job *TransferTask) Do() {
return
}
master, err := slave.DefaultController.GetMasterInfo(job.MasterID)
master, err := cluster.DefaultController.GetMasterInfo(job.MasterID)
if err != nil {
job.SetErrorMsg("找不到主机节点", err)
return
@ -131,7 +131,7 @@ func (job *TransferTask) Do() {
Content: serializer.SlaveTransferResult{},
}
if err := slave.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), msg); err != nil {
if err := cluster.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), msg); err != nil {
util.Log().Warning("无法发送转存成功通知到从机, ", err)
}
}

@ -3,6 +3,7 @@ package routers
import (
"github.com/cloudreve/Cloudreve/v3/middleware"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
@ -59,7 +60,7 @@ func InitSlaveRouter() *gin.Engine {
// 离线下载
aria2 := v3.Group("aria2")
aria2.Use(middleware.UseSlaveAria2Instance())
aria2.Use(middleware.UseSlaveAria2Instance(cluster.DefaultController))
{
// 创建离线下载任务
aria2.POST("task", controllers.SlaveAria2Create)
@ -205,7 +206,7 @@ func InitMasterRouter() *gin.Engine {
// 从机的 RPC 通信
slave := v3.Group("slave")
slave.Use(middleware.SlaveRPCSignRequired())
slave.Use(middleware.SlaveRPCSignRequired(cluster.Default))
{
// 事件通知
slave.PUT("notification/:subject", controllers.SlaveNotificationPush)

@ -9,7 +9,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
)
@ -91,7 +90,7 @@ func Add(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response
// 创建事件通知回调
siteID, _ := c.Get("MasterSiteID")
mq.GlobalMQ.SubscribeCallback(gid, func(message mq.Message) {
if err := slave.DefaultController.SendNotification(siteID.(string), message.TriggeredBy, message); err != nil {
if err := cluster.DefaultController.SendNotification(siteID.(string), message.TriggeredBy, message); err != nil {
util.Log().Warning("无法发送离线下载任务状态变更通知, %s", err)
}
})

@ -6,10 +6,10 @@ import (
"encoding/json"
"fmt"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
"github.com/cloudreve/Cloudreve/v3/pkg/task"
"github.com/cloudreve/Cloudreve/v3/pkg/task/slavetask"
"github.com/gin-gonic/gin"
@ -153,7 +153,7 @@ func CreateTransferTask(c *gin.Context, req *serializer.SlaveTransferReq) serial
MasterID: id.(string),
}
if err := slave.DefaultController.SubmitTask(job.MasterID, job, req.Hash(job.MasterID), func(job interface{}) {
if err := cluster.DefaultController.SubmitTask(job.MasterID, job, req.Hash(job.MasterID), func(job interface{}) {
task.TaskPoll.Submit(job.(task.Job))
}); err != nil {
return serializer.Err(serializer.CodeInternalSetting, "任务创建失败", err)

@ -3,10 +3,10 @@ package node
import (
"encoding/gob"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/onedrive"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
"github.com/gin-gonic/gin"
)
@ -19,7 +19,7 @@ type OneDriveCredentialService struct {
}
func HandleMasterHeartbeat(req *serializer.NodePingReq) serializer.Response {
res, err := slave.DefaultController.HandleHeartBeat(req)
res, err := cluster.DefaultController.HandleHeartBeat(req)
if err != nil {
return serializer.Err(serializer.CodeInternalSetting, "Cannot initialize slave controller", err)
}

Loading…
Cancel
Save