From adaa3290dd75f867fa69031c3a6f87ee4b8e6447 Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Thu, 16 Sep 2021 21:19:07 +0800 Subject: [PATCH] Feat: slave transfer file in OneDrive policy --- models/policy.go | 8 +++ pkg/filesystem/driver/onedrive/handler.go | 11 ++++ pkg/filesystem/driver/onedrive/lock.go | 25 ++++++++++ pkg/filesystem/driver/onedrive/oauth.go | 23 ++++++++- pkg/filesystem/filesystem.go | 20 ++++---- pkg/request/request.go | 9 +++- pkg/slave/slave.go | 61 +++++++++++++++++------ pkg/task/slavetask/transfer.go | 2 +- routers/controllers/slave.go | 11 ++++ routers/router.go | 3 ++ service/explorer/slave.go | 5 +- service/node/fabric.go | 29 ++++++++++- 12 files changed, 175 insertions(+), 32 deletions(-) create mode 100644 pkg/filesystem/driver/onedrive/lock.go diff --git a/models/policy.go b/models/policy.go index dfd068d9..e9a3d6ea 100644 --- a/models/policy.go +++ b/models/policy.go @@ -37,6 +37,7 @@ type Policy struct { // 数据库忽略字段 OptionsSerialized PolicyOption `gorm:"-"` + MasterID string `gorm:"-"` } // PolicyOption 非公有的存储策略属性 @@ -277,6 +278,13 @@ func (policy *Policy) SaveAndClearCache() error { return err } +// SaveAndClearCache 更新并清理缓存 +func (policy *Policy) UpdateAccessKeyAndClearCache(s string) error { + err := DB.Model(policy).UpdateColumn("access_key", s).Error + policy.ClearCache() + return err +} + // ClearCache 清空policy缓存 func (policy *Policy) ClearCache() { cache.Deletes([]string{strconv.FormatUint(uint64(policy.ID), 10)}, "policy_") diff --git a/pkg/filesystem/driver/onedrive/handler.go b/pkg/filesystem/driver/onedrive/handler.go index 08207642..609fee7c 100644 --- a/pkg/filesystem/driver/onedrive/handler.go +++ b/pkg/filesystem/driver/onedrive/handler.go @@ -14,6 +14,7 @@ import ( model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/cache" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" "github.com/cloudreve/Cloudreve/v3/pkg/request" @@ -27,6 +28,16 @@ type Driver struct { HTTPClient request.Client } +// NewDriver 从存储策略初始化新的Driver实例 +func NewDriver(policy *model.Policy) (driver.Handler, error) { + client, err := NewClient(policy) + return Driver{ + Policy: policy, + Client: client, + HTTPClient: request.NewClient(), + }, err +} + // List 列取项目 func (handler Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) { base = strings.TrimPrefix(base, "/") diff --git a/pkg/filesystem/driver/onedrive/lock.go b/pkg/filesystem/driver/onedrive/lock.go new file mode 100644 index 00000000..655936bd --- /dev/null +++ b/pkg/filesystem/driver/onedrive/lock.go @@ -0,0 +1,25 @@ +package onedrive + +import "sync" + +// CredentialLock 针对存储策略凭证的锁 +type CredentialLock interface { + Lock(uint) + Unlock(uint) +} + +var GlobalMutex = mutexMap{} + +type mutexMap struct { + locks sync.Map +} + +func (m *mutexMap) Lock(id uint) { + lock, _ := m.locks.LoadOrStore(id, &sync.Mutex{}) + lock.(*sync.Mutex).Lock() +} + +func (m *mutexMap) Unlock(id uint) { + lock, _ := m.locks.LoadOrStore(id, &sync.Mutex{}) + lock.(*sync.Mutex).Unlock() +} diff --git a/pkg/filesystem/driver/onedrive/oauth.go b/pkg/filesystem/driver/onedrive/oauth.go index 9b33d7a4..49170fe7 100644 --- a/pkg/filesystem/driver/onedrive/oauth.go +++ b/pkg/filesystem/driver/onedrive/oauth.go @@ -10,7 +10,9 @@ import ( "time" "github.com/cloudreve/Cloudreve/v3/pkg/cache" + "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/request" + "github.com/cloudreve/Cloudreve/v3/pkg/slave" "github.com/cloudreve/Cloudreve/v3/pkg/util" ) @@ -124,6 +126,13 @@ func (client *Client) ObtainToken(ctx context.Context, opts ...Option) (*Credent // UpdateCredential 更新凭证,并检查有效期 func (client *Client) UpdateCredential(ctx context.Context) error { + if conf.SystemConfig.Mode == "slave" { + return client.fetchCredentialFromMaster(ctx) + } + + GlobalMutex.Lock(client.Policy.ID) + defer GlobalMutex.Unlock(client.Policy.ID) + // 如果已存在凭证 if client.Credential != nil && client.Credential.AccessToken != "" { // 检查已有凭证是否过期 @@ -160,11 +169,21 @@ func (client *Client) UpdateCredential(ctx context.Context) error { client.Credential = credential // 更新存储策略的 RefreshToken - client.Policy.AccessKey = credential.RefreshToken - client.Policy.SaveAndClearCache() + client.Policy.UpdateAccessKeyAndClearCache(credential.RefreshToken) // 更新缓存 cache.Set("onedrive_"+client.ClientID, *credential, int(expires)) return nil } + +// UpdateCredential 更新凭证,并检查有效期 +func (client *Client) fetchCredentialFromMaster(ctx context.Context) error { + res, err := slave.DefaultController.GetOneDriveToken(client.Policy.MasterID, client.Policy.ID) + if err != nil { + return err + } + + client.Credential = &Credential{AccessToken: res} + return nil +} diff --git a/pkg/filesystem/filesystem.go b/pkg/filesystem/filesystem.go index 073a2e4b..7f176c0c 100644 --- a/pkg/filesystem/filesystem.go +++ b/pkg/filesystem/filesystem.go @@ -176,13 +176,9 @@ func (fs *FileSystem) DispatchHandler() error { } return nil case "onedrive": - client, err := onedrive.NewClient(currentPolicy) - fs.Handler = onedrive.Driver{ - Policy: currentPolicy, - Client: client, - HTTPClient: request.NewClient(), - } - return err + var odErr error + fs.Handler, odErr = onedrive.NewDriver(currentPolicy) + return odErr case "cos": u, _ := url.Parse(currentPolicy.Server) b := &cossdk.BaseURL{BucketURL: u} @@ -249,17 +245,19 @@ func (fs *FileSystem) SwitchToSlaveHandler(node cluster.Node) { } // SwitchToShadowHandler 将负责上传的 Handler 切换为从机节点转存使用的影子处理器 -func (fs *FileSystem) SwitchToShadowHandler(master cluster.Node, masterURL string) { - // 交换主从存储策略 - if fs.Policy.Type == "remote" { +func (fs *FileSystem) SwitchToShadowHandler(master cluster.Node, masterURL, masterID string) { + switch fs.Policy.Type { + case "remote": fs.Policy.Type = "local" fs.DispatchHandler() - } else if fs.Policy.Type == "local" { + case "local": fs.Policy.Type = "remote" fs.Policy.Server = masterURL fs.Policy.AccessKey = fmt.Sprintf("%d", master.ID()) fs.Policy.SecretKey = master.DBModel().MasterKey fs.DispatchHandler() + case "onedrive": + fs.Policy.MasterID = masterID } fs.Handler = masterinslave.NewDriver(master, fs.Handler, fs.Policy) diff --git a/pkg/request/request.go b/pkg/request/request.go index dabf48b6..c543c2f8 100644 --- a/pkg/request/request.go +++ b/pkg/request/request.go @@ -112,7 +112,14 @@ func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Optio // 签名请求 if options.sign != nil { - auth.SignRequest(options.sign, req, options.signTTL) + switch method { + case "PUT", "POST", "PATCH": + auth.SignRequest(options.sign, req, options.signTTL) + default: + if resURL, err := auth.SignURI(options.sign, req.URL.String(), options.signTTL); err == nil { + req.URL = resURL + } + } } // 发送请求 diff --git a/pkg/slave/slave.go b/pkg/slave/slave.go index b38a0e55..aa457b5b 100644 --- a/pkg/slave/slave.go +++ b/pkg/slave/slave.go @@ -7,11 +7,11 @@ import ( model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" + "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/task" "github.com/jinzhu/gorm" "net/url" "sync" @@ -31,15 +31,17 @@ type Controller interface { SendNotification(string, string, mq.Message) error // Submit async task into task pool - SubmitTask(string, task.Job, string) error + SubmitTask(string, interface{}, string, func(interface{})) error // Get master node info GetMasterInfo(string) (*MasterInfo, error) + + // Get master OneDrive policy credential + GetOneDriveToken(string, uint) (string, error) } type slaveController struct { masters map[string]MasterInfo - client request.Client lock sync.RWMutex } @@ -50,6 +52,7 @@ type MasterInfo struct { URL *url.URL // used to invoke aria2 rpc calls Instance cluster.Node + Client request.Client jobTracker map[string]bool } @@ -57,7 +60,6 @@ type MasterInfo struct { func Init() { DefaultController = &slaveController{ masters: make(map[string]MasterInfo), - client: request.NewClient(), } gob.Register(rpc.StatusInfo{}) } @@ -82,9 +84,16 @@ func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializ } c.masters[req.SiteID] = MasterInfo{ - ID: req.SiteID, - URL: masterUrl, - TTL: req.CredentialTTL, + ID: req.SiteID, + URL: masterUrl, + TTL: req.CredentialTTL, + Client: request.NewClient( + request.WithEndpoint(masterUrl.String()), + request.WithSlaveMeta(fmt.Sprintf("%d", req.Node.ID)), + request.WithCredential(auth.HMACAuth{ + SecretKey: []byte(req.Node.MasterKey), + }, int64(req.CredentialTTL)), + ), jobTracker: make(map[string]bool), Instance: cluster.NewNodeFromDBModel(&model.Node{ Model: gorm.Model{ID: req.Node.ID}, @@ -116,19 +125,16 @@ func (c *slaveController) SendNotification(id, subject string, msg mq.Message) e if node, ok := c.masters[id]; ok { c.lock.RUnlock() - apiPath, _ := url.Parse(fmt.Sprintf("/api/v3/slave/notification/%s", subject)) body := bytes.Buffer{} enc := gob.NewEncoder(&body) if err := enc.Encode(&msg); err != nil { return err } - res, err := c.client.Request( + res, err := node.Client.Request( "PUT", - node.URL.ResolveReference(apiPath).String(), + fmt.Sprintf("/api/v3/slave/notification/%s", subject), &body, - request.WithSlaveMeta(fmt.Sprintf("%d", node.Instance.ID())), - request.WithCredential(node.Instance.MasterAuthInstance(), int64(node.TTL)), ).CheckHTTPResponse(200).DecodeResponse() if err != nil { return err @@ -146,7 +152,7 @@ func (c *slaveController) SendNotification(id, subject string, msg mq.Message) e } // SubmitTask 提交异步任务 -func (c *slaveController) SubmitTask(id string, job task.Job, hash string) error { +func (c *slaveController) SubmitTask(id string, job interface{}, hash string, submitter func(interface{})) error { c.lock.RLock() defer c.lock.RUnlock() @@ -156,7 +162,7 @@ func (c *slaveController) SubmitTask(id string, job task.Job, hash string) error return nil } - task.TaskPoll.Submit(job) + submitter(job) return nil } @@ -174,3 +180,30 @@ func (c *slaveController) GetMasterInfo(id string) (*MasterInfo, error) { return nil, ErrMasterNotFound } + +// GetOneDriveToken 获取主机OneDrive凭证 +func (c *slaveController) GetOneDriveToken(id string, policyID uint) (string, error) { + c.lock.RLock() + + if node, ok := c.masters[id]; ok { + c.lock.RUnlock() + + res, err := node.Client.Request( + "GET", + fmt.Sprintf("/api/v3/slave/credential/onedrive/%d", policyID), + nil, + ).CheckHTTPResponse(200).DecodeResponse() + if err != nil { + return "", err + } + + if res.Code != 0 { + return "", serializer.NewErrorFromResponse(res) + } + + return res.Data.(string), nil + } + + c.lock.RUnlock() + return "", ErrMasterNotFound +} diff --git a/pkg/task/slavetask/transfer.go b/pkg/task/slavetask/transfer.go index 8db10f93..7aecd85f 100644 --- a/pkg/task/slavetask/transfer.go +++ b/pkg/task/slavetask/transfer.go @@ -97,7 +97,7 @@ func (job *TransferTask) Do() { return } - fs.SwitchToShadowHandler(master.Instance, master.URL.String()) + fs.SwitchToShadowHandler(master.Instance, master.URL.String(), master.ID) ctx := context.WithValue(context.Background(), fsctx.DisableOverwrite, true) file, err := os.Open(util.RelativePath(job.Req.Src)) if err != nil { diff --git a/routers/controllers/slave.go b/routers/controllers/slave.go index cc64bf62..267ae409 100644 --- a/routers/controllers/slave.go +++ b/routers/controllers/slave.go @@ -254,3 +254,14 @@ func SlaveNotificationPush(c *gin.Context) { c.JSON(200, ErrorResponse(err)) } } + +// SlaveGetOneDriveCredential 从机获取主机的OneDrive存储策略凭证 +func SlaveGetOneDriveCredential(c *gin.Context) { + var service node.OneDriveCredentialService + if err := c.ShouldBindUri(&service); err == nil { + res := service.Get(c) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} diff --git a/routers/router.go b/routers/router.go index 9110bdb5..78bd96a2 100644 --- a/routers/router.go +++ b/routers/router.go @@ -203,9 +203,12 @@ func InitMasterRouter() *gin.Engine { slave := v3.Group("slave") slave.Use(middleware.SlaveRPCSignRequired()) { + // 事件通知 slave.PUT("notification/:subject", controllers.SlaveNotificationPush) // 上传 slave.POST("upload", controllers.SlaveUpload) + // OneDrive 存储策略凭证 + slave.GET("credential/onedrive/:id", controllers.SlaveGetOneDriveCredential) } // 回调接口 diff --git a/service/explorer/slave.go b/service/explorer/slave.go index 56fe3bfd..8beb15b8 100644 --- a/service/explorer/slave.go +++ b/service/explorer/slave.go @@ -10,6 +10,7 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/slave" + "github.com/cloudreve/Cloudreve/v3/pkg/task" "github.com/cloudreve/Cloudreve/v3/pkg/task/slavetask" "github.com/gin-gonic/gin" "github.com/jinzhu/gorm" @@ -152,7 +153,9 @@ func CreateTransferTask(c *gin.Context, req *serializer.SlaveTransferReq) serial MasterID: id.(string), } - if err := slave.DefaultController.SubmitTask(job.MasterID, job, req.Hash(job.MasterID)); err != nil { + if err := slave.DefaultController.SubmitTask(job.MasterID, job, req.Hash(job.MasterID), func(job interface{}) { + task.TaskPoll.Submit(job.(task.Job)) + }); err != nil { return serializer.Err(serializer.CodeInternalSetting, "任务创建失败", err) } diff --git a/service/node/fabric.go b/service/node/fabric.go index 6904bb0d..79dfb29d 100644 --- a/service/node/fabric.go +++ b/service/node/fabric.go @@ -2,6 +2,8 @@ package node import ( "encoding/gob" + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/onedrive" "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/slave" @@ -12,10 +14,14 @@ type SlaveNotificationService struct { Subject string `uri:"subject" binding:"required"` } +type OneDriveCredentialService struct { + PolicyID uint `uri:"id" binding:"required"` +} + func HandleMasterHeartbeat(req *serializer.NodePingReq) serializer.Response { res, err := slave.DefaultController.HandleHeartBeat(req) if err != nil { - return serializer.Err(serializer.CodeInternalSetting, "无法初始化从机控制器", err) + return serializer.Err(serializer.CodeInternalSetting, "Cannot initialize slave controller", err) } return serializer.Response{ @@ -29,9 +35,28 @@ func (s *SlaveNotificationService) HandleSlaveNotificationPush(c *gin.Context) s var msg mq.Message dec := gob.NewDecoder(c.Request.Body) if err := dec.Decode(&msg); err != nil { - return serializer.ParamErr("无法解析通知消息", err) + return serializer.ParamErr("Cannot parse notification message", err) } mq.GlobalMQ.Publish(s.Subject, msg) return serializer.Response{} } + +// Get 获取主机OneDrive策略的AccessToken +func (s *OneDriveCredentialService) Get(c *gin.Context) serializer.Response { + policy, err := model.GetPolicyByID(s.PolicyID) + if err != nil { + return serializer.Err(serializer.CodeNotFound, "Cannot found storage policy", err) + } + + client, err := onedrive.NewClient(&policy) + if err != nil { + return serializer.Err(serializer.CodeInternalSetting, "Cannot initialize OneDrive client", err) + } + + if err := client.UpdateCredential(c); err != nil { + return serializer.Err(serializer.CodeInternalSetting, "Cannot refresh OneDrive credential", err) + } + + return serializer.Response{Data: client.Credential.AccessToken} +}