From 32b88e989d328c6a2b7afeb5cf6e7bd25a6355c8 Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Sat, 21 Aug 2021 11:06:53 +0800 Subject: [PATCH] Feat: call slave aria2 rpc method from master --- middleware/cluster.go | 15 +++++ models/migration.go | 4 ++ pkg/auth/auth.go | 2 + pkg/cluster/master.go | 5 +- pkg/cluster/slave.go | 113 ++++++++++++++++++++++++++++++++--- pkg/request/request.go | 1 + pkg/serializer/error.go | 18 +++++- pkg/serializer/slave.go | 14 ++++- pkg/slave/errors.go | 7 +++ pkg/slave/slave.go | 23 +++++-- routers/controllers/slave.go | 4 +- routers/router.go | 4 +- service/aria2/add.go | 29 ++++++--- 13 files changed, 207 insertions(+), 32 deletions(-) create mode 100644 middleware/cluster.go create mode 100644 pkg/slave/errors.go diff --git a/middleware/cluster.go b/middleware/cluster.go new file mode 100644 index 00000000..52ff1acf --- /dev/null +++ b/middleware/cluster.go @@ -0,0 +1,15 @@ +package middleware + +import ( + "github.com/gin-gonic/gin" +) + +// MasterMetadata 解析主机节点发来请求的包含主机节点信息的元数据 +func MasterMetadata() gin.HandlerFunc { + return func(c *gin.Context) { + c.Set("MasterSiteID", c.GetHeader("X-Site-ID")) + c.Set("MasterSiteURL", c.GetHeader("X-Site-Ur")) + c.Set("MasterVersion", c.GetHeader("X-Cloudreve-Version")) + c.Next() + } +} diff --git a/models/migration.go b/models/migration.go index 9bbe3a95..3e869ef1 100644 --- a/models/migration.go +++ b/models/migration.go @@ -5,6 +5,7 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/fatih/color" + "github.com/gofrs/uuid" "github.com/jinzhu/gorm" ) @@ -74,6 +75,8 @@ func addDefaultPolicy() { } func addDefaultSettings() { + siteID, _ := uuid.NewV4() + defaultSettings := []Setting{ {Name: "siteURL", Value: `http://localhost`, Type: "basic"}, {Name: "siteName", Value: `Cloudreve`, Type: "basic"}, @@ -84,6 +87,7 @@ func addDefaultSettings() { {Name: "siteDes", Value: `Cloudreve`, Type: "basic"}, {Name: "siteTitle", Value: `平步云端`, Type: "basic"}, {Name: "siteScript", Value: ``, Type: "basic"}, + {Name: "siteID", Value: siteID.String(), Type: "basic"}, {Name: "fromName", Value: `Cloudreve`, Type: "mail"}, {Name: "mail_keepalive", Value: `30`, Type: "mail"}, {Name: "fromAdress", Value: `no-reply@acg.blue`, Type: "mail"}, diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index fad861fc..358e3914 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -6,6 +6,7 @@ import ( "io/ioutil" "net/http" "net/url" + "sort" "strings" "time" @@ -81,6 +82,7 @@ func getSignContent(r *http.Request) (rawSignString string) { signedHeader = append(signedHeader, fmt.Sprintf("%s=%s", k, r.Header.Get(k))) } } + sort.Strings(signedHeader) // 读取所有待签名Header rawSignString = serializer.NewRequestSignString(r.URL.Path, strings.Join(signedHeader, "&"), string(body)) diff --git a/pkg/cluster/master.go b/pkg/cluster/master.go index 9389af20..ec303d6f 100644 --- a/pkg/cluster/master.go +++ b/pkg/cluster/master.go @@ -93,12 +93,13 @@ func (node *MasterNode) Kill() { // GetAria2Instance 获取主机Aria2实例 func (node *MasterNode) GetAria2Instance() common.Aria2 { + node.lock.RLock() + defer node.lock.RUnlock() + if !node.Model.Aria2Enabled { return &common.DummyAria2{} } - node.lock.RLock() - defer node.lock.RUnlock() if !node.aria2RPC.Initialized { return &common.DummyAria2{} } diff --git a/pkg/cluster/slave.go b/pkg/cluster/slave.go index 8cdba660..7398fa8e 100644 --- a/pkg/cluster/slave.go +++ b/pkg/cluster/slave.go @@ -2,13 +2,14 @@ package cluster import ( "encoding/json" - "errors" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/util" + "io" "net/url" "path" "strings" @@ -19,20 +20,26 @@ import ( type SlaveNode struct { Model *model.Node AuthInstance auth.Auth - Client request.Client Active bool + caller slaveCaller callback func(bool, uint) close chan bool lock sync.RWMutex } +type slaveCaller struct { + parent *SlaveNode + Client request.Client +} + // Init 初始化节点 func (node *SlaveNode) Init(nodeModel *model.Node) { node.lock.Lock() node.Model = nodeModel node.AuthInstance = auth.HMACAuth{SecretKey: []byte(nodeModel.SlaveKey)} - node.Client = request.HTTPClient{} + node.caller.Client = request.HTTPClient{} + node.caller.parent = node node.Active = true if node.close != nil { node.close <- true @@ -44,7 +51,12 @@ func (node *SlaveNode) Init(nodeModel *model.Node) { // IsFeatureEnabled 查询节点的某项功能是否启用 func (node *SlaveNode) IsFeatureEnabled(feature string) bool { + node.lock.RLock() + defer node.lock.RUnlock() + switch feature { + case "aria2": + return node.Model.Aria2Enabled default: return false } @@ -67,10 +79,12 @@ func (node *SlaveNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingRe bodyReader := strings.NewReader(string(reqBodyEncoded)) signTTL := model.GetIntSetting("slave_api_timeout", 60) - resp, err := node.Client.Request( + resp, err := node.caller.Client.Request( "POST", node.getAPIUrl("heartbeat"), bodyReader, + request.WithMasterMeta(), + request.WithTimeout(time.Duration(signTTL)*time.Second), request.WithCredential(node.AuthInstance, int64(signTTL)), ).CheckHTTPResponse(200).DecodeResponse() if err != nil { @@ -79,7 +93,7 @@ func (node *SlaveNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingRe // 处理列取结果 if resp.Code != 0 { - return nil, errors.New(resp.Error) + return nil, serializer.NewErrorFromResponse(resp) } var res serializer.NodePingResp @@ -96,6 +110,9 @@ func (node *SlaveNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingRe // IsActive 返回节点是否在线 func (node *SlaveNode) IsActive() bool { + node.lock.RLock() + defer node.lock.RUnlock() + return node.Active } @@ -111,7 +128,14 @@ func (node *SlaveNode) Kill() { // GetAria2Instance 获取从机Aria2实例 func (node *SlaveNode) GetAria2Instance() common.Aria2 { - return nil + node.lock.RLock() + defer node.lock.RUnlock() + + if !node.Model.Aria2Enabled { + return &common.DummyAria2{} + } + + return &node.caller } func (node *SlaveNode) ID() uint { @@ -210,8 +234,79 @@ loop: // getHeartbeatContent gets serializer.NodePingReq used to send heartbeat to slave func (node *SlaveNode) getHeartbeatContent(isUpdate bool) *serializer.NodePingReq { return &serializer.NodePingReq{ - IsUpdate: isUpdate, - MasterURL: model.GetSiteURL().String(), - Node: node.Model, + IsUpdate: isUpdate, + SiteID: model.GetSettingByName("siteID"), + Node: node.Model, } } + +func (s *slaveCaller) Init() error { + return nil +} + +// SendAria2Call send remote aria2 call to slave node +func (s *slaveCaller) SendAria2Call(body *serializer.SlaveAria2Call, scope string) (*serializer.Response, error) { + reqReader, err := getAria2RequestBody(body) + if err != nil { + return nil, err + } + + signTTL := model.GetIntSetting("slave_api_timeout", 60) + return s.Client.Request( + "POST", + s.parent.getAPIUrl("aria2/"+scope), + reqReader, + request.WithMasterMeta(), + request.WithTimeout(time.Duration(signTTL)*time.Second), + request.WithCredential(s.parent.AuthInstance, int64(signTTL)), + ).CheckHTTPResponse(200).DecodeResponse() +} + +func (s *slaveCaller) CreateTask(task *model.Download, options map[string]interface{}) (string, error) { + s.parent.lock.RLock() + defer s.parent.lock.RUnlock() + + req := &serializer.SlaveAria2Call{ + Task: task, + GroupOptions: options, + } + + res, err := s.SendAria2Call(req, "task") + if err != nil { + return "", err + } + + if res.Code != 0 { + return "", serializer.NewErrorFromResponse(res) + } + + return res.Data.(string), err +} + +func (s *slaveCaller) Status(task *model.Download) (rpc.StatusInfo, error) { + panic("implement me") +} + +func (s *slaveCaller) Cancel(task *model.Download) error { + panic("implement me") +} + +func (s *slaveCaller) Select(task *model.Download, files []int) error { + panic("implement me") +} + +func (s *slaveCaller) GetConfig() model.Aria2Option { + s.parent.lock.RLock() + defer s.parent.lock.RUnlock() + + return s.parent.Model.Aria2OptionsSerialized +} + +func getAria2RequestBody(body *serializer.SlaveAria2Call) (io.Reader, error) { + reqBodyEncoded, err := json.Marshal(body) + if err != nil { + return nil, err + } + + return strings.NewReader(string(reqBodyEncoded)), nil +} diff --git a/pkg/request/request.go b/pkg/request/request.go index d6f93afa..98a1baa4 100644 --- a/pkg/request/request.go +++ b/pkg/request/request.go @@ -154,6 +154,7 @@ func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Optio if options.masterMeta { req.Header.Add("X-Site-Url", model.GetSiteURL().String()) + req.Header.Add("X-Site-ID", model.GetSettingByName("siteID")) req.Header.Add("X-Cloudreve-Version", conf.BackendVersion) } diff --git a/pkg/serializer/error.go b/pkg/serializer/error.go index 0191ceeb..37e70c69 100644 --- a/pkg/serializer/error.go +++ b/pkg/serializer/error.go @@ -1,6 +1,9 @@ package serializer -import "github.com/gin-gonic/gin" +import ( + "errors" + "github.com/gin-gonic/gin" +) // Response 基础序列化器 type Response struct { @@ -17,7 +20,7 @@ type AppError struct { RawError error } -// NewError 返回新的错误对象 todo:测试 还有下面的 +// NewError 返回新的错误对象 func NewError(code int, msg string, err error) AppError { return AppError{ Code: code, @@ -26,6 +29,15 @@ func NewError(code int, msg string, err error) AppError { } } +// NewErrorFromResponse 从 serializer.Response 构建错误 +func NewErrorFromResponse(resp *Response) AppError { + return AppError{ + Code: resp.Code, + Msg: resp.Msg, + RawError: errors.New(resp.Error), + } +} + // WithError 将应用error携带标准库中的error func (err *AppError) WithError(raw error) AppError { err.RawError = raw @@ -66,6 +78,8 @@ const ( CodeGroupNotAllowed = 40007 // CodeAdminRequired 非管理用户组 CodeAdminRequired = 40008 + // CodeMasterNotFound 主机节点未注册 + CodeMasterNotFound = 40009 // CodeDBError 数据库操作失败 CodeDBError = 50001 // CodeEncryptError 加密失败 diff --git a/pkg/serializer/slave.go b/pkg/serializer/slave.go index 5d19adbf..af07c116 100644 --- a/pkg/serializer/slave.go +++ b/pkg/serializer/slave.go @@ -15,11 +15,19 @@ type ListRequest struct { // NodePingReq 从机节点Ping请求 type NodePingReq struct { - MasterURL string `json:"master_url"` - IsUpdate bool `json:"is_update"` - Node *model.Node `json:"node"` + SiteURL string `json:"site_url"` + SiteID string `json:"site_id"` + IsUpdate bool `json:"is_update"` + Node *model.Node `json:"node"` } // NodePingResp 从机节点Ping响应 type NodePingResp struct { } + +// SlaveAria2Call 从机有关Aria2的请求正文 +type SlaveAria2Call struct { + Task *model.Download `json:"task"` + GroupOptions map[string]interface{} `json:"group_options"` + Files []uint `json:"files"` +} diff --git a/pkg/slave/errors.go b/pkg/slave/errors.go new file mode 100644 index 00000000..2af6e13f --- /dev/null +++ b/pkg/slave/errors.go @@ -0,0 +1,7 @@ +package slave + +import "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + +var ( + ErrMasterNotFound = serializer.NewError(serializer.CodeMasterNotFound, "未知的主机节点", nil) +) diff --git a/pkg/slave/slave.go b/pkg/slave/slave.go index af900aed..55296663 100644 --- a/pkg/slave/slave.go +++ b/pkg/slave/slave.go @@ -2,6 +2,7 @@ package slave import ( model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" @@ -14,6 +15,9 @@ var DefaultController Controller type Controller interface { // Handle heartbeat sent from master HandleHeartBeat(*serializer.NodePingReq) (serializer.NodePingResp, error) + + // Get Aria2 instance by master node id + GetAria2Instance(string) (common.Aria2, error) } type slaveController struct { @@ -24,7 +28,7 @@ type slaveController struct { // info of master node type masterInfo struct { slaveID uint - url string + id string authClient auth.Auth // used to invoke aria2 rpc calls instance cluster.Node @@ -43,16 +47,16 @@ func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializ req.Node.AfterFind() // close old node if exist - origin, ok := c.masters[req.MasterURL] + origin, ok := c.masters[req.SiteID] if (ok && req.IsUpdate) || !ok { if ok { origin.instance.Kill() } - c.masters[req.MasterURL] = masterInfo{ + c.masters[req.SiteID] = masterInfo{ slaveID: req.Node.ID, - url: req.MasterURL, + id: req.SiteID, authClient: auth.HMACAuth{ SecretKey: []byte(req.Node.MasterKey), }, @@ -66,3 +70,14 @@ func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializ return serializer.NodePingResp{}, nil } + +func (c *slaveController) GetAria2Instance(id string) (common.Aria2, error) { + c.lock.RLock() + defer c.lock.RUnlock() + + if node, ok := c.masters[id]; ok { + return node.instance.GetAria2Instance(), nil + } + + return nil, ErrMasterNotFound +} diff --git a/routers/controllers/slave.go b/routers/controllers/slave.go index 464092dd..1c282244 100644 --- a/routers/controllers/slave.go +++ b/routers/controllers/slave.go @@ -191,9 +191,9 @@ func SlaveHeartbeat(c *gin.Context) { // SlaveAria2Create 创建 Aria2 任务 func SlaveAria2Create(c *gin.Context) { - var service aria2.SlaveAria2Call + var service serializer.SlaveAria2Call if err := c.ShouldBindJSON(&service); err == nil { - res := service.Add(c) + res := aria2.Add(c, &service) c.JSON(200, res) } else { c.JSON(200, ErrorResponse(err)) diff --git a/routers/router.go b/routers/router.go index 2656e422..2ec2f049 100644 --- a/routers/router.go +++ b/routers/router.go @@ -30,6 +30,8 @@ func InitSlaveRouter() *gin.Engine { v3 := r.Group("/api/v3/slave") // 鉴权中间件 v3.Use(middleware.SignRequired()) + // 主机信息解析 + v3.Use(middleware.MasterMetadata()) /* 路由 @@ -55,7 +57,7 @@ func InitSlaveRouter() *gin.Engine { // 离线下载 aria2 := v3.Group("aria2") { - aria2.POST("task", controllers.SlaveList) + aria2.POST("task", controllers.SlaveAria2Create) } } return r diff --git a/service/aria2/add.go b/service/aria2/add.go index fa48d439..75e8d20f 100644 --- a/service/aria2/add.go +++ b/service/aria2/add.go @@ -8,6 +8,7 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/cloudreve/Cloudreve/v3/pkg/slave" "github.com/gin-gonic/gin" ) @@ -17,13 +18,6 @@ type AddURLService struct { Dst string `json:"dst" binding:"required,min=1"` } -// SlaveAria2Call 从机有关Aria2的请求正文 -type SlaveAria2Call struct { - Task *model.Download `json:"task"` - GroupOptions map[string]interface{} `json:"group_options"` - Files []uint `json:"files"` -} - // Add 主机创建新的链接离线下载任务 func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Response { // 创建文件系统 @@ -83,6 +77,23 @@ func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Respo } // Add 从机创建新的链接离线下载任务 -func (service *SlaveAria2Call) Add(c *gin.Context) serializer.Response { - return serializer.Response{} +func Add(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response { + if siteID, exist := c.Get("MasterSiteID"); exist { + // 获取对应主机节点的从机Aria2实例 + caller, err := slave.DefaultController.GetAria2Instance(siteID.(string)) + if err != nil { + return serializer.Err(serializer.CodeNotSet, "无法获取 Aria2 实例", err) + } + + // 创建任务 + gid, err := caller.CreateTask(service.Task, service.GroupOptions) + if err != nil { + return serializer.Err(serializer.CodeInternalSetting, "无法创建离线下载任务", err) + } + + // TODO: 创建监控 + return serializer.Response{Data: gid} + } + + return serializer.ParamErr("未知的主机节点ID", nil) }