Feat: slave aria2 status event callback / salve RPC auth

pull/1040/head
HFO4 4 years ago
parent cf2960a092
commit 870df708bf

@ -22,16 +22,14 @@ import (
) )
// SignRequired 验证请求签名 // SignRequired 验证请求签名
func SignRequired() gin.HandlerFunc { func SignRequired(authInstance auth.Auth) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
var err error var err error
switch c.Request.Method { switch c.Request.Method {
case "PUT", "POST": case "PUT", "POST", "PATCH":
err = auth.CheckRequest(auth.General, c.Request) err = auth.CheckRequest(authInstance, c.Request)
// TODO 生产环境去掉下一行
//err = nil
default: default:
err = auth.CheckURI(auth.General, c.Request.URL) err = auth.CheckURI(authInstance, c.Request.URL)
} }
if err != nil { if err != nil {

@ -87,11 +87,10 @@ func TestAuthRequired(t *testing.T) {
func TestSignRequired(t *testing.T) { func TestSignRequired(t *testing.T) {
asserts := assert.New(t) asserts := assert.New(t)
auth.General = auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec) c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("GET", "/test", nil) c.Request, _ = http.NewRequest("GET", "/test", nil)
SignRequiredFunc := SignRequired() SignRequiredFunc := SignRequired(auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))})
// 鉴权失败 // 鉴权失败
SignRequiredFunc(c) SignRequiredFunc(c)

@ -1,16 +1,18 @@
package middleware package middleware
import ( import (
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/slave" "github.com/cloudreve/Cloudreve/v3/pkg/slave"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"strconv"
) )
// MasterMetadata 解析主机节点发来请求的包含主机节点信息的元数据 // MasterMetadata 解析主机节点发来请求的包含主机节点信息的元数据
func MasterMetadata() gin.HandlerFunc { func MasterMetadata() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
c.Set("MasterSiteID", c.GetHeader("X-Site-ID")) c.Set("MasterSiteID", c.GetHeader("X-Site-Id"))
c.Set("MasterSiteURL", c.GetHeader("X-Site-Ur")) c.Set("MasterSiteURL", c.GetHeader("X-Site-Url"))
c.Set("MasterVersion", c.GetHeader("X-Cloudreve-Version")) c.Set("MasterVersion", c.GetHeader("X-Cloudreve-Version"))
c.Next() c.Next()
} }
@ -37,3 +39,24 @@ func UseSlaveAria2Instance() gin.HandlerFunc {
c.Abort() c.Abort()
} }
} }
func SlaveRPCSignRequired() gin.HandlerFunc {
return func(c *gin.Context) {
nodeID, err := strconv.ParseUint(c.GetHeader("X-Node-Id"), 10, 64)
if err != nil {
c.JSON(200, serializer.ParamErr("未知的主机节点ID", err))
c.Abort()
return
}
slaveNode := cluster.Default.GetNodeByID(uint(nodeID))
if slaveNode == nil {
c.JSON(200, serializer.ParamErr("未知的主机节点ID", err))
c.Abort()
return
}
SignRequired(slaveNode.GetAuthInstance())(c)
}
}

@ -3,5 +3,6 @@ package balancer
import "errors" import "errors"
var ( var (
ErrInputNotSlice = errors.New("Input value is not silice") ErrInputNotSlice = errors.New("Input value is not silice")
ErrNoAvaliableNode = errors.New("No nodes avaliable")
) )

@ -16,6 +16,10 @@ func (r *RoundRobin) NextPeer(nodes interface{}) (error, interface{}) {
return ErrInputNotSlice, nil return ErrInputNotSlice, nil
} }
if v.Len() == 0 {
return ErrNoAvaliableNode, nil
}
next := r.NextIndex(v.Len()) next := r.NextIndex(v.Len())
return nil, v.Index(next).Interface() return nil, v.Index(next).Interface()
} }

@ -6,6 +6,7 @@ import (
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/cloudreve/Cloudreve/v3/pkg/util"
"net/url" "net/url"
@ -75,6 +76,13 @@ func (node *MasterNode) IsFeatureEnabled(feature string) bool {
} }
} }
func (node *MasterNode) GetAuthInstance() auth.Auth {
node.lock.RLock()
defer node.lock.RUnlock()
return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)}
}
// SubscribeStatusChange 订阅节点状态更改 // SubscribeStatusChange 订阅节点状态更改
func (node *MasterNode) SubscribeStatusChange(callback func(isActive bool, id uint)) { func (node *MasterNode) SubscribeStatusChange(callback func(isActive bool, id uint)) {
} }

@ -3,6 +3,7 @@ package cluster
import ( import (
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
) )
@ -33,6 +34,9 @@ type Node interface {
// Returns if current node is master node // Returns if current node is master node
IsMater() bool IsMater() bool
// Get auth instance used to check RPC call from slave to master
GetAuthInstance() auth.Auth
} }
// Create new node from DB model // Create new node from DB model

@ -119,7 +119,11 @@ func (pool *NodePool) BalanceNodeByFeature(feature string, lb balancer.Balancer)
defer pool.lock.RUnlock() defer pool.lock.RUnlock()
if nodes, ok := pool.featureMap[feature]; ok { if nodes, ok := pool.featureMap[feature]; ok {
err, res := lb.NextPeer(nodes) err, res := lb.NextPeer(nodes)
return err, res.(Node) if err == nil {
return nil, res.(Node)
}
return err, nil
} }
return ErrFeatureNotExist, nil return ErrFeatureNotExist, nil

@ -187,6 +187,7 @@ func (node *SlaveNode) StartPingLoop() {
util.Log().Debug("从机节点 [%s] 启动心跳循环", node.Model.Name) util.Log().Debug("从机节点 [%s] 启动心跳循环", node.Model.Name)
retry := 0 retry := 0
recoverMode := false recoverMode := false
isFirstLoop := true
loop: loop:
for { for {
@ -197,7 +198,9 @@ loop:
} }
util.Log().Debug("从机节点 [%s] 发送Ping", node.Model.Name) util.Log().Debug("从机节点 [%s] 发送Ping", node.Model.Name)
res, err := node.Ping(node.getHeartbeatContent(false)) res, err := node.Ping(node.getHeartbeatContent(isFirstLoop))
isFirstLoop = false
if err != nil { if err != nil {
util.Log().Debug("Ping从机节点 [%s] 时发生错误: %s", node.Model.Name, err) util.Log().Debug("Ping从机节点 [%s] 时发生错误: %s", node.Model.Name, err)
retry++ retry++
@ -217,6 +220,7 @@ loop:
util.Log().Debug("从机节点 [%s] 复活", node.Model.Name) util.Log().Debug("从机节点 [%s] 复活", node.Model.Name)
pingTicker = tickDuration pingTicker = tickDuration
recoverMode = false recoverMode = false
isFirstLoop = true
} }
util.Log().Debug("从机节点 [%s] 状态: %s", node.Model.Name, res) util.Log().Debug("从机节点 [%s] 状态: %s", node.Model.Name, res)
@ -234,6 +238,7 @@ loop:
// getHeartbeatContent gets serializer.NodePingReq used to send heartbeat to slave // getHeartbeatContent gets serializer.NodePingReq used to send heartbeat to slave
func (node *SlaveNode) getHeartbeatContent(isUpdate bool) *serializer.NodePingReq { func (node *SlaveNode) getHeartbeatContent(isUpdate bool) *serializer.NodePingReq {
return &serializer.NodePingReq{ return &serializer.NodePingReq{
SiteURL: model.GetSiteURL().String(),
IsUpdate: isUpdate, IsUpdate: isUpdate,
SiteID: model.GetSettingByName("siteID"), SiteID: model.GetSettingByName("siteID"),
Node: node.Model, Node: node.Model,
@ -245,6 +250,13 @@ func (node *SlaveNode) IsMater() bool {
return false return false
} }
func (node *SlaveNode) GetAuthInstance() auth.Auth {
node.lock.RLock()
defer node.lock.RUnlock()
return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)}
}
func (s *slaveCaller) Init() error { func (s *slaveCaller) Init() error {
return nil return nil
} }

@ -154,7 +154,7 @@ func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Optio
if options.masterMeta { if options.masterMeta {
req.Header.Add("X-Site-Url", model.GetSiteURL().String()) req.Header.Add("X-Site-Url", model.GetSiteURL().String())
req.Header.Add("X-Site-ID", model.GetSettingByName("siteID")) req.Header.Add("X-Site-Id", model.GetSettingByName("siteID"))
req.Header.Add("X-Cloudreve-Version", conf.BackendVersion) req.Header.Add("X-Cloudreve-Version", conf.BackendVersion)
} }

@ -6,11 +6,11 @@ import (
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"net/http" "net/http"
"net/url"
"sync" "sync"
) )
@ -36,10 +36,10 @@ type slaveController struct {
// info of master node // info of master node
type masterInfo struct { type masterInfo struct {
slaveID uint slaveID uint
id string id string
authClient auth.Auth ttl int
ttl int url *url.URL
// used to invoke aria2 rpc calls // used to invoke aria2 rpc calls
instance cluster.Node instance cluster.Node
} }
@ -66,14 +66,18 @@ func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializ
origin.instance.Kill() origin.instance.Kill()
} }
masterUrl, err := url.Parse(req.SiteURL)
if err != nil {
return serializer.NodePingResp{}, err
}
c.masters[req.SiteID] = masterInfo{ c.masters[req.SiteID] = masterInfo{
slaveID: req.Node.ID, slaveID: req.Node.ID,
id: req.SiteID, id: req.SiteID,
authClient: auth.HMACAuth{ url: masterUrl,
SecretKey: []byte(req.Node.MasterKey), ttl: req.CredentialTTL,
},
ttl: req.CredentialTTL,
instance: cluster.NewNodeFromDBModel(&model.Node{ instance: cluster.NewNodeFromDBModel(&model.Node{
MasterKey: req.Node.MasterKey,
Type: model.MasterNodeType, Type: model.MasterNodeType,
Aria2Enabled: req.Node.Aria2Enabled, Aria2Enabled: req.Node.Aria2Enabled,
Aria2OptionsSerialized: req.Node.Aria2OptionsSerialized, Aria2OptionsSerialized: req.Node.Aria2OptionsSerialized,
@ -101,12 +105,14 @@ func (c *slaveController) SendAria2Notification(id string, msg common.StatusEven
if node, ok := c.masters[id]; ok { if node, ok := c.masters[id]; ok {
c.lock.RUnlock() c.lock.RUnlock()
apiPath, _ := url.Parse(fmt.Sprintf("/api/v3/slave/aria2/%s/%d", msg.GID, msg.Status))
res, err := c.client.Request( res, err := c.client.Request(
"PATCH", "PATCH",
fmt.Sprintf("/api/v3/slave/aria2/%s/%d", msg.GID, msg.Status), node.url.ResolveReference(apiPath).String(),
nil, nil,
request.WithHeader(http.Header{"X-Node-ID": []string{fmt.Sprintf("%d", node.slaveID)}}), request.WithHeader(http.Header{"X-Node-Id": []string{fmt.Sprintf("%d", node.slaveID)}}),
request.WithCredential(node.authClient, int64(node.ttl)), request.WithCredential(node.instance.GetAuthInstance(), int64(node.ttl)),
).CheckHTTPResponse(200).DecodeResponse() ).CheckHTTPResponse(200).DecodeResponse()
if err != nil { if err != nil {
return err return err

@ -99,7 +99,7 @@ func ListFinished(c *gin.Context) {
// TaskUpdate 被动更新任务状态 // TaskUpdate 被动更新任务状态
func TaskUpdate(c *gin.Context) { func TaskUpdate(c *gin.Context) {
var service aria2.DownloadTaskService var service aria2.DownloadTaskService
if err := c.ShouldBindQuery(&service); err == nil { if err := c.ShouldBindUri(&service); err == nil {
res := service.Notify() res := service.Notify()
c.JSON(200, res) c.JSON(200, res)
} else { } else {

@ -2,6 +2,7 @@ package routers
import ( import (
"github.com/cloudreve/Cloudreve/v3/middleware" "github.com/cloudreve/Cloudreve/v3/middleware"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/hashid" "github.com/cloudreve/Cloudreve/v3/pkg/hashid"
"github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/cloudreve/Cloudreve/v3/pkg/util"
@ -29,7 +30,7 @@ func InitSlaveRouter() *gin.Engine {
InitCORS(r) InitCORS(r)
v3 := r.Group("/api/v3/slave") v3 := r.Group("/api/v3/slave")
// 鉴权中间件 // 鉴权中间件
v3.Use(middleware.SignRequired()) v3.Use(middleware.SignRequired(auth.General))
// 主机信息解析 // 主机信息解析
v3.Use(middleware.MasterMetadata()) v3.Use(middleware.MasterMetadata())
@ -149,7 +150,7 @@ func InitMasterRouter() *gin.Engine {
user.PATCH("reset", controllers.UserReset) user.PATCH("reset", controllers.UserReset)
// 邮件激活 // 邮件激活
user.GET("activate/:id", user.GET("activate/:id",
middleware.SignRequired(), middleware.SignRequired(auth.General),
middleware.HashID(hashid.UserID), middleware.HashID(hashid.UserID),
controllers.UserActivate, controllers.UserActivate,
) )
@ -177,7 +178,7 @@ func InitMasterRouter() *gin.Engine {
// 需要携带签名验证的 // 需要携带签名验证的
sign := v3.Group("") sign := v3.Group("")
sign.Use(middleware.SignRequired()) sign.Use(middleware.SignRequired(auth.General))
{ {
file := sign.Group("file") file := sign.Group("file")
{ {
@ -194,6 +195,7 @@ func InitMasterRouter() *gin.Engine {
// 从机的 RPC 通信 // 从机的 RPC 通信
slave := v3.Group("slave") slave := v3.Group("slave")
slave.Use(middleware.SlaveRPCSignRequired())
{ {
slave.PATCH("aria2/:gid/:status", controllers.TaskUpdate) slave.PATCH("aria2/:gid/:status", controllers.TaskUpdate)
} }

@ -17,7 +17,7 @@ type SelectFileService struct {
// DownloadTaskService 下载任务管理服务 // DownloadTaskService 下载任务管理服务
type DownloadTaskService struct { type DownloadTaskService struct {
GID string `uri:"gid" binding:"required"` GID string `uri:"gid" binding:"required"`
Status int `uri:"gid"` Status int `uri:"status"`
} }
// DownloadListService 下载列表服务 // DownloadListService 下载列表服务

Loading…
Cancel
Save