diff --git a/assets b/assets index 59890e6..8a61a8e 160000 --- a/assets +++ b/assets @@ -1 +1 @@ -Subproject commit 59890e6b22d69befa8b742a64967b6bab1bb4a3d +Subproject commit 8a61a8e4c238ed60a107ace23717cf8f03f957f6 diff --git a/bootstrap/app.go b/bootstrap/app.go index 093e3f4..d327b2a 100644 --- a/bootstrap/app.go +++ b/bootstrap/app.go @@ -34,7 +34,7 @@ type GitHubRelease struct { // CheckUpdate 检查更新 func CheckUpdate() { - client := request.HTTPClient{} + client := request.NewClient() res, err := client.Request("GET", "https://api.github.com/repos/cloudreve/cloudreve/releases", nil).GetResponse() if err != nil { util.Log().Warning("更新检查失败, %s", err) diff --git a/bootstrap/init.go b/bootstrap/init.go index 98f2c8d..6b43adb 100644 --- a/bootstrap/init.go +++ b/bootstrap/init.go @@ -5,9 +5,11 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/aria2" "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/cache" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/crontab" "github.com/cloudreve/Cloudreve/v3/pkg/email" + "github.com/cloudreve/Cloudreve/v3/pkg/slave" "github.com/cloudreve/Cloudreve/v3/pkg/task" "github.com/gin-gonic/gin" ) @@ -20,14 +22,86 @@ func Init(path string) { if !conf.SystemConfig.Debug { gin.SetMode(gin.ReleaseMode) } - cache.Init() - if conf.SystemConfig.Mode == "master" { - model.Init() - task.Init() - aria2.Init(false) - email.Init() - crontab.Init() - InitStatic() + + dependencies := []struct { + mode string + factory func() + }{ + { + "both", + func() { + cache.Init() + }, + }, + { + "master", + func() { + model.Init() + }, + }, + { + "both", + func() { + task.Init() + }, + }, + { + "master", + func() { + cluster.Init() + }, + }, + { + "master", + func() { + aria2.Init(false) + }, + }, + { + "master", + func() { + email.Init() + }, + }, + { + "master", + func() { + crontab.Init() + }, + }, + { + "master", + func() { + InitStatic() + }, + }, + { + "slave", + func() { + slave.Init() + }, + }, + { + "both", + func() { + auth.Init() + }, + }, + } + + for _, dependency := range dependencies { + switch dependency.mode { + case "master": + if conf.SystemConfig.Mode == "master" { + dependency.factory() + } + case "slave": + if conf.SystemConfig.Mode == "slave" { + dependency.factory() + } + default: + dependency.factory() + } } - auth.Init() + } diff --git a/go.mod b/go.mod index e7233f5..64df0eb 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/gin-gonic/gin v1.5.0 github.com/go-ini/ini v1.50.0 github.com/go-mail/mail v2.3.1+incompatible + github.com/gofrs/uuid v4.0.0+incompatible github.com/gomodule/redigo v2.0.0+incompatible github.com/google/go-querystring v1.0.0 github.com/gorilla/websocket v1.4.1 diff --git a/go.sum b/go.sum index 280b0f8..8728064 100644 --- a/go.sum +++ b/go.sum @@ -77,6 +77,8 @@ github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= +github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g= diff --git a/middleware/auth.go b/middleware/auth.go index 69233ee..135dd8c 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -22,16 +22,14 @@ import ( ) // SignRequired 验证请求签名 -func SignRequired() gin.HandlerFunc { +func SignRequired(authInstance auth.Auth) gin.HandlerFunc { return func(c *gin.Context) { var err error switch c.Request.Method { - case "PUT", "POST": - err = auth.CheckRequest(auth.General, c.Request) - // TODO 生产环境去掉下一行 - //err = nil + case "PUT", "POST", "PATCH": + err = auth.CheckRequest(authInstance, c.Request) default: - err = auth.CheckURI(auth.General, c.Request.URL) + err = auth.CheckURI(authInstance, c.Request.URL) } if err != nil { diff --git a/middleware/auth_test.go b/middleware/auth_test.go index 95ab75c..a17602d 100644 --- a/middleware/auth_test.go +++ b/middleware/auth_test.go @@ -87,11 +87,10 @@ func TestAuthRequired(t *testing.T) { func TestSignRequired(t *testing.T) { asserts := assert.New(t) - auth.General = auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))} rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request, _ = http.NewRequest("GET", "/test", nil) - SignRequiredFunc := SignRequired() + SignRequiredFunc := SignRequired(auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}) // 鉴权失败 SignRequiredFunc(c) diff --git a/middleware/cluster.go b/middleware/cluster.go new file mode 100644 index 0000000..079a4f4 --- /dev/null +++ b/middleware/cluster.go @@ -0,0 +1,62 @@ +package middleware + +import ( + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/cloudreve/Cloudreve/v3/pkg/slave" + "github.com/gin-gonic/gin" + "strconv" +) + +// MasterMetadata 解析主机节点发来请求的包含主机节点信息的元数据 +func MasterMetadata() gin.HandlerFunc { + return func(c *gin.Context) { + c.Set("MasterSiteID", c.GetHeader("X-Site-Id")) + c.Set("MasterSiteURL", c.GetHeader("X-Site-Url")) + c.Set("MasterVersion", c.GetHeader("X-Cloudreve-Version")) + c.Next() + } +} + +// UseSlaveAria2Instance 从机用于获取对应主机节点的Aria2实例 +func UseSlaveAria2Instance() gin.HandlerFunc { + return func(c *gin.Context) { + if siteID, exist := c.Get("MasterSiteID"); exist { + // 获取对应主机节点的从机Aria2实例 + caller, err := slave.DefaultController.GetAria2Instance(siteID.(string)) + if err != nil { + c.JSON(200, serializer.Err(serializer.CodeNotSet, "无法获取 Aria2 实例", err)) + c.Abort() + return + } + + c.Set("MasterAria2Instance", caller) + c.Next() + return + } + + c.JSON(200, serializer.ParamErr("未知的主机节点ID", nil)) + c.Abort() + } +} + +func SlaveRPCSignRequired() gin.HandlerFunc { + return func(c *gin.Context) { + nodeID, err := strconv.ParseUint(c.GetHeader("X-Node-Id"), 10, 64) + if err != nil { + c.JSON(200, serializer.ParamErr("未知的主机节点ID", err)) + c.Abort() + return + } + + slaveNode := cluster.Default.GetNodeByID(uint(nodeID)) + if slaveNode == nil { + c.JSON(200, serializer.ParamErr("未知的主机节点ID", err)) + c.Abort() + return + } + + SignRequired(slaveNode.MasterAuthInstance())(c) + + } +} diff --git a/models/download.go b/models/download.go index 0989b14..40802ad 100644 --- a/models/download.go +++ b/models/download.go @@ -24,6 +24,7 @@ type Download struct { Dst string `gorm:"type:text"` // 用户文件系统存储父目录路径 UserID uint // 发起者UID TaskID uint // 对应的转存任务ID + NodeID uint // 处理任务的节点ID // 关联模型 User *User `gorm:"PRELOAD:false,association_autoupdate:false"` @@ -114,3 +115,13 @@ func (task *Download) GetOwner() *User { func (download *Download) Delete() error { return DB.Model(download).Delete(download).Error } + +// GetNodeID 返回任务所属节点ID +func (task *Download) GetNodeID() uint { + // 兼容3.4版本之前生成的下载记录 + if task.NodeID == 0 { + return 1 + } + + return task.NodeID +} diff --git a/models/migration.go b/models/migration.go index 42fb14d..277020d 100644 --- a/models/migration.go +++ b/models/migration.go @@ -5,6 +5,7 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/fatih/color" + "github.com/gofrs/uuid" "github.com/jinzhu/gorm" ) @@ -34,8 +35,9 @@ func migration() { if conf.DatabaseConfig.Type == "mysql" { DB = DB.Set("gorm:table_options", "ENGINE=InnoDB") } + DB.AutoMigrate(&User{}, &Setting{}, &Group{}, &Policy{}, &Folder{}, &File{}, &Share{}, - &Task{}, &Download{}, &Tag{}, &Webdav{}) + &Task{}, &Download{}, &Tag{}, &Webdav{}, &Node{}) // 创建初始存储策略 addDefaultPolicy() @@ -73,6 +75,8 @@ func addDefaultPolicy() { } func addDefaultSettings() { + siteID, _ := uuid.NewV4() + defaultSettings := []Setting{ {Name: "siteURL", Value: `http://localhost`, Type: "basic"}, {Name: "siteName", Value: `Cloudreve`, Type: "basic"}, @@ -83,6 +87,7 @@ func addDefaultSettings() { {Name: "siteDes", Value: `Cloudreve`, Type: "basic"}, {Name: "siteTitle", Value: `平步云端`, Type: "basic"}, {Name: "siteScript", Value: ``, Type: "basic"}, + {Name: "siteID", Value: siteID.String(), Type: "basic"}, {Name: "fromName", Value: `Cloudreve`, Type: "mail"}, {Name: "mail_keepalive", Value: `30`, Type: "mail"}, {Name: "fromAdress", Value: `no-reply@acg.blue`, Type: "mail"}, @@ -100,10 +105,13 @@ func addDefaultSettings() { {Name: "upload_credential_timeout", Value: `1800`, Type: "timeout"}, {Name: "upload_session_timeout", Value: `86400`, Type: "timeout"}, {Name: "slave_api_timeout", Value: `60`, Type: "timeout"}, + {Name: "slave_node_retry", Value: `3`, Type: "slave"}, + {Name: "slave_ping_interval", Value: `60`, Type: "slave"}, + {Name: "slave_recover_interval", Value: `120`, Type: "slave"}, + {Name: "slave_transfer_timeout", Value: `172800`, Type: "timeout"}, {Name: "onedrive_monitor_timeout", Value: `600`, Type: "timeout"}, {Name: "share_download_session_timeout", Value: `2073600`, Type: "timeout"}, {Name: "onedrive_callback_check", Value: `20`, Type: "timeout"}, - {Name: "aria2_call_timeout", Value: `5`, Type: "timeout"}, {Name: "folder_props_timeout", Value: `300`, Type: "timeout"}, {Name: "onedrive_chunk_retries", Value: `1`, Type: "retry"}, {Name: "onedrive_source_timeout", Value: `1800`, Type: "timeout"}, @@ -131,11 +139,6 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti {Name: "gravatar_server", Value: `https://www.gravatar.com/`, Type: "avatar"}, {Name: "defaultTheme", Value: `#3f51b5`, Type: "basic"}, {Name: "themes", Value: `{"#3f51b5":{"palette":{"primary":{"main":"#3f51b5"},"secondary":{"main":"#f50057"}}},"#2196f3":{"palette":{"primary":{"main":"#2196f3"},"secondary":{"main":"#FFC107"}}},"#673AB7":{"palette":{"primary":{"main":"#673AB7"},"secondary":{"main":"#2196F3"}}},"#E91E63":{"palette":{"primary":{"main":"#E91E63"},"secondary":{"main":"#42A5F5","contrastText":"#fff"}}},"#FF5722":{"palette":{"primary":{"main":"#FF5722"},"secondary":{"main":"#3F51B5"}}},"#FFC107":{"palette":{"primary":{"main":"#FFC107"},"secondary":{"main":"#26C6DA"}}},"#8BC34A":{"palette":{"primary":{"main":"#8BC34A","contrastText":"#fff"},"secondary":{"main":"#FF8A65","contrastText":"#fff"}}},"#009688":{"palette":{"primary":{"main":"#009688"},"secondary":{"main":"#4DD0E1","contrastText":"#fff"}}},"#607D8B":{"palette":{"primary":{"main":"#607D8B"},"secondary":{"main":"#F06292"}}},"#795548":{"palette":{"primary":{"main":"#795548"},"secondary":{"main":"#4CAF50","contrastText":"#fff"}}}}`, Type: "basic"}, - {Name: "aria2_token", Value: ``, Type: "aria2"}, - {Name: "aria2_rpcurl", Value: ``, Type: "aria2"}, - {Name: "aria2_temp_path", Value: ``, Type: "aria2"}, - {Name: "aria2_options", Value: `{}`, Type: "aria2"}, - {Name: "aria2_interval", Value: `60`, Type: "aria2"}, {Name: "max_worker_num", Value: `10`, Type: "task"}, {Name: "max_parallel_transfer", Value: `4`, Type: "task"}, {Name: "secret_key", Value: util.RandStringRunes(256), Type: "auth"}, diff --git a/models/node.go b/models/node.go new file mode 100644 index 0000000..992a828 --- /dev/null +++ b/models/node.go @@ -0,0 +1,91 @@ +package model + +import ( + "encoding/json" + "github.com/jinzhu/gorm" +) + +// Node 从机节点信息模型 +type Node struct { + gorm.Model + Status NodeStatus // 节点状态 + Name string // 节点别名 + Type ModelType // 节点状态 + Server string // 服务器地址 + SlaveKey string `gorm:"type:text"` // 主->从 通信密钥 + MasterKey string `gorm:"type:text"` // 从->主 通信密钥 + Aria2Enabled bool // 是否支持用作离线下载节点 + Aria2Options string `gorm:"type:text"` // 离线下载配置 + Rank int // 负载均衡权重 + + // 数据库忽略字段 + Aria2OptionsSerialized Aria2Option `gorm:"-"` +} + +// Aria2Option 非公有的Aria2配置属性 +type Aria2Option struct { + // RPC 服务器地址 + Server string `json:"server,omitempty"` + // RPC 密钥 + Token string `json:"token,omitempty"` + // 临时下载目录 + TempPath string `json:"temp_path,omitempty"` + // 附加下载配置 + Options string `json:"options,omitempty"` + // 下载监控间隔 + Interval int `json:"interval,omitempty"` + // RPC API 请求超时 + Timeout int `json:"timeout,omitempty"` +} + +type NodeStatus int +type ModelType int + +const ( + NodeActive NodeStatus = iota + NodeSuspend +) + +const ( + SlaveNodeType ModelType = iota + MasterNodeType +) + +// GetNodeByID 用ID获取节点 +func GetNodeByID(ID interface{}) (Node, error) { + var node Node + result := DB.First(&node, ID) + return node, result.Error +} + +// GetNodesByStatus 根据给定状态获取节点 +func GetNodesByStatus(status ...NodeStatus) ([]Node, error) { + var nodes []Node + result := DB.Where("status in (?)", status).Find(&nodes) + return nodes, result.Error +} + +// AfterFind 找到节点后的钩子 +func (node *Node) AfterFind() (err error) { + // 解析离线下载设置到 Aria2OptionsSerialized + if node.Aria2Options != "" { + err = json.Unmarshal([]byte(node.Aria2Options), &node.Aria2OptionsSerialized) + } + + return err +} + +// BeforeSave Save策略前的钩子 +func (node *Node) BeforeSave() (err error) { + optionsValue, err := json.Marshal(&node.Aria2OptionsSerialized) + node.Aria2Options = string(optionsValue) + return err +} + +// SetStatus 设置节点启用状态 +func (node *Node) SetStatus(status NodeStatus) error { + node.Status = status + return DB.Model(node).Updates(map[string]interface{}{ + "status": status, + }).Error +} diff --git a/models/policy.go b/models/policy.go index dfd068d..e9a3d6e 100644 --- a/models/policy.go +++ b/models/policy.go @@ -37,6 +37,7 @@ type Policy struct { // 数据库忽略字段 OptionsSerialized PolicyOption `gorm:"-"` + MasterID string `gorm:"-"` } // PolicyOption 非公有的存储策略属性 @@ -277,6 +278,13 @@ func (policy *Policy) SaveAndClearCache() error { return err } +// SaveAndClearCache 更新并清理缓存 +func (policy *Policy) UpdateAccessKeyAndClearCache(s string) error { + err := DB.Model(policy).UpdateColumn("access_key", s).Error + policy.ClearCache() + return err +} + // ClearCache 清空policy缓存 func (policy *Policy) ClearCache() { cache.Deletes([]string{strconv.FormatUint(uint64(policy.ID), 10)}, "policy_") diff --git a/models/setting.go b/models/setting.go index f8157cf..1738c1d 100644 --- a/models/setting.go +++ b/models/setting.go @@ -30,12 +30,16 @@ func GetSettingByName(name string) string { if optionValue, ok := cache.Get(cacheKey); ok { return optionValue.(string) } + // 尝试数据库中查找 - result := DB.Where("name = ?", name).First(&setting) - if result.Error == nil { - _ = cache.Set(cacheKey, setting.Value, -1) - return setting.Value + if DB != nil { + result := DB.Where("name = ?", name).First(&setting) + if result.Error == nil { + _ = cache.Set(cacheKey, setting.Value, -1) + return setting.Value + } } + return "" } diff --git a/pkg/aria2/aria2.go b/pkg/aria2/aria2.go index 40ce36a..d7f9abe 100644 --- a/pkg/aria2/aria2.go +++ b/pkg/aria2/aria2.go @@ -1,169 +1,65 @@ package aria2 import ( - "encoding/json" + "context" + "fmt" "net/url" "sync" + "time" model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/cloudreve/Cloudreve/v3/pkg/balancer" ) // Instance 默认使用的Aria2处理实例 -var Instance Aria2 = &DummyAria2{} +var Instance common.Aria2 = &common.DummyAria2{} + +// LB 获取 Aria2 节点的负载均衡器 +var LB balancer.Balancer // Lock Instance的读写锁 var Lock sync.RWMutex -// EventNotifier 任务状态更新通知处理器 -var EventNotifier = &Notifier{} - -// Aria2 离线下载处理接口 -type Aria2 interface { - // CreateTask 创建新的任务 - CreateTask(task *model.Download, options map[string]interface{}) error - // 返回状态信息 - Status(task *model.Download) (rpc.StatusInfo, error) - // 取消任务 - Cancel(task *model.Download) error - // 选择要下载的文件 - Select(task *model.Download, files []int) error -} - -const ( - // URLTask 从URL添加的任务 - URLTask = iota - // TorrentTask 种子任务 - TorrentTask -) - -const ( - // Ready 准备就绪 - Ready = iota - // Downloading 下载中 - Downloading - // Paused 暂停中 - Paused - // Error 出错 - Error - // Complete 完成 - Complete - // Canceled 取消/停止 - Canceled - // Unknown 未知状态 - Unknown -) - -var ( - // ErrNotEnabled 功能未开启错误 - ErrNotEnabled = serializer.NewError(serializer.CodeNoPermissionErr, "离线下载功能未开启", nil) - // ErrUserNotFound 未找到下载任务创建者 - ErrUserNotFound = serializer.NewError(serializer.CodeNotFound, "无法找到任务创建者", nil) -) - -// DummyAria2 未开启Aria2功能时使用的默认处理器 -type DummyAria2 struct { -} - -// CreateTask 创建新任务,此处直接返回未开启错误 -func (instance *DummyAria2) CreateTask(model *model.Download, options map[string]interface{}) error { - return ErrNotEnabled -} - -// Status 返回未开启错误 -func (instance *DummyAria2) Status(task *model.Download) (rpc.StatusInfo, error) { - return rpc.StatusInfo{}, ErrNotEnabled -} - -// Cancel 返回未开启错误 -func (instance *DummyAria2) Cancel(task *model.Download) error { - return ErrNotEnabled -} - -// Select 返回未开启错误 -func (instance *DummyAria2) Select(task *model.Download, files []int) error { - return ErrNotEnabled +// GetLoadBalancer 返回供Aria2使用的负载均衡器 +func GetLoadBalancer() balancer.Balancer { + Lock.RLock() + defer Lock.RUnlock() + return LB } // Init 初始化 func Init(isReload bool) { Lock.Lock() - defer Lock.Unlock() - - // 关闭上个初始连接 - if previousClient, ok := Instance.(*RPCService); ok { - if previousClient.Caller != nil { - util.Log().Debug("关闭上个 aria2 连接") - previousClient.Caller.Close() - } - } - - options := model.GetSettingByNames("aria2_rpcurl", "aria2_token", "aria2_options") - timeout := model.GetIntSetting("aria2_call_timeout", 5) - if options["aria2_rpcurl"] == "" { - Instance = &DummyAria2{} - return - } - - util.Log().Info("初始化 aria2 RPC 服务[%s]", options["aria2_rpcurl"]) - client := &RPCService{} - - // 解析RPC服务地址 - server, err := url.Parse(options["aria2_rpcurl"]) - if err != nil { - util.Log().Warning("无法解析 aria2 RPC 服务地址,%s", err) - Instance = &DummyAria2{} - return - } - server.Path = "/jsonrpc" - - // 加载自定义下载配置 - var globalOptions map[string]interface{} - err = json.Unmarshal([]byte(options["aria2_options"]), &globalOptions) - if err != nil { - util.Log().Warning("无法解析 aria2 全局配置,%s", err) - Instance = &DummyAria2{} - return - } - - if err := client.Init(server.String(), options["aria2_token"], timeout, globalOptions); err != nil { - util.Log().Warning("初始化 aria2 RPC 服务失败,%s", err) - Instance = &DummyAria2{} - return - } - - Instance = client + LB = balancer.NewBalancer("RoundRobin") + Lock.Unlock() if !isReload { // 从数据库中读取未完成任务,创建监控 - unfinished := model.GetDownloadsByStatus(Ready, Paused, Downloading) + unfinished := model.GetDownloadsByStatus(common.Ready, common.Paused, common.Downloading) for i := 0; i < len(unfinished); i++ { // 创建任务监控 - NewMonitor(&unfinished[i]) + monitor.NewMonitor(&unfinished[i]) } } - } -// getStatus 将给定的状态字符串转换为状态标识数字 -func getStatus(status string) int { - switch status { - case "complete": - return Complete - case "active": - return Downloading - case "waiting": - return Ready - case "paused": - return Paused - case "error": - return Error - case "removed": - return Canceled - default: - return Unknown +// TestRPCConnection 发送测试用的 RPC 请求,测试服务连通性 +func TestRPCConnection(server, secret string, timeout int) (rpc.VersionInfo, error) { + // 解析RPC服务地址 + rpcServer, err := url.Parse(server) + if err != nil { + return rpc.VersionInfo{}, fmt.Errorf("cannot parse RPC server: %w", err) + } + + rpcServer.Path = "/jsonrpc" + caller, err := rpc.New(context.Background(), rpcServer.String(), secret, time.Duration(timeout)*time.Second, nil) + if err != nil { + return rpc.VersionInfo{}, fmt.Errorf("cannot initialize rpc connection: %w", err) } + + return caller.GetVersion() } diff --git a/pkg/aria2/aria2_test.go b/pkg/aria2/aria2_test.go index 51605a1..dfd71a3 100644 --- a/pkg/aria2/aria2_test.go +++ b/pkg/aria2/aria2_test.go @@ -6,6 +6,7 @@ import ( "github.com/DATA-DOG/go-sqlmock" model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor" "github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" @@ -37,7 +38,7 @@ func TestDummyAria2(t *testing.T) { } func TestInit(t *testing.T) { - MAX_RETRY = 0 + monitor.MAX_RETRY = 0 asserts := assert.New(t) cache.Set("setting_aria2_token", "1", 0) cache.Set("setting_aria2_call_timeout", "5", 0) @@ -81,11 +82,11 @@ func TestInit(t *testing.T) { func TestGetStatus(t *testing.T) { asserts := assert.New(t) - asserts.Equal(4, getStatus("complete")) - asserts.Equal(1, getStatus("active")) - asserts.Equal(0, getStatus("waiting")) - asserts.Equal(2, getStatus("paused")) - asserts.Equal(3, getStatus("error")) - asserts.Equal(5, getStatus("removed")) - asserts.Equal(6, getStatus("?")) + asserts.Equal(4, GetStatus("complete")) + asserts.Equal(1, GetStatus("active")) + asserts.Equal(0, GetStatus("waiting")) + asserts.Equal(2, GetStatus("paused")) + asserts.Equal(3, GetStatus("error")) + asserts.Equal(5, GetStatus("removed")) + asserts.Equal(6, GetStatus("?")) } diff --git a/pkg/aria2/caller.go b/pkg/aria2/caller.go index 6e287a2..70e0bea 100644 --- a/pkg/aria2/caller.go +++ b/pkg/aria2/caller.go @@ -9,6 +9,7 @@ import ( model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/util" ) @@ -33,7 +34,7 @@ func (client *RPCService) Init(server, secret string, timeout int, options map[s Options: options, } caller, err := rpc.New(context.Background(), server, secret, time.Duration(timeout)*time.Second, - EventNotifier) + mq.GlobalMQ) client.Caller = caller return err } @@ -85,7 +86,7 @@ func (client *RPCService) Select(task *model.Download, files []int) error { } // CreateTask 创建新任务 -func (client *RPCService) CreateTask(task *model.Download, groupOptions map[string]interface{}) error { +func (client *RPCService) CreateTask(task *model.Download, groupOptions map[string]interface{}) (string, error) { // 生成存储路径 path := filepath.Join( model.GetSettingByName("aria2_temp_path"), @@ -106,18 +107,8 @@ func (client *RPCService) CreateTask(task *model.Download, groupOptions map[stri gid, err := client.Caller.AddURI(task.Source, options) if err != nil || gid == "" { - return err + return "", err } - // 保存到数据库 - task.GID = gid - _, err = task.Create() - if err != nil { - return err - } - - // 创建任务监控 - NewMonitor(task) - - return nil + return gid, nil } diff --git a/pkg/aria2/common/common.go b/pkg/aria2/common/common.go new file mode 100644 index 0000000..8f281d8 --- /dev/null +++ b/pkg/aria2/common/common.go @@ -0,0 +1,114 @@ +package common + +import ( + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" +) + +// Aria2 离线下载处理接口 +type Aria2 interface { + // Init 初始化客户端连接 + Init() error + // CreateTask 创建新的任务 + CreateTask(task *model.Download, options map[string]interface{}) (string, error) + // 返回状态信息 + Status(task *model.Download) (rpc.StatusInfo, error) + // 取消任务 + Cancel(task *model.Download) error + // 选择要下载的文件 + Select(task *model.Download, files []int) error + // 获取离线下载配置 + GetConfig() model.Aria2Option + // 删除临时下载文件 + DeleteTempFile(*model.Download) error +} + +const ( + // URLTask 从URL添加的任务 + URLTask = iota + // TorrentTask 种子任务 + TorrentTask +) + +const ( + // Ready 准备就绪 + Ready = iota + // Downloading 下载中 + Downloading + // Paused 暂停中 + Paused + // Error 出错 + Error + // Complete 完成 + Complete + // Canceled 取消/停止 + Canceled + // Unknown 未知状态 + Unknown +) + +var ( + // ErrNotEnabled 功能未开启错误 + ErrNotEnabled = serializer.NewError(serializer.CodeNoPermissionErr, "离线下载功能未开启", nil) + // ErrUserNotFound 未找到下载任务创建者 + ErrUserNotFound = serializer.NewError(serializer.CodeNotFound, "无法找到任务创建者", nil) +) + +// DummyAria2 未开启Aria2功能时使用的默认处理器 +type DummyAria2 struct { +} + +func (instance *DummyAria2) Init() error { + return nil +} + +// CreateTask 创建新任务,此处直接返回未开启错误 +func (instance *DummyAria2) CreateTask(model *model.Download, options map[string]interface{}) (string, error) { + return "", ErrNotEnabled +} + +// Status 返回未开启错误 +func (instance *DummyAria2) Status(task *model.Download) (rpc.StatusInfo, error) { + return rpc.StatusInfo{}, ErrNotEnabled +} + +// Cancel 返回未开启错误 +func (instance *DummyAria2) Cancel(task *model.Download) error { + return ErrNotEnabled +} + +// Select 返回未开启错误 +func (instance *DummyAria2) Select(task *model.Download, files []int) error { + return ErrNotEnabled +} + +// GetConfig 返回空的 +func (instance *DummyAria2) GetConfig() model.Aria2Option { + return model.Aria2Option{} +} + +// GetConfig 返回空的 +func (instance *DummyAria2) DeleteTempFile(src *model.Download) error { + return ErrNotEnabled +} + +// GetStatus 将给定的状态字符串转换为状态标识数字 +func GetStatus(status string) int { + switch status { + case "complete": + return Complete + case "active": + return Downloading + case "waiting": + return Ready + case "paused": + return Paused + case "error": + return Error + case "removed": + return Canceled + default: + return Unknown + } +} diff --git a/pkg/aria2/monitor.go b/pkg/aria2/monitor/monitor.go similarity index 80% rename from pkg/aria2/monitor.go rename to pkg/aria2/monitor/monitor.go index 1667f47..7a04411 100644 --- a/pkg/aria2/monitor.go +++ b/pkg/aria2/monitor/monitor.go @@ -1,19 +1,21 @@ -package aria2 +package monitor import ( "context" "encoding/json" "errors" - "os" "path/filepath" "strconv" "time" model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/task" "github.com/cloudreve/Cloudreve/v3/pkg/util" ) @@ -23,32 +25,34 @@ type Monitor struct { Task *model.Download Interval time.Duration - notifier chan StatusEvent + notifier <-chan mq.Message + node cluster.Node retried int } -// StatusEvent 状态改变事件 -type StatusEvent struct { - GID string - Status int -} - var MAX_RETRY = 10 -// NewMonitor 新建上传状态监控 +// NewMonitor 新建离线下载状态监控 func NewMonitor(task *model.Download) { monitor := &Monitor{ Task: task, - Interval: time.Duration(model.GetIntSetting("aria2_interval", 10)) * time.Second, - notifier: make(chan StatusEvent), + notifier: make(chan mq.Message), + node: cluster.Default.GetNodeByID(task.GetNodeID()), + } + + if monitor.node != nil { + monitor.Interval = time.Duration(monitor.node.GetAria2Instance().GetConfig().Interval) * time.Second + go monitor.Loop() + + monitor.notifier = mq.GlobalMQ.Subscribe(monitor.Task.GID, 0) + } else { + monitor.setErrorStatus(errors.New("节点不可用")) } - go monitor.Loop() - EventNotifier.Subscribe(monitor.notifier, monitor.Task.GID) } // Loop 开启监控循环 func (monitor *Monitor) Loop() { - defer EventNotifier.Unsubscribe(monitor.Task.GID) + defer mq.GlobalMQ.Unsubscribe(monitor.Task.GID, monitor.notifier) // 首次循环立即更新 interval := time.Duration(0) @@ -70,9 +74,7 @@ func (monitor *Monitor) Loop() { // Update 更新状态,返回值表示是否退出监控 func (monitor *Monitor) Update() bool { - Lock.RLock() - status, err := Instance.Status(monitor.Task) - Lock.RUnlock() + status, err := monitor.node.GetAria2Instance().Status(monitor.Task) if err != nil { monitor.retried++ @@ -102,6 +104,7 @@ func (monitor *Monitor) Update() bool { if err := monitor.UpdateTaskInfo(status); err != nil { util.Log().Warning("无法更新下载任务[%s]的任务信息[%s],", monitor.Task.GID, err) monitor.setErrorStatus(err) + monitor.RemoveTempFolder() return true } @@ -115,7 +118,7 @@ func (monitor *Monitor) Update() bool { case "active", "waiting", "paused": return false case "removed": - monitor.Task.Status = Canceled + monitor.Task.Status = common.Canceled monitor.Task.Save() monitor.RemoveTempFolder() return true @@ -130,7 +133,7 @@ func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error { originSize := monitor.Task.TotalSize monitor.Task.GID = status.Gid - monitor.Task.Status = getStatus(status.Status) + monitor.Task.Status = common.GetStatus(status.Status) // 文件大小、已下载大小 total, err := strconv.ParseUint(status.TotalLength, 10, 64) @@ -164,9 +167,7 @@ func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error { // 文件大小更新后,对文件限制等进行校验 if err := monitor.ValidateFile(); err != nil { // 验证失败时取消任务 - Lock.RLock() - Instance.Cancel(monitor.Task) - Lock.RUnlock() + monitor.node.GetAria2Instance().Cancel(monitor.Task) return err } } @@ -179,7 +180,7 @@ func (monitor *Monitor) ValidateFile() error { // 找到任务创建者 user := monitor.Task.GetOwner() if user == nil { - return ErrUserNotFound + return common.ErrUserNotFound } // 创建文件系统 @@ -230,28 +231,31 @@ func (monitor *Monitor) Error(status rpc.StatusInfo) bool { // RemoveTempFolder 清理下载临时目录 func (monitor *Monitor) RemoveTempFolder() { - err := os.RemoveAll(monitor.Task.Parent) - if err != nil { - util.Log().Warning("无法删除离线下载临时目录[%s], %s", monitor.Task.Parent, err) - } - + monitor.node.GetAria2Instance().DeleteTempFile(monitor.Task) } // Complete 完成下载,返回是否中断监控 func (monitor *Monitor) Complete(status rpc.StatusInfo) bool { // 创建中转任务 file := make([]string, 0, len(monitor.Task.StatusInfo.Files)) + sizes := make(map[string]uint64, len(monitor.Task.StatusInfo.Files)) for i := 0; i < len(monitor.Task.StatusInfo.Files); i++ { - if monitor.Task.StatusInfo.Files[i].Selected == "true" { - file = append(file, monitor.Task.StatusInfo.Files[i].Path) + fileInfo := monitor.Task.StatusInfo.Files[i] + if fileInfo.Selected == "true" { + file = append(file, fileInfo.Path) + size, _ := strconv.ParseUint(fileInfo.Length, 10, 64) + sizes[fileInfo.Path] = size } } + job, err := task.NewTransferTask( monitor.Task.UserID, file, monitor.Task.Dst, monitor.Task.Parent, true, + monitor.node.ID(), + sizes, ) if err != nil { monitor.setErrorStatus(err) @@ -269,7 +273,7 @@ func (monitor *Monitor) Complete(status rpc.StatusInfo) bool { } func (monitor *Monitor) setErrorStatus(err error) { - monitor.Task.Status = Error + monitor.Task.Status = common.Error monitor.Task.Error = err.Error() monitor.Task.Save() } diff --git a/pkg/aria2/monitor_test.go b/pkg/aria2/monitor/monitor_test.go similarity index 57% rename from pkg/aria2/monitor_test.go rename to pkg/aria2/monitor/monitor_test.go index 9172894..9d45026 100644 --- a/pkg/aria2/monitor_test.go +++ b/pkg/aria2/monitor/monitor_test.go @@ -1,4 +1,4 @@ -package aria2 +package monitor import ( "errors" @@ -7,6 +7,8 @@ import ( "github.com/DATA-DOG/go-sqlmock" model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" @@ -44,13 +46,13 @@ func (m InstanceMock) Select(task *model.Download, files []int) error { func TestNewMonitor(t *testing.T) { asserts := assert.New(t) NewMonitor(&model.Download{GID: "gid"}) - _, ok := EventNotifier.Subscribes.Load("gid") + _, ok := common.EventNotifier.Subscribes.Load("gid") asserts.True(ok) } func TestMonitor_Loop(t *testing.T) { asserts := assert.New(t) - notifier := make(chan StatusEvent) + notifier := make(chan common.StatusEvent) MAX_RETRY = 0 monitor := &Monitor{ Task: &model.Download{GID: "gid"}, @@ -76,10 +78,10 @@ func TestMonitor_Update(t *testing.T) { { MAX_RETRY = 1 testInstance := new(InstanceMock) - testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, errors.New("error")) + testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{}, errors.New("error")) file, _ := util.CreatNestedFile("TestMonitor_Update/1") file.Close() - Instance = testInstance + aria2.Instance = testInstance asserts.False(monitor.Update()) asserts.True(monitor.Update()) testInstance.AssertExpectations(t) @@ -89,16 +91,16 @@ func TestMonitor_Update(t *testing.T) { // 磁力链下载重定向 { testInstance := new(InstanceMock) - testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{ + testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{ FollowedBy: []string{"1"}, }, nil) monitor.Task.ID = 1 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - Instance = testInstance + aria2.mock.ExpectBegin() + aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + aria2.mock.ExpectCommit() + aria2.Instance = testInstance asserts.False(monitor.Update()) - asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(aria2.mock.ExpectationsWereMet()) testInstance.AssertExpectations(t) asserts.EqualValues("1", monitor.Task.GID) } @@ -106,82 +108,82 @@ func TestMonitor_Update(t *testing.T) { // 无法更新任务信息 { testInstance := new(InstanceMock) - testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, nil) + testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{}, nil) monitor.Task.ID = 1 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - Instance = testInstance + aria2.mock.ExpectBegin() + aria2.mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error")) + aria2.mock.ExpectRollback() + aria2.Instance = testInstance asserts.True(monitor.Update()) - asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(aria2.mock.ExpectationsWereMet()) testInstance.AssertExpectations(t) } // 返回未知状态 { testInstance := new(InstanceMock) - testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "?"}, nil) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - Instance = testInstance + testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "?"}, nil) + aria2.mock.ExpectBegin() + aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + aria2.mock.ExpectCommit() + aria2.Instance = testInstance asserts.True(monitor.Update()) - asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(aria2.mock.ExpectationsWereMet()) testInstance.AssertExpectations(t) } // 返回被取消状态 { testInstance := new(InstanceMock) - testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "removed"}, nil) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - Instance = testInstance + testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "removed"}, nil) + aria2.mock.ExpectBegin() + aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + aria2.mock.ExpectCommit() + aria2.mock.ExpectBegin() + aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + aria2.mock.ExpectCommit() + aria2.Instance = testInstance asserts.True(monitor.Update()) - asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(aria2.mock.ExpectationsWereMet()) testInstance.AssertExpectations(t) } // 返回活跃状态 { testInstance := new(InstanceMock) - testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "active"}, nil) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - Instance = testInstance + testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "active"}, nil) + aria2.mock.ExpectBegin() + aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + aria2.mock.ExpectCommit() + aria2.Instance = testInstance asserts.False(monitor.Update()) - asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(aria2.mock.ExpectationsWereMet()) testInstance.AssertExpectations(t) } // 返回错误状态 { testInstance := new(InstanceMock) - testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "error"}, nil) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - Instance = testInstance + testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "error"}, nil) + aria2.mock.ExpectBegin() + aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + aria2.mock.ExpectCommit() + aria2.Instance = testInstance asserts.True(monitor.Update()) - asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(aria2.mock.ExpectationsWereMet()) testInstance.AssertExpectations(t) } // 返回完成 { testInstance := new(InstanceMock) - testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "complete"}, nil) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - Instance = testInstance + testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "complete"}, nil) + aria2.mock.ExpectBegin() + aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + aria2.mock.ExpectCommit() + aria2.Instance = testInstance asserts.True(monitor.Update()) - asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(aria2.mock.ExpectationsWereMet()) testInstance.AssertExpectations(t) } } @@ -198,34 +200,34 @@ func TestMonitor_UpdateTaskInfo(t *testing.T) { // 失败 { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() + aria2.mock.ExpectBegin() + aria2.mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error")) + aria2.mock.ExpectRollback() err := monitor.UpdateTaskInfo(rpc.StatusInfo{}) - asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(aria2.mock.ExpectationsWereMet()) asserts.Error(err) } // 更新成功,无需校验 { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() + aria2.mock.ExpectBegin() + aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + aria2.mock.ExpectCommit() err := monitor.UpdateTaskInfo(rpc.StatusInfo{}) - asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(aria2.mock.ExpectationsWereMet()) asserts.NoError(err) } // 更新成功,大小改变,需要校验,校验失败 { testInstance := new(InstanceMock) - testInstance.On("Cancel", testMock.Anything).Return(nil) - Instance = testInstance - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() + testInstance.On("SlaveCancel", testMock.Anything).Return(nil) + aria2.Instance = testInstance + aria2.mock.ExpectBegin() + aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + aria2.mock.ExpectCommit() err := monitor.UpdateTaskInfo(rpc.StatusInfo{TotalLength: "1"}) - asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(aria2.mock.ExpectationsWereMet()) asserts.Error(err) testInstance.AssertExpectations(t) } @@ -308,17 +310,17 @@ func TestMonitor_Complete(t *testing.T) { } cache.Set("setting_max_worker_num", "1", 0) - mock.ExpectQuery("SELECT(.+)tasks").WillReturnRows(sqlmock.NewRows([]string{"id"})) + aria2.mock.ExpectQuery("SELECT(.+)tasks").WillReturnRows(sqlmock.NewRows([]string{"id"})) task.Init() - mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectQuery("SELECT(.+)policies").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() + aria2.mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + aria2.mock.ExpectQuery("SELECT(.+)policies").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + aria2.mock.ExpectBegin() + aria2.mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(1, 1)) + aria2.mock.ExpectCommit() - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() + aria2.mock.ExpectBegin() + aria2.mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1)) + aria2.mock.ExpectCommit() asserts.True(monitor.Complete(rpc.StatusInfo{})) - asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(aria2.mock.ExpectationsWereMet()) } diff --git a/pkg/aria2/notification.go b/pkg/aria2/notification.go deleted file mode 100644 index e2ead91..0000000 --- a/pkg/aria2/notification.go +++ /dev/null @@ -1,64 +0,0 @@ -package aria2 - -import ( - "sync" - - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" -) - -// Notifier aria2实践通知处理 -type Notifier struct { - Subscribes sync.Map -} - -// Subscribe 订阅事件通知 -func (notifier *Notifier) Subscribe(target chan StatusEvent, gid string) { - notifier.Subscribes.Store(gid, target) -} - -// Unsubscribe 取消订阅事件通知 -func (notifier *Notifier) Unsubscribe(gid string) { - notifier.Subscribes.Delete(gid) -} - -// Notify 发送通知 -func (notifier *Notifier) Notify(events []rpc.Event, status int) { - for _, event := range events { - if target, ok := notifier.Subscribes.Load(event.Gid); ok { - target.(chan StatusEvent) <- StatusEvent{ - GID: event.Gid, - Status: status, - } - } - } -} - -// OnDownloadStart 下载开始 -func (notifier *Notifier) OnDownloadStart(events []rpc.Event) { - notifier.Notify(events, Downloading) -} - -// OnDownloadPause 下载暂停 -func (notifier *Notifier) OnDownloadPause(events []rpc.Event) { - notifier.Notify(events, Paused) -} - -// OnDownloadStop 下载停止 -func (notifier *Notifier) OnDownloadStop(events []rpc.Event) { - notifier.Notify(events, Canceled) -} - -// OnDownloadComplete 下载完成 -func (notifier *Notifier) OnDownloadComplete(events []rpc.Event) { - notifier.Notify(events, Complete) -} - -// OnDownloadError 下载出错 -func (notifier *Notifier) OnDownloadError(events []rpc.Event) { - notifier.Notify(events, Error) -} - -// OnBtDownloadComplete BT下载完成 -func (notifier *Notifier) OnBtDownloadComplete(events []rpc.Event) { - notifier.Notify(events, Complete) -} diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 86d02c6..d8250e8 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -2,9 +2,11 @@ package auth import ( "bytes" + "fmt" "io/ioutil" "net/http" "net/url" + "sort" "strings" "time" @@ -30,9 +32,8 @@ type Auth interface { Check(body string, sign string) error } -// SignRequest 对PUT\POST等复杂HTTP请求签名,如果请求Header中 -// 包含 X-Policy, 则此请求会被认定为上传请求,只会对URI部分和 -// Policy部分进行签名。其他请求则会对URI和Body部分进行签名。 +// SignRequest 对PUT\POST等复杂HTTP请求签名,只会对URI部分、 +// 请求正文、`X-`开头的header进行签名 func SignRequest(instance Auth, r *http.Request, expires int64) *http.Request { // 处理有效期 if expires > 0 { @@ -61,20 +62,31 @@ func CheckRequest(instance Auth, r *http.Request) error { return instance.Check(getSignContent(r), sign[0]) } -// getSignContent 根据请求Header中是否包含X-Policy判断是否为上传请求, -// 返回待签名/验证的字符串 +// getSignContent 签名请求 path、正文、以`X-`开头的 Header. 如果 Header 中包含 `X-Policy`, +// 则不对正文签名。返回待签名/验证的字符串 func getSignContent(r *http.Request) (rawSignString string) { - if policy, ok := r.Header["X-Policy"]; ok { - rawSignString = serializer.NewRequestSignString(r.URL.Path, policy[0], "") - } else { - var body = []byte{} + // 读取所有body正文 + var body = []byte{} + if _, ok := r.Header["X-Policy"]; !ok { if r.Body != nil { body, _ = ioutil.ReadAll(r.Body) _ = r.Body.Close() r.Body = ioutil.NopCloser(bytes.NewReader(body)) } - rawSignString = serializer.NewRequestSignString(r.URL.Path, "", string(body)) } + + // 决定要签名的header + var signedHeader []string + for k, _ := range r.Header { + if strings.HasPrefix(k, "X-") && k != "X-Filename" { + signedHeader = append(signedHeader, fmt.Sprintf("%s=%s", k, r.Header.Get(k))) + } + } + sort.Strings(signedHeader) + + // 读取所有待签名Header + rawSignString = serializer.NewRequestSignString(r.URL.Path, strings.Join(signedHeader, "&"), string(body)) + return rawSignString } diff --git a/pkg/balancer/balancer.go b/pkg/balancer/balancer.go new file mode 100644 index 0000000..5d5c028 --- /dev/null +++ b/pkg/balancer/balancer.go @@ -0,0 +1,15 @@ +package balancer + +type Balancer interface { + NextPeer(nodes interface{}) (error, interface{}) +} + +// NewBalancer 根据策略标识返回新的负载均衡器 +func NewBalancer(strategy string) Balancer { + switch strategy { + case "RoundRobin": + return &RoundRobin{} + default: + return &RoundRobin{} + } +} diff --git a/pkg/balancer/errors.go b/pkg/balancer/errors.go new file mode 100644 index 0000000..aef7b1f --- /dev/null +++ b/pkg/balancer/errors.go @@ -0,0 +1,8 @@ +package balancer + +import "errors" + +var ( + ErrInputNotSlice = errors.New("Input value is not silice") + ErrNoAvaliableNode = errors.New("No nodes avaliable") +) diff --git a/pkg/balancer/roundrobin.go b/pkg/balancer/roundrobin.go new file mode 100644 index 0000000..cf300f5 --- /dev/null +++ b/pkg/balancer/roundrobin.go @@ -0,0 +1,30 @@ +package balancer + +import ( + "reflect" + "sync/atomic" +) + +type RoundRobin struct { + current uint64 +} + +// NextPeer 返回轮盘的下一节点 +func (r *RoundRobin) NextPeer(nodes interface{}) (error, interface{}) { + v := reflect.ValueOf(nodes) + if v.Kind() != reflect.Slice { + return ErrInputNotSlice, nil + } + + if v.Len() == 0 { + return ErrNoAvaliableNode, nil + } + + next := r.NextIndex(v.Len()) + return nil, v.Index(next).Interface() +} + +// NextIndex 返回下一个节点下标 +func (r *RoundRobin) NextIndex(total int) int { + return int(atomic.AddUint64(&r.current, uint64(1)) % uint64(total)) +} diff --git a/pkg/cluster/errors.go b/pkg/cluster/errors.go new file mode 100644 index 0000000..9afdbef --- /dev/null +++ b/pkg/cluster/errors.go @@ -0,0 +1,8 @@ +package cluster + +import "errors" + +var ( + ErrFeatureNotExist = errors.New("No nodes in nodepool match the feature specificed") + ErrIlegalPath = errors.New("path out of boundary of setting temp folder") +) diff --git a/pkg/cluster/master.go b/pkg/cluster/master.go new file mode 100644 index 0000000..e877920 --- /dev/null +++ b/pkg/cluster/master.go @@ -0,0 +1,265 @@ +package cluster + +import ( + "context" + "encoding/json" + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" + "github.com/cloudreve/Cloudreve/v3/pkg/auth" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/gofrs/uuid" + "net/url" + "os" + "path/filepath" + "strconv" + "strings" + "sync" + "time" +) + +const deleteTempFileDuration = 60 * time.Second + +type MasterNode struct { + Model *model.Node + aria2RPC rpcService + lock sync.RWMutex +} + +// RPCService 通过RPC服务的Aria2任务管理器 +type rpcService struct { + Caller rpc.Client + Initialized bool + + parent *MasterNode + options *clientOptions +} + +type clientOptions struct { + Options map[string]interface{} // 创建下载时额外添加的设置 +} + +// Init 初始化节点 +func (node *MasterNode) Init(nodeModel *model.Node) { + node.lock.Lock() + node.Model = nodeModel + node.aria2RPC.parent = node + node.lock.Unlock() + + node.lock.RLock() + if node.Model.Aria2Enabled { + node.lock.RUnlock() + node.aria2RPC.Init() + return + } + node.lock.RUnlock() +} + +func (node *MasterNode) ID() uint { + node.lock.RLock() + defer node.lock.RUnlock() + + return node.Model.ID +} + +func (node *MasterNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) { + return &serializer.NodePingResp{}, nil +} + +// IsFeatureEnabled 查询节点的某项功能是否启用 +func (node *MasterNode) IsFeatureEnabled(feature string) bool { + node.lock.RLock() + defer node.lock.RUnlock() + + switch feature { + case "aria2": + return node.Model.Aria2Enabled + default: + return false + } +} + +func (node *MasterNode) MasterAuthInstance() auth.Auth { + node.lock.RLock() + defer node.lock.RUnlock() + + return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)} +} + +func (node *MasterNode) SlaveAuthInstance() auth.Auth { + node.lock.RLock() + defer node.lock.RUnlock() + + return auth.HMACAuth{SecretKey: []byte(node.Model.SlaveKey)} +} + +// SubscribeStatusChange 订阅节点状态更改 +func (node *MasterNode) SubscribeStatusChange(callback func(isActive bool, id uint)) { +} + +// IsActive 返回节点是否在线 +func (node *MasterNode) IsActive() bool { + return true +} + +// Kill 结束aria2请求 +func (node *MasterNode) Kill() { + if node.aria2RPC.Caller != nil { + node.aria2RPC.Caller.Close() + } +} + +// GetAria2Instance 获取主机Aria2实例 +func (node *MasterNode) GetAria2Instance() common.Aria2 { + node.lock.RLock() + + if !node.Model.Aria2Enabled { + node.lock.RUnlock() + return &common.DummyAria2{} + } + + if !node.aria2RPC.Initialized { + node.lock.RUnlock() + node.aria2RPC.Init() + return &common.DummyAria2{} + } + + defer node.lock.RUnlock() + return &node.aria2RPC +} + +func (node *MasterNode) IsMater() bool { + return true +} + +func (node *MasterNode) DBModel() *model.Node { + node.lock.RLock() + defer node.lock.RUnlock() + + return node.Model +} + +func (r *rpcService) Init() error { + r.parent.lock.Lock() + defer r.parent.lock.Unlock() + r.Initialized = false + + // 客户端已存在,则关闭先前连接 + if r.Caller != nil { + r.Caller.Close() + } + + // 解析RPC服务地址 + server, err := url.Parse(r.parent.Model.Aria2OptionsSerialized.Server) + if err != nil { + util.Log().Warning("无法解析主机 Aria2 RPC 服务地址,%s", err) + return err + } + server.Path = "/jsonrpc" + + // 加载自定义下载配置 + var globalOptions map[string]interface{} + if r.parent.Model.Aria2OptionsSerialized.Options != "" { + err = json.Unmarshal([]byte(r.parent.Model.Aria2OptionsSerialized.Options), &globalOptions) + if err != nil { + util.Log().Warning("无法解析主机 Aria2 配置,%s", err) + return err + } + } + + r.options = &clientOptions{ + Options: globalOptions, + } + timeout := r.parent.Model.Aria2OptionsSerialized.Timeout + caller, err := rpc.New(context.Background(), server.String(), r.parent.Model.Aria2OptionsSerialized.Token, time.Duration(timeout)*time.Second, mq.GlobalMQ) + + r.Caller = caller + r.Initialized = err == nil + return err +} + +func (r *rpcService) CreateTask(task *model.Download, groupOptions map[string]interface{}) (string, error) { + r.parent.lock.RLock() + // 生成存储路径 + guid, _ := uuid.NewV4() + path := filepath.Join( + r.parent.Model.Aria2OptionsSerialized.TempPath, + "aria2", + guid.String(), + ) + r.parent.lock.RUnlock() + + // 创建下载任务 + options := map[string]interface{}{ + "dir": path, + } + for k, v := range r.options.Options { + options[k] = v + } + for k, v := range groupOptions { + options[k] = v + } + + gid, err := r.Caller.AddURI(task.Source, options) + if err != nil || gid == "" { + return "", err + } + + return gid, nil +} + +func (r *rpcService) Status(task *model.Download) (rpc.StatusInfo, error) { + res, err := r.Caller.TellStatus(task.GID) + if err != nil { + // 失败后重试 + util.Log().Debug("无法获取离线下载状态,%s,10秒钟后重试", err) + time.Sleep(time.Duration(10) * time.Second) + res, err = r.Caller.TellStatus(task.GID) + } + + return res, err +} + +func (r *rpcService) Cancel(task *model.Download) error { + // 取消下载任务 + _, err := r.Caller.Remove(task.GID) + if err != nil { + util.Log().Warning("无法取消离线下载任务[%s], %s", task.GID, err) + } + + return err +} + +func (r *rpcService) Select(task *model.Download, files []int) error { + var selected = make([]string, len(files)) + for i := 0; i < len(files); i++ { + selected[i] = strconv.Itoa(files[i]) + } + _, err := r.Caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")}) + return err +} + +func (r *rpcService) GetConfig() model.Aria2Option { + r.parent.lock.RLock() + defer r.parent.lock.RUnlock() + + return r.parent.Model.Aria2OptionsSerialized +} + +func (s *rpcService) DeleteTempFile(task *model.Download) error { + s.parent.lock.RLock() + defer s.parent.lock.RUnlock() + + // 避免被aria2占用,异步执行删除 + go func(src string) { + time.Sleep(deleteTempFileDuration) + err := os.RemoveAll(src) + if err != nil { + util.Log().Warning("无法删除离线下载临时目录[%s], %s", src, err) + } + }(task.Parent) + + return nil +} diff --git a/pkg/cluster/node.go b/pkg/cluster/node.go new file mode 100644 index 0000000..745dd25 --- /dev/null +++ b/pkg/cluster/node.go @@ -0,0 +1,60 @@ +package cluster + +import ( + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" + "github.com/cloudreve/Cloudreve/v3/pkg/auth" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" +) + +type Node interface { + // Init a node from database model + Init(node *model.Node) + + // Check if given feature is enabled + IsFeatureEnabled(feature string) bool + + // Subscribe node status change to a callback function + SubscribeStatusChange(callback func(isActive bool, id uint)) + + // Ping the node + Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) + + // Returns if the node is active + IsActive() bool + + // Get instances for aria2 calls + GetAria2Instance() common.Aria2 + + // Returns unique id of this node + ID() uint + + // Kill node and recycle resources + Kill() + + // Returns if current node is master node + IsMater() bool + + // Get auth instance used to check RPC call from slave to master + MasterAuthInstance() auth.Auth + + // Get auth instance used to check RPC call from master to slave + SlaveAuthInstance() auth.Auth + + // Get node DB model + DBModel() *model.Node +} + +// Create new node from DB model +func NewNodeFromDBModel(node *model.Node) Node { + switch node.Type { + case model.SlaveNodeType: + slave := &SlaveNode{} + slave.Init(node) + return slave + default: + master := &MasterNode{} + master.Init(node) + return master + } +} diff --git a/pkg/cluster/pool.go b/pkg/cluster/pool.go new file mode 100644 index 0000000..4526f4a --- /dev/null +++ b/pkg/cluster/pool.go @@ -0,0 +1,176 @@ +package cluster + +import ( + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/balancer" + "github.com/cloudreve/Cloudreve/v3/pkg/util" + "sync" +) + +var Default *NodePool + +// 需要分类的节点组 +var featureGroup = []string{"aria2"} + +// Pool 节点池 +type Pool interface { + // Returns active node selected by given feature and load balancer + BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, Node) + + // Returns node by ID + GetNodeByID(id uint) Node + + // Add given node into pool. If node existed, refresh node. + Add(node *model.Node) + + // Delete and kill node from pool by given node id + Delete(id uint) +} + +// NodePool 通用节点池 +type NodePool struct { + active map[uint]Node + inactive map[uint]Node + + featureMap map[string][]Node + + lock sync.RWMutex +} + +// Init 初始化从机节点池 +func Init() { + Default = &NodePool{ + featureMap: make(map[string][]Node), + } + if err := Default.initFromDB(); err != nil { + util.Log().Warning("节点池初始化失败, %s", err) + } +} + +func (pool *NodePool) buildIndexMap() { + pool.lock.Lock() + for _, feature := range featureGroup { + pool.featureMap[feature] = make([]Node, 0) + } + + for _, v := range pool.active { + for _, feature := range featureGroup { + if v.IsFeatureEnabled(feature) { + pool.featureMap[feature] = append(pool.featureMap[feature], v) + } + } + } + pool.lock.Unlock() +} + +func (pool *NodePool) GetNodeByID(id uint) Node { + pool.lock.RLock() + defer pool.lock.RUnlock() + + if node, ok := pool.active[id]; ok { + return node + } + + return pool.inactive[id] +} + +func (pool *NodePool) nodeStatusChange(isActive bool, id uint) { + util.Log().Debug("从机节点 [ID=%d] 状态变更 [Active=%t]", id, isActive) + pool.lock.Lock() + if isActive { + node := pool.inactive[id] + delete(pool.inactive, id) + pool.active[id] = node + } else { + node := pool.active[id] + delete(pool.active, id) + pool.inactive[id] = node + } + pool.lock.Unlock() + + pool.buildIndexMap() +} + +func (pool *NodePool) initFromDB() error { + nodes, err := model.GetNodesByStatus(model.NodeActive) + if err != nil { + return err + } + + pool.lock.Lock() + pool.active = make(map[uint]Node) + pool.inactive = make(map[uint]Node) + for i := 0; i < len(nodes); i++ { + pool.add(&nodes[i]) + } + pool.lock.Unlock() + + pool.buildIndexMap() + return nil +} + +func (pool *NodePool) add(node *model.Node) { + newNode := NewNodeFromDBModel(node) + if newNode.IsActive() { + pool.active[node.ID] = newNode + } else { + pool.inactive[node.ID] = newNode + } + + // 订阅节点状态变更 + newNode.SubscribeStatusChange(func(isActive bool, id uint) { + pool.nodeStatusChange(isActive, id) + }) +} + +func (pool *NodePool) Add(node *model.Node) { + pool.lock.Lock() + defer pool.buildIndexMap() + defer pool.lock.Unlock() + + if _, ok := pool.active[node.ID]; ok { + // TODO: refresh node + return + } + + if _, ok := pool.inactive[node.ID]; ok { + return + } + + pool.add(node) +} + +func (pool *NodePool) Delete(id uint) { + pool.lock.Lock() + defer pool.buildIndexMap() + defer pool.lock.Unlock() + + if node, ok := pool.active[id]; ok { + node.Kill() + delete(pool.active, id) + return + } + + if node, ok := pool.inactive[id]; ok { + node.Kill() + delete(pool.inactive, id) + return + } + +} + +// BalanceNodeByFeature 根据 feature 和 LoadBalancer 取出节点 +func (pool *NodePool) BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, Node) { + pool.lock.RLock() + defer pool.lock.RUnlock() + if nodes, ok := pool.featureMap[feature]; ok { + err, res := lb.NextPeer(nodes) + if err == nil { + return nil, res.(Node) + } + + return err, nil + } + + return ErrFeatureNotExist, nil +} diff --git a/pkg/cluster/slave.go b/pkg/cluster/slave.go new file mode 100644 index 0000000..a76f238 --- /dev/null +++ b/pkg/cluster/slave.go @@ -0,0 +1,405 @@ +package cluster + +import ( + "encoding/json" + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" + "github.com/cloudreve/Cloudreve/v3/pkg/auth" + "github.com/cloudreve/Cloudreve/v3/pkg/request" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/cloudreve/Cloudreve/v3/pkg/util" + "io" + "net/url" + "strings" + "sync" + "time" +) + +type SlaveNode struct { + Model *model.Node + Active bool + + caller slaveCaller + callback func(bool, uint) + close chan bool + lock sync.RWMutex +} + +type slaveCaller struct { + parent *SlaveNode + Client request.Client +} + +// Init 初始化节点 +func (node *SlaveNode) Init(nodeModel *model.Node) { + node.lock.Lock() + defer node.lock.Unlock() + node.Model = nodeModel + + // Init http request client + var endpoint *url.URL + if serverURL, err := url.Parse(node.Model.Server); err == nil { + var controller *url.URL + controller, _ = url.Parse("/api/v3/slave") + endpoint = serverURL.ResolveReference(controller) + } + + signTTL := model.GetIntSetting("slave_api_timeout", 60) + node.caller.Client = request.NewClient( + request.WithMasterMeta(), + request.WithTimeout(time.Duration(signTTL)*time.Second), + request.WithCredential(auth.HMACAuth{SecretKey: []byte(nodeModel.SlaveKey)}, int64(signTTL)), + request.WithEndpoint(endpoint.String()), + ) + + node.caller.parent = node + node.Active = true + if node.close != nil { + node.close <- true + } + + go node.StartPingLoop() +} + +// IsFeatureEnabled 查询节点的某项功能是否启用 +func (node *SlaveNode) IsFeatureEnabled(feature string) bool { + node.lock.RLock() + defer node.lock.RUnlock() + + switch feature { + case "aria2": + return node.Model.Aria2Enabled + default: + return false + } +} + +// SubscribeStatusChange 订阅节点状态更改 +func (node *SlaveNode) SubscribeStatusChange(callback func(bool, uint)) { + node.lock.Lock() + node.callback = callback + node.lock.Unlock() +} + +// Ping 从机节点,返回从机负载 +func (node *SlaveNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) { + reqBodyEncoded, err := json.Marshal(req) + if err != nil { + return nil, err + } + + bodyReader := strings.NewReader(string(reqBodyEncoded)) + + resp, err := node.caller.Client.Request( + "POST", + "heartbeat", + bodyReader, + ).CheckHTTPResponse(200).DecodeResponse() + if err != nil { + return nil, err + } + + // 处理列取结果 + if resp.Code != 0 { + return nil, serializer.NewErrorFromResponse(resp) + } + + var res serializer.NodePingResp + + if resStr, ok := resp.Data.(string); ok { + err = json.Unmarshal([]byte(resStr), &res) + if err != nil { + return nil, err + } + } + + return &res, nil +} + +// IsActive 返回节点是否在线 +func (node *SlaveNode) IsActive() bool { + node.lock.RLock() + defer node.lock.RUnlock() + + return node.Active +} + +// Kill 结束节点内相关循环 +func (node *SlaveNode) Kill() { + node.lock.RLock() + defer node.lock.RUnlock() + + if node.close != nil { + close(node.close) + } +} + +// GetAria2Instance 获取从机Aria2实例 +func (node *SlaveNode) GetAria2Instance() common.Aria2 { + node.lock.RLock() + defer node.lock.RUnlock() + + if !node.Model.Aria2Enabled { + return &common.DummyAria2{} + } + + return &node.caller +} + +func (node *SlaveNode) ID() uint { + node.lock.RLock() + defer node.lock.RUnlock() + + return node.Model.ID +} + +func (node *SlaveNode) StartPingLoop() { + node.lock.Lock() + node.close = make(chan bool) + node.lock.Unlock() + + tickDuration := time.Duration(model.GetIntSetting("slave_ping_interval", 300)) * time.Second + recoverDuration := time.Duration(model.GetIntSetting("slave_recover_interval", 600)) * time.Second + pingTicker := time.Duration(0) + + util.Log().Debug("从机节点 [%s] 启动心跳循环", node.Model.Name) + retry := 0 + recoverMode := false + isFirstLoop := true + +loop: + for { + select { + case <-time.After(pingTicker): + if pingTicker == 0 { + pingTicker = tickDuration + } + + util.Log().Debug("从机节点 [%s] 发送Ping", node.Model.Name) + res, err := node.Ping(node.getHeartbeatContent(isFirstLoop)) + isFirstLoop = false + + if err != nil { + util.Log().Debug("Ping从机节点 [%s] 时发生错误: %s", node.Model.Name, err) + retry++ + if retry >= model.GetIntSetting("slave_node_retry", 3) { + util.Log().Debug("从机节点 [%s] Ping 重试已达到最大限制,将从机节点标记为不可用", node.Model.Name) + node.changeStatus(false) + + if !recoverMode { + // 启动恢复监控循环 + util.Log().Debug("从机节点 [%s] 进入恢复模式", node.Model.Name) + pingTicker = recoverDuration + recoverMode = true + } + } + } else { + if recoverMode { + util.Log().Debug("从机节点 [%s] 复活", node.Model.Name) + pingTicker = tickDuration + recoverMode = false + isFirstLoop = true + } + + util.Log().Debug("从机节点 [%s] 状态: %s", node.Model.Name, res) + node.changeStatus(true) + retry = 0 + } + + case <-node.close: + util.Log().Debug("从机节点 [%s] 收到关闭信号", node.Model.Name) + break loop + } + } +} + +func (node *SlaveNode) IsMater() bool { + return false +} + +func (node *SlaveNode) MasterAuthInstance() auth.Auth { + node.lock.RLock() + defer node.lock.RUnlock() + + return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)} +} + +func (node *SlaveNode) SlaveAuthInstance() auth.Auth { + node.lock.RLock() + defer node.lock.RUnlock() + + return auth.HMACAuth{SecretKey: []byte(node.Model.SlaveKey)} +} + +func (node *SlaveNode) DBModel() *model.Node { + node.lock.RLock() + defer node.lock.RUnlock() + + return node.Model +} + +// getHeartbeatContent gets serializer.NodePingReq used to send heartbeat to slave +func (node *SlaveNode) getHeartbeatContent(isUpdate bool) *serializer.NodePingReq { + return &serializer.NodePingReq{ + SiteURL: model.GetSiteURL().String(), + IsUpdate: isUpdate, + SiteID: model.GetSettingByName("siteID"), + Node: node.Model, + CredentialTTL: model.GetIntSetting("slave_api_timeout", 60), + } +} + +func (node *SlaveNode) changeStatus(isActive bool) { + node.lock.RLock() + id := node.Model.ID + if isActive != node.Active { + node.lock.RUnlock() + node.lock.Lock() + node.Active = isActive + node.lock.Unlock() + node.callback(isActive, id) + } else { + node.lock.RUnlock() + } + +} + +func (s *slaveCaller) Init() error { + return nil +} + +// SendAria2Call send remote aria2 call to slave node +func (s *slaveCaller) SendAria2Call(body *serializer.SlaveAria2Call, scope string) (*serializer.Response, error) { + reqReader, err := getAria2RequestBody(body) + if err != nil { + return nil, err + } + + return s.Client.Request( + "POST", + "aria2/"+scope, + reqReader, + ).CheckHTTPResponse(200).DecodeResponse() +} + +func (s *slaveCaller) CreateTask(task *model.Download, options map[string]interface{}) (string, error) { + s.parent.lock.RLock() + defer s.parent.lock.RUnlock() + + req := &serializer.SlaveAria2Call{ + Task: task, + GroupOptions: options, + } + + res, err := s.SendAria2Call(req, "task") + if err != nil { + return "", err + } + + if res.Code != 0 { + return "", serializer.NewErrorFromResponse(res) + } + + return res.Data.(string), err +} + +func (s *slaveCaller) Status(task *model.Download) (rpc.StatusInfo, error) { + s.parent.lock.RLock() + defer s.parent.lock.RUnlock() + + req := &serializer.SlaveAria2Call{ + Task: task, + } + + res, err := s.SendAria2Call(req, "status") + if err != nil { + return rpc.StatusInfo{}, err + } + + if res.Code != 0 { + return rpc.StatusInfo{}, serializer.NewErrorFromResponse(res) + } + + var status rpc.StatusInfo + res.GobDecode(&status) + + return status, err +} + +func (s *slaveCaller) Cancel(task *model.Download) error { + s.parent.lock.RLock() + defer s.parent.lock.RUnlock() + + req := &serializer.SlaveAria2Call{ + Task: task, + } + + res, err := s.SendAria2Call(req, "cancel") + if err != nil { + return err + } + + if res.Code != 0 { + return serializer.NewErrorFromResponse(res) + } + + return nil +} + +func (s *slaveCaller) Select(task *model.Download, files []int) error { + s.parent.lock.RLock() + defer s.parent.lock.RUnlock() + + req := &serializer.SlaveAria2Call{ + Task: task, + Files: files, + } + + res, err := s.SendAria2Call(req, "select") + if err != nil { + return err + } + + if res.Code != 0 { + return serializer.NewErrorFromResponse(res) + } + + return nil +} + +func (s *slaveCaller) GetConfig() model.Aria2Option { + s.parent.lock.RLock() + defer s.parent.lock.RUnlock() + + return s.parent.Model.Aria2OptionsSerialized +} + +func (s *slaveCaller) DeleteTempFile(task *model.Download) error { + s.parent.lock.RLock() + defer s.parent.lock.RUnlock() + + req := &serializer.SlaveAria2Call{ + Task: task, + } + + res, err := s.SendAria2Call(req, "delete") + if err != nil { + return err + } + + if res.Code != 0 { + return serializer.NewErrorFromResponse(res) + } + + return nil +} + +func getAria2RequestBody(body *serializer.SlaveAria2Call) (io.Reader, error) { + reqBodyEncoded, err := json.Marshal(body) + if err != nil { + return nil, err + } + + return strings.NewReader(string(reqBodyEncoded)), nil +} diff --git a/pkg/filesystem/driver/handler.go b/pkg/filesystem/driver/handler.go new file mode 100644 index 0000000..758d386 --- /dev/null +++ b/pkg/filesystem/driver/handler.go @@ -0,0 +1,39 @@ +package driver + +import ( + "context" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "io" + "net/url" +) + +// Handler 存储策略适配器 +type Handler interface { + // 上传文件, dst为文件存储路径,size 为文件大小。上下文关闭 + // 时,应取消上传并清理临时文件 + Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error + + // 删除一个或多个给定路径的文件,返回删除失败的文件路径列表及错误 + Delete(ctx context.Context, files []string) ([]string, error) + + // 获取文件内容 + Get(ctx context.Context, path string) (response.RSCloser, error) + + // 获取缩略图,可直接在ContentResponse中返回文件数据流,也可指 + // 定为重定向 + Thumb(ctx context.Context, path string) (*response.ContentResponse, error) + + // 获取外链/下载地址, + // url - 站点本身地址, + // isDownload - 是否直接下载 + Source(ctx context.Context, path string, url url.URL, ttl int64, isDownload bool, speed int) (string, error) + + // Token 获取有效期为ttl的上传凭证和签名,同时回调会话有效期为sessionTTL + Token(ctx context.Context, ttl int64, callbackKey string) (serializer.UploadCredential, error) + + // List 递归列取远程端path路径下文件、目录,不包含path本身, + // 返回的对象路径以path作为起始根目录. + // recursive - 是否递归列出 + List(ctx context.Context, path string, recursive bool) ([]response.Object, error) +} diff --git a/pkg/filesystem/driver/onedrive/client.go b/pkg/filesystem/driver/onedrive/client.go index 101f9c3..dbbca3c 100644 --- a/pkg/filesystem/driver/onedrive/client.go +++ b/pkg/filesystem/driver/onedrive/client.go @@ -55,7 +55,7 @@ func NewClient(policy *model.Policy) (*Client, error) { ClientID: policy.BucketName, ClientSecret: policy.SecretKey, Redirect: policy.OptionsSerialized.OdRedirect, - Request: request.HTTPClient{}, + Request: request.NewClient(), } if client.Endpoints.DriverResource == "" { diff --git a/pkg/filesystem/driver/onedrive/handler.go b/pkg/filesystem/driver/onedrive/handler.go index 0820764..609fee7 100644 --- a/pkg/filesystem/driver/onedrive/handler.go +++ b/pkg/filesystem/driver/onedrive/handler.go @@ -14,6 +14,7 @@ import ( model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/cache" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" "github.com/cloudreve/Cloudreve/v3/pkg/request" @@ -27,6 +28,16 @@ type Driver struct { HTTPClient request.Client } +// NewDriver 从存储策略初始化新的Driver实例 +func NewDriver(policy *model.Policy) (driver.Handler, error) { + client, err := NewClient(policy) + return Driver{ + Policy: policy, + Client: client, + HTTPClient: request.NewClient(), + }, err +} + // List 列取项目 func (handler Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) { base = strings.TrimPrefix(base, "/") diff --git a/pkg/filesystem/driver/onedrive/lock.go b/pkg/filesystem/driver/onedrive/lock.go new file mode 100644 index 0000000..655936b --- /dev/null +++ b/pkg/filesystem/driver/onedrive/lock.go @@ -0,0 +1,25 @@ +package onedrive + +import "sync" + +// CredentialLock 针对存储策略凭证的锁 +type CredentialLock interface { + Lock(uint) + Unlock(uint) +} + +var GlobalMutex = mutexMap{} + +type mutexMap struct { + locks sync.Map +} + +func (m *mutexMap) Lock(id uint) { + lock, _ := m.locks.LoadOrStore(id, &sync.Mutex{}) + lock.(*sync.Mutex).Lock() +} + +func (m *mutexMap) Unlock(id uint) { + lock, _ := m.locks.LoadOrStore(id, &sync.Mutex{}) + lock.(*sync.Mutex).Unlock() +} diff --git a/pkg/filesystem/driver/onedrive/oauth.go b/pkg/filesystem/driver/onedrive/oauth.go index 9b33d7a..49170fe 100644 --- a/pkg/filesystem/driver/onedrive/oauth.go +++ b/pkg/filesystem/driver/onedrive/oauth.go @@ -10,7 +10,9 @@ import ( "time" "github.com/cloudreve/Cloudreve/v3/pkg/cache" + "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/request" + "github.com/cloudreve/Cloudreve/v3/pkg/slave" "github.com/cloudreve/Cloudreve/v3/pkg/util" ) @@ -124,6 +126,13 @@ func (client *Client) ObtainToken(ctx context.Context, opts ...Option) (*Credent // UpdateCredential 更新凭证,并检查有效期 func (client *Client) UpdateCredential(ctx context.Context) error { + if conf.SystemConfig.Mode == "slave" { + return client.fetchCredentialFromMaster(ctx) + } + + GlobalMutex.Lock(client.Policy.ID) + defer GlobalMutex.Unlock(client.Policy.ID) + // 如果已存在凭证 if client.Credential != nil && client.Credential.AccessToken != "" { // 检查已有凭证是否过期 @@ -160,11 +169,21 @@ func (client *Client) UpdateCredential(ctx context.Context) error { client.Credential = credential // 更新存储策略的 RefreshToken - client.Policy.AccessKey = credential.RefreshToken - client.Policy.SaveAndClearCache() + client.Policy.UpdateAccessKeyAndClearCache(credential.RefreshToken) // 更新缓存 cache.Set("onedrive_"+client.ClientID, *credential, int(expires)) return nil } + +// UpdateCredential 更新凭证,并检查有效期 +func (client *Client) fetchCredentialFromMaster(ctx context.Context) error { + res, err := slave.DefaultController.GetOneDriveToken(client.Policy.MasterID, client.Policy.ID) + if err != nil { + return err + } + + client.Credential = &Credential{AccessToken: res} + return nil +} diff --git a/pkg/filesystem/driver/oss/callback.go b/pkg/filesystem/driver/oss/callback.go index 7ca1e23..e5b41bb 100644 --- a/pkg/filesystem/driver/oss/callback.go +++ b/pkg/filesystem/driver/oss/callback.go @@ -42,7 +42,7 @@ func GetPublicKey(r *http.Request) ([]byte, error) { } // 获取公钥 - client := request.HTTPClient{} + client := request.NewClient() body, err := client.Request("GET", string(pubURL), nil). CheckHTTPResponse(200). GetResponse() diff --git a/pkg/filesystem/driver/oss/handler_test.go b/pkg/filesystem/driver/oss/handler_test.go index 5be01f2..58401f3 100644 --- a/pkg/filesystem/driver/oss/handler_test.go +++ b/pkg/filesystem/driver/oss/handler_test.go @@ -292,7 +292,7 @@ func TestDriver_Get(t *testing.T) { BucketName: "test", Server: "oss-cn-shanghai.aliyuncs.com", }, - HTTPClient: request.HTTPClient{}, + HTTPClient: request.NewClient(), } cache.Set("setting_preview_timeout", "3600", 0) diff --git a/pkg/filesystem/driver/remote/handler.go b/pkg/filesystem/driver/remote/handler.go index 5b9b965..3f77700 100644 --- a/pkg/filesystem/driver/remote/handler.go +++ b/pkg/filesystem/driver/remote/handler.go @@ -49,6 +49,7 @@ func (handler Driver) List(ctx context.Context, path string, recursive bool) ([] handler.getAPIUrl("list"), bodyReader, request.WithCredential(handler.AuthInstance, int64(signTTL)), + request.WithMasterMeta(), ).CheckHTTPResponse(200).DecodeResponse() if err != nil { return res, err @@ -97,7 +98,7 @@ func (handler Driver) getAPIUrl(scope string, routes ...string) string { // Get 获取文件内容 func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) { - // 尝试获取速度限制 TODO 是否需要在这里限制? + // 尝试获取速度限制 speedLimit := 0 if user, ok := ctx.Value(fsctx.UserCtx).(model.User); ok { speedLimit = user.Group.SpeedLimit @@ -116,6 +117,7 @@ func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, nil, request.WithContext(ctx), request.WithTimeout(time.Duration(0)), + request.WithMasterMeta(), ).CheckHTTPResponse(200).GetRSCloser() if err != nil { return nil, err @@ -168,13 +170,15 @@ func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, s handler.Policy.GetUploadURL(), file, request.WithHeader(map[string][]string{ - "Authorization": {credential.Token}, - "X-Policy": {credential.Policy}, - "X-FileName": {fileName}, - "X-Overwrite": {overwrite}, + "X-Policy": {credential.Policy}, + "X-FileName": {fileName}, + "X-Overwrite": {overwrite}, }), request.WithContentLength(int64(size)), request.WithTimeout(time.Duration(0)), + request.WithMasterMeta(), + request.WithSlaveMeta(handler.Policy.AccessKey), + request.WithCredential(handler.AuthInstance, int64(credentialTTL)), ).CheckHTTPResponse(200).DecodeResponse() if err != nil { return err @@ -206,6 +210,8 @@ func (handler Driver) Delete(ctx context.Context, files []string) ([]string, err handler.getAPIUrl("delete"), bodyReader, request.WithCredential(handler.AuthInstance, int64(signTTL)), + request.WithMasterMeta(), + request.WithSlaveMeta(handler.Policy.AccessKey), ).CheckHTTPResponse(200).GetResponse() if err != nil { return files, err diff --git a/pkg/filesystem/driver/s3/handler.go b/pkg/filesystem/driver/s3/handler.go index b338d8f..4502196 100644 --- a/pkg/filesystem/driver/s3/handler.go +++ b/pkg/filesystem/driver/s3/handler.go @@ -172,7 +172,7 @@ func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, } // 获取文件数据流 - client := request.HTTPClient{} + client := request.NewClient() resp, err := client.Request( "GET", downloadURL, diff --git a/pkg/filesystem/driver/shadow/masterinslave/errors.go b/pkg/filesystem/driver/shadow/masterinslave/errors.go new file mode 100644 index 0000000..27d0428 --- /dev/null +++ b/pkg/filesystem/driver/shadow/masterinslave/errors.go @@ -0,0 +1,7 @@ +package masterinslave + +import "errors" + +var ( + ErrNotImplemented = errors.New("this method of shadowed policy is not implemented") +) diff --git a/pkg/filesystem/driver/shadow/masterinslave/handler.go b/pkg/filesystem/driver/shadow/masterinslave/handler.go new file mode 100644 index 0000000..485a9b2 --- /dev/null +++ b/pkg/filesystem/driver/shadow/masterinslave/handler.go @@ -0,0 +1,56 @@ +package masterinslave + +import ( + "context" + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "io" + "net/url" +) + +// Driver 影子存储策略,用于在从机端上传文件 +type Driver struct { + master cluster.Node + handler driver.Handler + policy *model.Policy +} + +// NewDriver 返回新的处理器 +func NewDriver(master cluster.Node, handler driver.Handler, policy *model.Policy) driver.Handler { + return &Driver{ + master: master, + handler: handler, + policy: policy, + } +} + +func (d *Driver) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error { + return d.handler.Put(ctx, file, dst, size) +} + +func (d *Driver) Delete(ctx context.Context, files []string) ([]string, error) { + return d.handler.Delete(ctx, files) +} + +func (d *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) { + return nil, ErrNotImplemented +} + +func (d *Driver) Thumb(ctx context.Context, path string) (*response.ContentResponse, error) { + return nil, ErrNotImplemented +} + +func (d *Driver) Source(ctx context.Context, path string, url url.URL, ttl int64, isDownload bool, speed int) (string, error) { + return "", ErrNotImplemented +} + +func (d *Driver) Token(ctx context.Context, ttl int64, callbackKey string) (serializer.UploadCredential, error) { + return serializer.UploadCredential{}, ErrNotImplemented +} + +func (d *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) { + return nil, ErrNotImplemented +} diff --git a/pkg/filesystem/driver/shadow/slaveinmaster/errors.go b/pkg/filesystem/driver/shadow/slaveinmaster/errors.go new file mode 100644 index 0000000..6acadc8 --- /dev/null +++ b/pkg/filesystem/driver/shadow/slaveinmaster/errors.go @@ -0,0 +1,9 @@ +package slaveinmaster + +import "errors" + +var ( + ErrNotImplemented = errors.New("this method of shadowed policy is not implemented") + ErrSlaveSrcPathNotExist = errors.New("cannot determine source file path in slave node") + ErrWaitResultTimeout = errors.New("timeout waiting for slave transfer result") +) diff --git a/pkg/filesystem/driver/shadow/slaveinmaster/handler.go b/pkg/filesystem/driver/shadow/slaveinmaster/handler.go new file mode 100644 index 0000000..9d13247 --- /dev/null +++ b/pkg/filesystem/driver/shadow/slaveinmaster/handler.go @@ -0,0 +1,121 @@ +package slaveinmaster + +import ( + "bytes" + "context" + "encoding/json" + "errors" + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" + "github.com/cloudreve/Cloudreve/v3/pkg/request" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "io" + "net/url" + "time" +) + +// Driver 影子存储策略,将上传任务指派给从机节点处理,并等待从机通知上传结果 +type Driver struct { + node cluster.Node + handler driver.Handler + policy *model.Policy + client request.Client +} + +// NewDriver 返回新的从机指派处理器 +func NewDriver(node cluster.Node, handler driver.Handler, policy *model.Policy) driver.Handler { + var endpoint *url.URL + if serverURL, err := url.Parse(node.DBModel().Server); err == nil { + var controller *url.URL + controller, _ = url.Parse("/api/v3/slave") + endpoint = serverURL.ResolveReference(controller) + } + + signTTL := model.GetIntSetting("slave_api_timeout", 60) + return &Driver{ + node: node, + handler: handler, + policy: policy, + client: request.NewClient( + request.WithMasterMeta(), + request.WithTimeout(time.Duration(signTTL)*time.Second), + request.WithCredential(node.SlaveAuthInstance(), int64(signTTL)), + request.WithEndpoint(endpoint.String()), + ), + } +} + +// Put 将ctx中指定的从机物理文件由从机上传到目标存储策略 +func (d *Driver) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error { + src, ok := ctx.Value(fsctx.SlaveSrcPath).(string) + if !ok { + return ErrSlaveSrcPathNotExist + } + + req := serializer.SlaveTransferReq{ + Src: src, + Dst: dst, + Policy: d.policy, + } + + body, err := json.Marshal(req) + if err != nil { + return err + } + + // 订阅转存结果 + resChan := mq.GlobalMQ.Subscribe(req.Hash(model.GetSettingByName("siteID")), 0) + defer mq.GlobalMQ.Unsubscribe(req.Hash(model.GetSettingByName("siteID")), resChan) + + res, err := d.client.Request("PUT", "task/transfer", bytes.NewReader(body)). + CheckHTTPResponse(200). + DecodeResponse() + if err != nil { + return err + } + + if res.Code != 0 { + return serializer.NewErrorFromResponse(res) + } + + // 等待转存结果或者超时 + waitTimeout := model.GetIntSetting("slave_transfer_timeout", 172800) + select { + case <-time.After(time.Duration(waitTimeout) * time.Second): + return ErrWaitResultTimeout + case msg := <-resChan: + if msg.Event != serializer.SlaveTransferSuccess { + return errors.New(msg.Content.(serializer.SlaveTransferResult).Error) + } + } + + return nil +} + +func (d *Driver) Delete(ctx context.Context, files []string) ([]string, error) { + return d.handler.Delete(ctx, files) +} + +func (d *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) { + return nil, ErrNotImplemented +} + +func (d *Driver) Thumb(ctx context.Context, path string) (*response.ContentResponse, error) { + return nil, ErrNotImplemented +} + +func (d *Driver) Source(ctx context.Context, path string, url url.URL, ttl int64, isDownload bool, speed int) (string, error) { + return "", ErrNotImplemented +} + +func (d *Driver) Token(ctx context.Context, ttl int64, callbackKey string) (serializer.UploadCredential, error) { + return serializer.UploadCredential{}, ErrNotImplemented +} + +func (d *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) { + return nil, ErrNotImplemented +} diff --git a/pkg/filesystem/filesystem.go b/pkg/filesystem/filesystem.go index af62983..7f176c0 100644 --- a/pkg/filesystem/filesystem.go +++ b/pkg/filesystem/filesystem.go @@ -1,8 +1,12 @@ package filesystem import ( - "context" "errors" + "fmt" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/masterinslave" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/slaveinmaster" "io" "net/http" "net/url" @@ -19,7 +23,6 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/remote" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/s3" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/upyun" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/gin-gonic/gin" @@ -43,36 +46,6 @@ type FileHeader interface { GetVirtualPath() string } -// Handler 存储策略适配器 -type Handler interface { - // 上传文件, dst为文件存储路径,size 为文件大小。上下文关闭 - // 时,应取消上传并清理临时文件 - Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error - - // 删除一个或多个给定路径的文件,返回删除失败的文件路径列表及错误 - Delete(ctx context.Context, files []string) ([]string, error) - - // 获取文件内容 - Get(ctx context.Context, path string) (response.RSCloser, error) - - // 获取缩略图,可直接在ContentResponse中返回文件数据流,也可指 - // 定为重定向 - Thumb(ctx context.Context, path string) (*response.ContentResponse, error) - - // 获取外链/下载地址, - // url - 站点本身地址, - // isDownload - 是否直接下载 - Source(ctx context.Context, path string, url url.URL, ttl int64, isDownload bool, speed int) (string, error) - - // Token 获取有效期为ttl的上传凭证和签名,同时回调会话有效期为sessionTTL - Token(ctx context.Context, ttl int64, callbackKey string) (serializer.UploadCredential, error) - - // List 递归列取远程端path路径下文件、目录,不包含path本身, - // 返回的对象路径以path作为起始根目录. - // recursive - 是否递归列出 - List(ctx context.Context, path string, recursive bool) ([]response.Object, error) -} - // FileSystem 管理文件的文件系统 type FileSystem struct { // 文件系统所有者 @@ -96,7 +69,7 @@ type FileSystem struct { /* 文件系统处理适配器 */ - Handler Handler + Handler driver.Handler // 回收锁 recycleLock sync.Mutex @@ -134,7 +107,6 @@ func NewFileSystem(user *model.User) (*FileSystem, error) { // 分配存储策略适配器 err := fs.DispatchHandler() - // TODO 分配默认钩子 return fs, err } @@ -159,7 +131,6 @@ func NewAnonymousFileSystem() (*FileSystem, error) { } // DispatchHandler 根据存储策略分配文件适配器 -// TODO 完善测试 func (fs *FileSystem) DispatchHandler() error { var policyType string var currentPolicy *model.Policy @@ -184,7 +155,7 @@ func (fs *FileSystem) DispatchHandler() error { case "remote": fs.Handler = remote.Driver{ Policy: currentPolicy, - Client: request.HTTPClient{}, + Client: request.NewClient(), AuthInstance: auth.HMACAuth{[]byte(currentPolicy.SecretKey)}, } return nil @@ -196,7 +167,7 @@ func (fs *FileSystem) DispatchHandler() error { case "oss": fs.Handler = oss.Driver{ Policy: currentPolicy, - HTTPClient: request.HTTPClient{}, + HTTPClient: request.NewClient(), } return nil case "upyun": @@ -205,13 +176,9 @@ func (fs *FileSystem) DispatchHandler() error { } return nil case "onedrive": - client, err := onedrive.NewClient(currentPolicy) - fs.Handler = onedrive.Driver{ - Policy: currentPolicy, - Client: client, - HTTPClient: request.HTTPClient{}, - } - return err + var odErr error + fs.Handler, odErr = onedrive.NewDriver(currentPolicy) + return odErr case "cos": u, _ := url.Parse(currentPolicy.Server) b := &cossdk.BaseURL{BucketURL: u} @@ -223,7 +190,7 @@ func (fs *FileSystem) DispatchHandler() error { SecretKey: currentPolicy.SecretKey, }, }), - HTTPClient: request.HTTPClient{}, + HTTPClient: request.NewClient(), } return nil case "s3": @@ -272,6 +239,30 @@ func NewFileSystemFromCallback(c *gin.Context) (*FileSystem, error) { return fs, err } +// SwitchToSlaveHandler 将负责上传的 Handler 切换为从机节点 +func (fs *FileSystem) SwitchToSlaveHandler(node cluster.Node) { + fs.Handler = slaveinmaster.NewDriver(node, fs.Handler, &fs.User.Policy) +} + +// SwitchToShadowHandler 将负责上传的 Handler 切换为从机节点转存使用的影子处理器 +func (fs *FileSystem) SwitchToShadowHandler(master cluster.Node, masterURL, masterID string) { + switch fs.Policy.Type { + case "remote": + fs.Policy.Type = "local" + fs.DispatchHandler() + case "local": + fs.Policy.Type = "remote" + fs.Policy.Server = masterURL + fs.Policy.AccessKey = fmt.Sprintf("%d", master.ID()) + fs.Policy.SecretKey = master.DBModel().MasterKey + fs.DispatchHandler() + case "onedrive": + fs.Policy.MasterID = masterID + } + + fs.Handler = masterinslave.NewDriver(master, fs.Handler, fs.Policy) +} + // SetTargetFile 设置当前处理的目标文件 func (fs *FileSystem) SetTargetFile(files *[]model.File) { if len(fs.FileTarget) == 0 { diff --git a/pkg/filesystem/fsctx/context.go b/pkg/filesystem/fsctx/context.go index 28a2653..d280638 100644 --- a/pkg/filesystem/fsctx/context.go +++ b/pkg/filesystem/fsctx/context.go @@ -41,4 +41,6 @@ const ( ValidateCapacityOnceCtx // 禁止上传时同名覆盖操作 DisableOverwrite + // 文件在从机节点中的路径 + SlaveSrcPath ) diff --git a/pkg/filesystem/upload.go b/pkg/filesystem/upload.go index 2ff8997..e2b092a 100644 --- a/pkg/filesystem/upload.go +++ b/pkg/filesystem/upload.go @@ -228,12 +228,14 @@ func (fs *FileSystem) UploadFromStream(ctx context.Context, src io.ReadCloser, d } // UploadFromPath 将本机已有文件上传到用户的文件系统 -func (fs *FileSystem) UploadFromPath(ctx context.Context, src, dst string) error { +func (fs *FileSystem) UploadFromPath(ctx context.Context, src, dst string, resetPolicy bool) error { // 重设存储策略 - fs.Policy = &fs.User.Policy - err := fs.DispatchHandler() - if err != nil { - return err + if resetPolicy { + fs.Policy = &fs.User.Policy + err := fs.DispatchHandler() + if err != nil { + return err + } } file, err := os.Open(util.RelativePath(src)) diff --git a/pkg/filesystem/upload_test.go b/pkg/filesystem/upload_test.go index 2c0d827..8473e4f 100644 --- a/pkg/filesystem/upload_test.go +++ b/pkg/filesystem/upload_test.go @@ -226,13 +226,13 @@ func TestFileSystem_UploadFromPath(t *testing.T) { // 文件不存在 { - err := fs.UploadFromPath(ctx, "test/not_exist", "/") + err := fs.UploadFromPath(ctx, "test/not_exist", "/", true) asserts.Error(err) } // 文存在,上传失败 { - err := fs.UploadFromPath(ctx, "tests/test.zip", "/") + err := fs.UploadFromPath(ctx, "tests/test.zip", "/", true) asserts.Error(err) } } diff --git a/pkg/mq/mq.go b/pkg/mq/mq.go new file mode 100644 index 0000000..e7a8a34 --- /dev/null +++ b/pkg/mq/mq.go @@ -0,0 +1,160 @@ +package mq + +import ( + "encoding/gob" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" + "strconv" + "sync" + "time" +) + +// Message 消息事件正文 +type Message struct { + // 消息触发者 + TriggeredBy string + + // 事件标识 + Event string + + // 消息正文 + Content interface{} +} + +type CallbackFunc func(Message) + +// MQ 消息队列 +type MQ interface { + rpc.Notifier + + // 发布一个消息 + Publish(string, Message) + + // 订阅一个消息主题 + Subscribe(string, int) <-chan Message + + // 订阅一个消息主题,注册触发回调函数 + SubscribeCallback(string, CallbackFunc) + + // 取消订阅一个消息主题 + Unsubscribe(string, <-chan Message) +} + +var GlobalMQ = NewMQ() + +func NewMQ() MQ { + return &inMemoryMQ{ + topics: make(map[string][]chan Message), + callbacks: make(map[string][]CallbackFunc), + } +} + +func init() { + gob.Register(Message{}) + gob.Register([]rpc.Event{}) +} + +type inMemoryMQ struct { + topics map[string][]chan Message + callbacks map[string][]CallbackFunc + sync.RWMutex +} + +func (i *inMemoryMQ) Publish(topic string, message Message) { + i.RLock() + subscribersChan, okChan := i.topics[topic] + subscribersCallback, okCallback := i.callbacks[topic] + i.RUnlock() + + if okChan { + go func(subscribersChan []chan Message) { + for i := 0; i < len(subscribersChan); i++ { + select { + case subscribersChan[i] <- message: + case <-time.After(time.Millisecond * 500): + } + } + }(subscribersChan) + + } + + if okCallback { + for i := 0; i < len(subscribersCallback); i++ { + go subscribersCallback[i](message) + } + } +} + +func (i *inMemoryMQ) Subscribe(topic string, buffer int) <-chan Message { + ch := make(chan Message, buffer) + i.Lock() + i.topics[topic] = append(i.topics[topic], ch) + i.Unlock() + return ch +} + +func (i *inMemoryMQ) SubscribeCallback(topic string, callbackFunc CallbackFunc) { + i.Lock() + i.callbacks[topic] = append(i.callbacks[topic], callbackFunc) + i.Unlock() +} + +func (i *inMemoryMQ) Unsubscribe(topic string, sub <-chan Message) { + i.Lock() + defer i.Unlock() + + subscribers, ok := i.topics[topic] + if !ok { + return + } + + var newSubs []chan Message + for _, subscriber := range subscribers { + if subscriber == sub { + continue + } + newSubs = append(newSubs, subscriber) + } + + i.topics[topic] = newSubs +} + +func (i *inMemoryMQ) Aria2Notify(events []rpc.Event, status int) { + for _, event := range events { + i.Publish(event.Gid, Message{ + TriggeredBy: event.Gid, + Event: strconv.FormatInt(int64(status), 10), + Content: events, + }) + } +} + +// OnDownloadStart 下载开始 +func (i *inMemoryMQ) OnDownloadStart(events []rpc.Event) { + i.Aria2Notify(events, common.Downloading) +} + +// OnDownloadPause 下载暂停 +func (i *inMemoryMQ) OnDownloadPause(events []rpc.Event) { + i.Aria2Notify(events, common.Paused) +} + +// OnDownloadStop 下载停止 +func (i *inMemoryMQ) OnDownloadStop(events []rpc.Event) { + i.Aria2Notify(events, common.Canceled) +} + +// OnDownloadComplete 下载完成 +func (i *inMemoryMQ) OnDownloadComplete(events []rpc.Event) { + i.Aria2Notify(events, common.Complete) +} + +// OnDownloadError 下载出错 +func (i *inMemoryMQ) OnDownloadError(events []rpc.Event) { + i.Aria2Notify(events, common.Error) +} + +// OnBtDownloadComplete BT下载完成 +func (i *inMemoryMQ) OnBtDownloadComplete(events []rpc.Event) { + i.Aria2Notify(events, common.Complete) +} diff --git a/pkg/mq/mq_test.go b/pkg/mq/mq_test.go new file mode 100644 index 0000000..9acdd3f --- /dev/null +++ b/pkg/mq/mq_test.go @@ -0,0 +1,149 @@ +package mq + +import ( + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" + "github.com/stretchr/testify/assert" + "sync" + "testing" + "time" +) + +func TestPublishAndSubscribe(t *testing.T) { + t.Parallel() + asserts := assert.New(t) + mq := NewMQ() + + // No subscriber + { + asserts.NotPanics(func() { + mq.Publish("No subscriber", Message{}) + }) + } + + // One channel subscriber + { + topic := "One channel subscriber" + msg := Message{TriggeredBy: "Tester"} + notifier := mq.Subscribe(topic, 0) + mq.Publish(topic, msg) + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + wg.Done() + msgRecv := <-notifier + asserts.Equal(msg, msgRecv) + }() + wg.Wait() + } + + // two channel subscriber + { + topic := "two channel subscriber" + msg := Message{TriggeredBy: "Tester"} + notifier := mq.Subscribe(topic, 0) + notifier2 := mq.Subscribe(topic, 0) + mq.Publish(topic, msg) + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + wg.Done() + msgRecv := <-notifier + asserts.Equal(msg, msgRecv) + }() + go func() { + wg.Done() + msgRecv := <-notifier2 + asserts.Equal(msg, msgRecv) + }() + wg.Wait() + } + + // two channel subscriber, one timeout + { + topic := "two channel subscriber, one timeout" + msg := Message{TriggeredBy: "Tester"} + mq.Subscribe(topic, 0) + notifier2 := mq.Subscribe(topic, 0) + mq.Publish(topic, msg) + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + wg.Done() + msgRecv := <-notifier2 + asserts.Equal(msg, msgRecv) + }() + wg.Wait() + } + + // two channel subscriber, one unsubscribe + { + topic := "two channel subscriber, one unsubscribe" + msg := Message{TriggeredBy: "Tester"} + mq.Subscribe(topic, 0) + notifier2 := mq.Subscribe(topic, 0) + notifier := mq.Subscribe(topic, 0) + mq.Unsubscribe(topic, notifier) + mq.Publish(topic, msg) + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + wg.Done() + msgRecv := <-notifier2 + asserts.Equal(msg, msgRecv) + }() + wg.Wait() + + select { + case <-notifier: + t.Error() + default: + } + } +} + +func TestAria2Interface(t *testing.T) { + t.Parallel() + asserts := assert.New(t) + mq := NewMQ() + var ( + OnDownloadStart int + OnDownloadPause int + OnDownloadStop int + OnDownloadComplete int + OnDownloadError int + ) + l := sync.Mutex{} + + mq.SubscribeCallback("TestAria2Interface", func(message Message) { + asserts.Equal("TestAria2Interface", message.TriggeredBy) + l.Lock() + defer l.Unlock() + switch message.Event { + case "1": + OnDownloadStart++ + case "2": + OnDownloadPause++ + case "5": + OnDownloadStop++ + case "4": + OnDownloadComplete++ + case "3": + OnDownloadError++ + } + }) + + mq.OnDownloadStart([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}}) + mq.OnDownloadPause([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}}) + mq.OnDownloadStop([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}}) + mq.OnDownloadComplete([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}}) + mq.OnDownloadError([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}}) + mq.OnBtDownloadComplete([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}}) + + time.Sleep(time.Duration(500) * time.Millisecond) + + asserts.Equal(2, OnDownloadStart) + asserts.Equal(2, OnDownloadPause) + asserts.Equal(2, OnDownloadStop) + asserts.Equal(4, OnDownloadComplete) + asserts.Equal(2, OnDownloadError) +} diff --git a/pkg/request/options.go b/pkg/request/options.go new file mode 100644 index 0000000..d495757 --- /dev/null +++ b/pkg/request/options.go @@ -0,0 +1,110 @@ +package request + +import ( + "context" + "github.com/cloudreve/Cloudreve/v3/pkg/auth" + "net/http" + "net/url" + "time" +) + +// Option 发送请求的额外设置 +type Option interface { + apply(*options) +} + +type options struct { + timeout time.Duration + header http.Header + sign auth.Auth + signTTL int64 + ctx context.Context + contentLength int64 + masterMeta bool + endpoint *url.URL + slaveNodeID string +} + +type optionFunc func(*options) + +func (f optionFunc) apply(o *options) { + f(o) +} + +func newDefaultOption() *options { + return &options{ + header: http.Header{}, + timeout: time.Duration(30) * time.Second, + contentLength: -1, + } +} + +// WithTimeout 设置请求超时 +func WithTimeout(t time.Duration) Option { + return optionFunc(func(o *options) { + o.timeout = t + }) +} + +// WithContext 设置请求上下文 +func WithContext(c context.Context) Option { + return optionFunc(func(o *options) { + o.ctx = c + }) +} + +// WithCredential 对请求进行签名 +func WithCredential(instance auth.Auth, ttl int64) Option { + return optionFunc(func(o *options) { + o.sign = instance + o.signTTL = ttl + }) +} + +// WithHeader 设置请求Header +func WithHeader(header http.Header) Option { + return optionFunc(func(o *options) { + for k, v := range header { + o.header[k] = v + } + }) +} + +// WithoutHeader 设置清除请求Header +func WithoutHeader(header []string) Option { + return optionFunc(func(o *options) { + for _, v := range header { + delete(o.header, v) + } + + }) +} + +// WithContentLength 设置请求大小 +func WithContentLength(s int64) Option { + return optionFunc(func(o *options) { + o.contentLength = s + }) +} + +// WithMasterMeta 请求时携带主机信息 +func WithMasterMeta() Option { + return optionFunc(func(o *options) { + o.masterMeta = true + }) +} + +// WithSlaveMeta 请求时携带从机信息 +func WithSlaveMeta(s string) Option { + return optionFunc(func(o *options) { + o.slaveNodeID = s + }) +} + +// Endpoint 使用同一的请求Endpoint +func WithEndpoint(endpoint string) Option { + endpointURL, _ := url.Parse(endpoint) + return optionFunc(func(o *options) { + o.endpoint = endpointURL + }) +} diff --git a/pkg/request/request.go b/pkg/request/request.go index 36195c8..c543c2f 100644 --- a/pkg/request/request.go +++ b/pkg/request/request.go @@ -1,23 +1,25 @@ package request import ( - "context" "encoding/json" "errors" "fmt" "io" "io/ioutil" "net/http" - "time" + "path" + "strings" + "sync" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/auth" + "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/util" ) // GeneralClient 通用 HTTP Client -var GeneralClient Client = HTTPClient{} +var GeneralClient Client = NewClient() // Response 请求的响应或错误信息 type Response struct { @@ -32,90 +34,30 @@ type Client interface { // HTTPClient 实现 Client 接口 type HTTPClient struct { + mu sync.Mutex + options *options } -// Option 发送请求的额外设置 -type Option interface { - apply(*options) -} - -type options struct { - timeout time.Duration - header http.Header - sign auth.Auth - signTTL int64 - ctx context.Context - contentLength int64 -} - -type optionFunc func(*options) - -func (f optionFunc) apply(o *options) { - f(o) -} - -func newDefaultOption() *options { - return &options{ - header: http.Header{}, - timeout: time.Duration(30) * time.Second, - contentLength: -1, +func NewClient(opts ...Option) Client { + client := &HTTPClient{ + options: newDefaultOption(), } -} - -// WithTimeout 设置请求超时 -func WithTimeout(t time.Duration) Option { - return optionFunc(func(o *options) { - o.timeout = t - }) -} - -// WithContext 设置请求上下文 -func WithContext(c context.Context) Option { - return optionFunc(func(o *options) { - o.ctx = c - }) -} - -// WithCredential 对请求进行签名 -func WithCredential(instance auth.Auth, ttl int64) Option { - return optionFunc(func(o *options) { - o.sign = instance - o.signTTL = ttl - }) -} - -// WithHeader 设置请求Header -func WithHeader(header http.Header) Option { - return optionFunc(func(o *options) { - for k, v := range header { - o.header[k] = v - } - }) -} - -// WithoutHeader 设置清除请求Header -func WithoutHeader(header []string) Option { - return optionFunc(func(o *options) { - for _, v := range header { - delete(o.header, v) - } - }) -} + for _, o := range opts { + o.apply(client.options) + } -// WithContentLength 设置请求大小 -func WithContentLength(s int64) Option { - return optionFunc(func(o *options) { - o.contentLength = s - }) + return client } // Request 发送HTTP请求 func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Option) *Response { // 应用额外设置 - options := newDefaultOption() + c.mu.Lock() + options := *c.options + c.mu.Unlock() for _, o := range opts { - o.apply(options) + o.apply(&options) } // 创建请求客户端 @@ -126,6 +68,13 @@ func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Optio body = nil } + // 确定请求URL + if options.endpoint != nil { + targetURL := *options.endpoint + targetURL.Path = path.Join(targetURL.Path, target) + target = targetURL.String() + } + // 创建请求 var ( req *http.Request @@ -141,14 +90,36 @@ func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Optio } // 添加请求相关设置 - req.Header = options.header + if options.header != nil { + for k, v := range options.header { + req.Header.Add(k, strings.Join(v, " ")) + } + } + + if options.masterMeta && conf.SystemConfig.Mode == "master" { + req.Header.Add("X-Site-Url", model.GetSiteURL().String()) + req.Header.Add("X-Site-Id", model.GetSettingByName("siteID")) + req.Header.Add("X-Cloudreve-Version", conf.BackendVersion) + } + + if options.slaveNodeID != "" && conf.SystemConfig.Mode == "slave" { + req.Header.Add("X-Node-Id", options.slaveNodeID) + } + if options.contentLength != -1 { req.ContentLength = options.contentLength } // 签名请求 if options.sign != nil { - auth.SignRequest(options.sign, req, options.signTTL) + switch method { + case "PUT", "POST", "PATCH": + auth.SignRequest(options.sign, req, options.signTTL) + default: + if resURL, err := auth.SignURI(options.sign, req.URL.String(), options.signTTL); err == nil { + req.URL = resURL + } + } } // 发送请求 diff --git a/pkg/request/slave.go b/pkg/request/slave.go index 0bd1ca3..2948250 100644 --- a/pkg/request/slave.go +++ b/pkg/request/slave.go @@ -11,6 +11,7 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/serializer" ) +// TODO: move to slave pkg // RemoteCallback 发送远程存储策略上传回调请求 func RemoteCallback(url string, body serializer.UploadCallback) error { callbackBody, err := json.Marshal(struct { diff --git a/pkg/serializer/auth.go b/pkg/serializer/auth.go index 7d11dff..c8b348e 100644 --- a/pkg/serializer/auth.go +++ b/pkg/serializer/auth.go @@ -5,16 +5,15 @@ import "encoding/json" // RequestRawSign 待签名的HTTP请求 type RequestRawSign struct { Path string - Policy string + Header string Body string } // NewRequestSignString 返回JSON格式的待签名字符串 -// TODO 测试 -func NewRequestSignString(path, policy, body string) string { +func NewRequestSignString(path, header, body string) string { req := RequestRawSign{ Path: path, - Policy: policy, + Header: header, Body: body, } res, _ := json.Marshal(req) diff --git a/pkg/serializer/error.go b/pkg/serializer/error.go index 0191cee..d5e971c 100644 --- a/pkg/serializer/error.go +++ b/pkg/serializer/error.go @@ -1,14 +1,9 @@ package serializer -import "github.com/gin-gonic/gin" - -// Response 基础序列化器 -type Response struct { - Code int `json:"code"` - Data interface{} `json:"data,omitempty"` - Msg string `json:"msg"` - Error string `json:"error,omitempty"` -} +import ( + "errors" + "github.com/gin-gonic/gin" +) // AppError 应用错误,实现了error接口 type AppError struct { @@ -17,7 +12,7 @@ type AppError struct { RawError error } -// NewError 返回新的错误对象 todo:测试 还有下面的 +// NewError 返回新的错误对象 func NewError(code int, msg string, err error) AppError { return AppError{ Code: code, @@ -26,6 +21,15 @@ func NewError(code int, msg string, err error) AppError { } } +// NewErrorFromResponse 从 serializer.Response 构建错误 +func NewErrorFromResponse(resp *Response) AppError { + return AppError{ + Code: resp.Code, + Msg: resp.Msg, + RawError: errors.New(resp.Error), + } +} + // WithError 将应用error携带标准库中的error func (err *AppError) WithError(raw error) AppError { err.RawError = raw @@ -66,6 +70,8 @@ const ( CodeGroupNotAllowed = 40007 // CodeAdminRequired 非管理用户组 CodeAdminRequired = 40008 + // CodeMasterNotFound 主机节点未注册 + CodeMasterNotFound = 40009 // CodeDBError 数据库操作失败 CodeDBError = 50001 // CodeEncryptError 加密失败 diff --git a/pkg/serializer/response.go b/pkg/serializer/response.go new file mode 100644 index 0000000..91aae47 --- /dev/null +++ b/pkg/serializer/response.go @@ -0,0 +1,35 @@ +package serializer + +import ( + "bytes" + "encoding/base64" + "encoding/gob" +) + +// Response 基础序列化器 +type Response struct { + Code int `json:"code"` + Data interface{} `json:"data,omitempty"` + Msg string `json:"msg"` + Error string `json:"error,omitempty"` +} + +// NewResponseWithGobData 返回Data字段使用gob编码的Response +func NewResponseWithGobData(data interface{}) Response { + var w bytes.Buffer + encoder := gob.NewEncoder(&w) + if err := encoder.Encode(data); err != nil { + return Err(CodeInternalSetting, "无法编码返回结果", err) + } + + return Response{Data: w.Bytes()} +} + +// GobDecode 将 Response 正文解码至目标指针 +func (r *Response) GobDecode(target interface{}) { + src := r.Data.(string) + raw := make([]byte, len(src)*len(src)/base64.StdEncoding.DecodedLen(len(src))) + base64.StdEncoding.Decode(raw, []byte(src)) + decoder := gob.NewDecoder(bytes.NewBuffer(raw)) + decoder.Decode(target) +} diff --git a/pkg/serializer/slave.go b/pkg/serializer/slave.go index e23e809..245767a 100644 --- a/pkg/serializer/slave.go +++ b/pkg/serializer/slave.go @@ -1,5 +1,12 @@ package serializer +import ( + "crypto/sha1" + "encoding/gob" + "fmt" + model "github.com/cloudreve/Cloudreve/v3/models" +) + // RemoteDeleteRequest 远程策略删除接口请求正文 type RemoteDeleteRequest struct { Files []string `json:"files"` @@ -10,3 +17,51 @@ type ListRequest struct { Path string `json:"path"` Recursive bool `json:"recursive"` } + +// NodePingReq 从机节点Ping请求 +type NodePingReq struct { + SiteURL string `json:"site_url"` + SiteID string `json:"site_id"` + IsUpdate bool `json:"is_update"` + CredentialTTL int `json:"credential_ttl"` + Node *model.Node `json:"node"` +} + +// NodePingResp 从机节点Ping响应 +type NodePingResp struct { +} + +// SlaveAria2Call 从机有关Aria2的请求正文 +type SlaveAria2Call struct { + Task *model.Download `json:"task"` + GroupOptions map[string]interface{} `json:"group_options"` + Files []int `json:"files"` +} + +// SlaveTransferReq 从机中转任务创建请求 +type SlaveTransferReq struct { + Src string `json:"src"` + Dst string `json:"dst"` + Policy *model.Policy `json:"policy"` +} + +// Hash 返回创建请求的唯一标识,保持创建请求幂等 +func (s *SlaveTransferReq) Hash(id string) string { + h := sha1.New() + h.Write([]byte(fmt.Sprintf("transfer-%s-%s-%s-%d", id, s.Src, s.Dst, s.Policy.ID))) + bs := h.Sum(nil) + return fmt.Sprintf("%x", bs) +} + +const ( + SlaveTransferSuccess = "success" + SlaveTransferFailed = "failed" +) + +type SlaveTransferResult struct { + Error string +} + +func init() { + gob.Register(SlaveTransferResult{}) +} diff --git a/pkg/slave/errors.go b/pkg/slave/errors.go new file mode 100644 index 0000000..2af6e13 --- /dev/null +++ b/pkg/slave/errors.go @@ -0,0 +1,7 @@ +package slave + +import "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + +var ( + ErrMasterNotFound = serializer.NewError(serializer.CodeMasterNotFound, "未知的主机节点", nil) +) diff --git a/pkg/slave/slave.go b/pkg/slave/slave.go new file mode 100644 index 0000000..aa457b5 --- /dev/null +++ b/pkg/slave/slave.go @@ -0,0 +1,209 @@ +package slave + +import ( + "bytes" + "encoding/gob" + "fmt" + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" + "github.com/cloudreve/Cloudreve/v3/pkg/auth" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" + "github.com/cloudreve/Cloudreve/v3/pkg/request" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/jinzhu/gorm" + "net/url" + "sync" +) + +var DefaultController Controller + +// Controller controls communications between master and slave +type Controller interface { + // Handle heartbeat sent from master + HandleHeartBeat(*serializer.NodePingReq) (serializer.NodePingResp, error) + + // Get Aria2 Instance by master node ID + GetAria2Instance(string) (common.Aria2, error) + + // Send event change message to master node + SendNotification(string, string, mq.Message) error + + // Submit async task into task pool + SubmitTask(string, interface{}, string, func(interface{})) error + + // Get master node info + GetMasterInfo(string) (*MasterInfo, error) + + // Get master OneDrive policy credential + GetOneDriveToken(string, uint) (string, error) +} + +type slaveController struct { + masters map[string]MasterInfo + lock sync.RWMutex +} + +// info of master node +type MasterInfo struct { + ID string + TTL int + URL *url.URL + // used to invoke aria2 rpc calls + Instance cluster.Node + Client request.Client + + jobTracker map[string]bool +} + +func Init() { + DefaultController = &slaveController{ + masters: make(map[string]MasterInfo), + } + gob.Register(rpc.StatusInfo{}) +} + +func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializer.NodePingResp, error) { + c.lock.Lock() + defer c.lock.Unlock() + + req.Node.AfterFind() + + // close old node if exist + origin, ok := c.masters[req.SiteID] + + if (ok && req.IsUpdate) || !ok { + if ok { + origin.Instance.Kill() + } + + masterUrl, err := url.Parse(req.SiteURL) + if err != nil { + return serializer.NodePingResp{}, err + } + + c.masters[req.SiteID] = MasterInfo{ + ID: req.SiteID, + URL: masterUrl, + TTL: req.CredentialTTL, + Client: request.NewClient( + request.WithEndpoint(masterUrl.String()), + request.WithSlaveMeta(fmt.Sprintf("%d", req.Node.ID)), + request.WithCredential(auth.HMACAuth{ + SecretKey: []byte(req.Node.MasterKey), + }, int64(req.CredentialTTL)), + ), + jobTracker: make(map[string]bool), + Instance: cluster.NewNodeFromDBModel(&model.Node{ + Model: gorm.Model{ID: req.Node.ID}, + MasterKey: req.Node.MasterKey, + Type: model.MasterNodeType, + Aria2Enabled: req.Node.Aria2Enabled, + Aria2OptionsSerialized: req.Node.Aria2OptionsSerialized, + }), + } + } + + return serializer.NodePingResp{}, nil +} + +func (c *slaveController) GetAria2Instance(id string) (common.Aria2, error) { + c.lock.RLock() + defer c.lock.RUnlock() + + if node, ok := c.masters[id]; ok { + return node.Instance.GetAria2Instance(), nil + } + + return nil, ErrMasterNotFound +} + +func (c *slaveController) SendNotification(id, subject string, msg mq.Message) error { + c.lock.RLock() + + if node, ok := c.masters[id]; ok { + c.lock.RUnlock() + + body := bytes.Buffer{} + enc := gob.NewEncoder(&body) + if err := enc.Encode(&msg); err != nil { + return err + } + + res, err := node.Client.Request( + "PUT", + fmt.Sprintf("/api/v3/slave/notification/%s", subject), + &body, + ).CheckHTTPResponse(200).DecodeResponse() + if err != nil { + return err + } + + if res.Code != 0 { + return serializer.NewErrorFromResponse(res) + } + + return nil + } + + c.lock.RUnlock() + return ErrMasterNotFound +} + +// SubmitTask 提交异步任务 +func (c *slaveController) SubmitTask(id string, job interface{}, hash string, submitter func(interface{})) error { + c.lock.RLock() + defer c.lock.RUnlock() + + if node, ok := c.masters[id]; ok { + if _, ok := node.jobTracker[hash]; ok { + // 任务已存在,直接返回 + return nil + } + + submitter(job) + return nil + } + + return ErrMasterNotFound +} + +// GetMasterInfo 获取主机节点信息 +func (c *slaveController) GetMasterInfo(id string) (*MasterInfo, error) { + c.lock.RLock() + defer c.lock.RUnlock() + + if node, ok := c.masters[id]; ok { + return &node, nil + } + + return nil, ErrMasterNotFound +} + +// GetOneDriveToken 获取主机OneDrive凭证 +func (c *slaveController) GetOneDriveToken(id string, policyID uint) (string, error) { + c.lock.RLock() + + if node, ok := c.masters[id]; ok { + c.lock.RUnlock() + + res, err := node.Client.Request( + "GET", + fmt.Sprintf("/api/v3/slave/credential/onedrive/%d", policyID), + nil, + ).CheckHTTPResponse(200).DecodeResponse() + if err != nil { + return "", err + } + + if res.Code != 0 { + return "", serializer.NewErrorFromResponse(res) + } + + return res.Data.(string), nil + } + + c.lock.RUnlock() + return "", ErrMasterNotFound +} diff --git a/pkg/task/compress.go b/pkg/task/compress.go index b134922..95b06d5 100644 --- a/pkg/task/compress.go +++ b/pkg/task/compress.go @@ -106,7 +106,7 @@ func (job *CompressTask) Do() { job.TaskModel.SetProgress(TransferringProgress) // 上传文件 - err = fs.UploadFromPath(ctx, zipFile, job.TaskProps.Dst) + err = fs.UploadFromPath(ctx, zipFile, job.TaskProps.Dst, true) if err != nil { job.SetErrorMsg(err.Error()) return diff --git a/pkg/task/job.go b/pkg/task/job.go index 22adc79..064b078 100644 --- a/pkg/task/job.go +++ b/pkg/task/job.go @@ -96,7 +96,9 @@ func Resume() { continue } - TaskPoll.Submit(job) + if job != nil { + TaskPoll.Submit(job) + } } } diff --git a/pkg/task/pool.go b/pkg/task/pool.go index d44877d..4fe550f 100644 --- a/pkg/task/pool.go +++ b/pkg/task/pool.go @@ -2,6 +2,7 @@ package task import ( model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/util" ) @@ -56,5 +57,7 @@ func Init() { TaskPoll.Add(maxWorker) util.Log().Info("初始化任务队列,WorkerNum = %d", maxWorker) - Resume() + if conf.SystemConfig.Mode == "master" { + Resume() + } } diff --git a/pkg/task/slavetask/transfer.go b/pkg/task/slavetask/transfer.go new file mode 100644 index 0000000..c312742 --- /dev/null +++ b/pkg/task/slavetask/transfer.go @@ -0,0 +1,145 @@ +package slavetask + +import ( + "context" + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/cloudreve/Cloudreve/v3/pkg/slave" + "github.com/cloudreve/Cloudreve/v3/pkg/task" + "github.com/cloudreve/Cloudreve/v3/pkg/util" + "os" + "path/filepath" +) + +// TransferTask 文件中转任务 +type TransferTask struct { + Err *task.JobError + Req *serializer.SlaveTransferReq + MasterID string +} + +// Props 获取任务属性 +func (job *TransferTask) Props() string { + return "" +} + +// Type 获取任务类型 +func (job *TransferTask) Type() int { + return 0 +} + +// Creator 获取创建者ID +func (job *TransferTask) Creator() uint { + return 0 +} + +// Model 获取任务的数据库模型 +func (job *TransferTask) Model() *model.Task { + return nil +} + +// SetStatus 设定状态 +func (job *TransferTask) SetStatus(status int) { +} + +// SetError 设定任务失败信息 +func (job *TransferTask) SetError(err *task.JobError) { + job.Err = err + +} + +// SetErrorMsg 设定任务失败信息 +func (job *TransferTask) SetErrorMsg(msg string, err error) { + jobErr := &task.JobError{Msg: msg} + if err != nil { + jobErr.Error = err.Error() + } + + job.SetError(jobErr) + + notifyMsg := mq.Message{ + TriggeredBy: job.MasterID, + Event: serializer.SlaveTransferFailed, + Content: serializer.SlaveTransferResult{ + Error: err.Error(), + }, + } + + if err := slave.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), notifyMsg); err != nil { + util.Log().Warning("无法发送转存失败通知到从机, ", err) + } +} + +// GetError 返回任务失败信息 +func (job *TransferTask) GetError() *task.JobError { + return job.Err +} + +// Do 开始执行任务 +func (job *TransferTask) Do() { + defer job.Recycle() + + fs, err := filesystem.NewAnonymousFileSystem() + if err != nil { + job.SetErrorMsg("无法初始化匿名文件系统", err) + return + } + + fs.Policy = job.Req.Policy + if err := fs.DispatchHandler(); err != nil { + job.SetErrorMsg("无法分发存储策略", err) + return + } + + master, err := slave.DefaultController.GetMasterInfo(job.MasterID) + if err != nil { + job.SetErrorMsg("找不到主机节点", err) + return + } + + fs.SwitchToShadowHandler(master.Instance, master.URL.String(), master.ID) + ctx := context.WithValue(context.Background(), fsctx.DisableOverwrite, true) + file, err := os.Open(util.RelativePath(job.Req.Src)) + if err != nil { + job.SetErrorMsg("无法读取源文件", err) + return + } + + defer file.Close() + + // 获取源文件大小 + fi, err := file.Stat() + if err != nil { + job.SetErrorMsg("无法获取源文件大小", err) + return + } + + size := fi.Size() + + err = fs.Handler.Put(ctx, file, job.Req.Dst, uint64(size)) + if err != nil { + job.SetErrorMsg("文件上传失败", err) + return + } + + msg := mq.Message{ + TriggeredBy: job.MasterID, + Event: serializer.SlaveTransferSuccess, + Content: serializer.SlaveTransferResult{}, + } + + if err := slave.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), msg); err != nil { + util.Log().Warning("无法发送转存成功通知到从机, ", err) + } +} + +// Recycle 回收临时文件 +func (job *TransferTask) Recycle() { + err := os.RemoveAll(filepath.Dir(job.Req.Src)) + if err != nil { + util.Log().Warning("无法删除中转临时目录[%s], %s", job.Req.Src, err) + } +} diff --git a/pkg/task/tranfer.go b/pkg/task/tranfer.go index 8cdc247..5db638d 100644 --- a/pkg/task/tranfer.go +++ b/pkg/task/tranfer.go @@ -9,6 +9,7 @@ import ( "strings" model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/util" @@ -26,11 +27,14 @@ type TransferTask struct { // TransferProps 中转任务属性 type TransferProps struct { - Src []string `json:"src"` // 原始文件 - Parent string `json:"parent"` // 父目录 - Dst string `json:"dst"` // 目的目录ID + Src []string `json:"src"` // 原始文件 + SrcSizes map[string]uint64 `json:"src_size"` // 原始文件的大小信息,从机转存时使用 + Parent string `json:"parent"` // 父目录 + Dst string `json:"dst"` // 目的目录ID // 将会保留原始文件的目录结构,Src 除去 Parent 开头作为最终路径 TrimPath bool `json:"trim_path"` + // 负责处理中专任务的节点ID + NodeID uint `json:"node_id"` } // Props 获取任务属性 @@ -104,7 +108,24 @@ func (job *TransferTask) Do() { } ctx := context.WithValue(context.Background(), fsctx.DisableOverwrite, true) - err = fs.UploadFromPath(ctx, file, dst) + ctx = context.WithValue(ctx, fsctx.SlaveSrcPath, file) + if job.TaskProps.NodeID > 1 { + // 指定为从机中转 + + // 获取从机节点 + node := cluster.Default.GetNodeByID(job.TaskProps.NodeID) + if node == nil { + job.SetErrorMsg("从机节点不可用", nil) + } + + // 切换为从机节点处理上传 + fs.SwitchToSlaveHandler(node) + err = fs.UploadFromStream(ctx, nil, dst, job.TaskProps.SrcSizes[file]) + } else { + // 主机节点中转 + err = fs.UploadFromPath(ctx, file, dst, true) + } + if err != nil { job.SetErrorMsg("文件转存失败", err) } @@ -114,15 +135,16 @@ func (job *TransferTask) Do() { // Recycle 回收临时文件 func (job *TransferTask) Recycle() { - err := os.RemoveAll(job.TaskProps.Parent) - if err != nil { - util.Log().Warning("无法删除中转临时目录[%s], %s", job.TaskProps.Parent, err) + if job.TaskProps.NodeID == 1 { + err := os.RemoveAll(job.TaskProps.Parent) + if err != nil { + util.Log().Warning("无法删除中转临时目录[%s], %s", job.TaskProps.Parent, err) + } } - } // NewTransferTask 新建中转任务 -func NewTransferTask(user uint, src []string, dst, parent string, trim bool) (Job, error) { +func NewTransferTask(user uint, src []string, dst, parent string, trim bool, node uint, sizes map[string]uint64) (Job, error) { creator, err := model.GetActiveUserByID(user) if err != nil { return nil, err @@ -135,6 +157,8 @@ func NewTransferTask(user uint, src []string, dst, parent string, trim bool) (Jo Parent: parent, Dst: dst, TrimPath: trim, + NodeID: node, + SrcSizes: sizes, }, } diff --git a/routers/controllers/admin.go b/routers/controllers/admin.go index c0dc4c9..a3ebfa5 100644 --- a/routers/controllers/admin.go +++ b/routers/controllers/admin.go @@ -3,6 +3,7 @@ package controllers import ( "io" + model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/aria2" "github.com/cloudreve/Cloudreve/v3/pkg/email" "github.com/cloudreve/Cloudreve/v3/pkg/request" @@ -24,7 +25,7 @@ func AdminSummary(c *gin.Context) { // AdminNews 获取社区新闻 func AdminNews(c *gin.Context) { - r := request.HTTPClient{} + r := request.NewClient() res := r.Request("GET", "https://forum.cloudreve.org/api/discussions?include=startUser%2ClastUser%2CstartPost%2Ctags&filter%5Bq%5D=%20tag%3Anotice&sort=-startTime&page%5Blimit%5D=10", nil) if res.Err == nil { io.Copy(c.Writer, res.Response.Body) @@ -92,7 +93,13 @@ func AdminSendTestMail(c *gin.Context) { func AdminTestAria2(c *gin.Context) { var service admin.Aria2TestService if err := c.ShouldBindJSON(&service); err == nil { - res := service.Test() + var res serializer.Response + if service.Type == model.MasterNodeType { + res = service.TestMaster() + } else { + res = service.TestSlave() + } + c.JSON(200, res) } else { c.JSON(200, ErrorResponse(err)) @@ -425,3 +432,58 @@ func AdminListFolders(c *gin.Context) { c.JSON(200, ErrorResponse(err)) } } + +// AdminListNodes 列出从机节点 +func AdminListNodes(c *gin.Context) { + var service admin.AdminListService + if err := c.ShouldBindJSON(&service); err == nil { + res := service.Nodes() + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +// AdminAddNode 新建节点 +func AdminAddNode(c *gin.Context) { + var service admin.AddNodeService + if err := c.ShouldBindJSON(&service); err == nil { + res := service.Add() + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +// AdminToggleNode 启用/暂停节点 +func AdminToggleNode(c *gin.Context) { + var service admin.ToggleNodeService + if err := c.ShouldBindUri(&service); err == nil { + res := service.Toggle() + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +// AdminDeleteGroup 删除用户组 +func AdminDeleteNode(c *gin.Context) { + var service admin.NodeService + if err := c.ShouldBindUri(&service); err == nil { + res := service.Delete() + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +// AdminGetNode 获取节点详情 +func AdminGetNode(c *gin.Context) { + var service admin.NodeService + if err := c.ShouldBindUri(&service); err == nil { + res := service.Get() + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} diff --git a/routers/controllers/aria2.go b/routers/controllers/aria2.go index b2bc6d6..25a8fb0 100644 --- a/routers/controllers/aria2.go +++ b/routers/controllers/aria2.go @@ -3,7 +3,7 @@ package controllers import ( "context" - ariaCall "github.com/cloudreve/Cloudreve/v3/pkg/aria2" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/service/aria2" "github.com/cloudreve/Cloudreve/v3/service/explorer" "github.com/gin-gonic/gin" @@ -13,7 +13,7 @@ import ( func AddAria2URL(c *gin.Context) { var addService aria2.AddURLService if err := c.ShouldBindJSON(&addService); err == nil { - res := addService.Add(c, ariaCall.URLTask) + res := addService.Add(c, common.URLTask) c.JSON(200, res) } else { c.JSON(200, ErrorResponse(err)) @@ -52,7 +52,7 @@ func AddAria2Torrent(c *gin.Context) { if err := c.ShouldBindJSON(&addService); err == nil { addService.URL = res.Data.(string) - res := addService.Add(c, ariaCall.URLTask) + res := addService.Add(c, common.URLTask) c.JSON(200, res) } else { c.JSON(200, ErrorResponse(err)) diff --git a/routers/controllers/slave.go b/routers/controllers/slave.go index e10e2b0..10c46ff 100644 --- a/routers/controllers/slave.go +++ b/routers/controllers/slave.go @@ -10,7 +10,9 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/service/admin" + "github.com/cloudreve/Cloudreve/v3/service/aria2" "github.com/cloudreve/Cloudreve/v3/service/explorer" + "github.com/cloudreve/Cloudreve/v3/service/node" "github.com/gin-gonic/gin" ) @@ -175,3 +177,102 @@ func SlaveList(c *gin.Context) { c.JSON(200, ErrorResponse(err)) } } + +// SlaveHeartbeat 接受主机心跳包 +func SlaveHeartbeat(c *gin.Context) { + var service serializer.NodePingReq + if err := c.ShouldBindJSON(&service); err == nil { + res := node.HandleMasterHeartbeat(&service) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +// SlaveAria2Create 创建 Aria2 任务 +func SlaveAria2Create(c *gin.Context) { + var service serializer.SlaveAria2Call + if err := c.ShouldBindJSON(&service); err == nil { + res := aria2.Add(c, &service) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +// SlaveAria2Status 查询从机 Aria2 任务状态 +func SlaveAria2Status(c *gin.Context) { + var service serializer.SlaveAria2Call + if err := c.ShouldBindJSON(&service); err == nil { + res := aria2.SlaveStatus(c, &service) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +// SlaveCancelAria2Task 取消从机离线下载任务 +func SlaveCancelAria2Task(c *gin.Context) { + var service serializer.SlaveAria2Call + if err := c.ShouldBindJSON(&service); err == nil { + res := aria2.SlaveCancel(c, &service) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +// SlaveSelectTask 从机选取离线下载文件 +func SlaveSelectTask(c *gin.Context) { + var service serializer.SlaveAria2Call + if err := c.ShouldBindJSON(&service); err == nil { + res := aria2.SlaveSelect(c, &service) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +// SlaveCreateTransferTask 从机创建中转任务 +func SlaveCreateTransferTask(c *gin.Context) { + var service serializer.SlaveTransferReq + if err := c.ShouldBindJSON(&service); err == nil { + res := explorer.CreateTransferTask(c, &service) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +// SlaveNotificationPush 处理从机发送的消息推送 +func SlaveNotificationPush(c *gin.Context) { + var service node.SlaveNotificationService + if err := c.ShouldBindUri(&service); err == nil { + res := service.HandleSlaveNotificationPush(c) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +// SlaveGetOneDriveCredential 从机获取主机的OneDrive存储策略凭证 +func SlaveGetOneDriveCredential(c *gin.Context) { + var service node.OneDriveCredentialService + if err := c.ShouldBindUri(&service); err == nil { + res := service.Get(c) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +// SlaveSelectTask 从机删除离线下载临时文件 +func SlaveDeleteTempFile(c *gin.Context) { + var service serializer.SlaveAria2Call + if err := c.ShouldBindJSON(&service); err == nil { + res := aria2.SlaveDeleteTemp(c, &service) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} diff --git a/routers/router.go b/routers/router.go index 3ae63e9..a7204c4 100644 --- a/routers/router.go +++ b/routers/router.go @@ -2,6 +2,7 @@ package routers import ( "github.com/cloudreve/Cloudreve/v3/middleware" + "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/hashid" "github.com/cloudreve/Cloudreve/v3/pkg/util" @@ -29,7 +30,9 @@ func InitSlaveRouter() *gin.Engine { InitCORS(r) v3 := r.Group("/api/v3/slave") // 鉴权中间件 - v3.Use(middleware.SignRequired()) + v3.Use(middleware.SignRequired(auth.General)) + // 主机信息解析 + v3.Use(middleware.MasterMetadata()) /* 路由 @@ -37,6 +40,10 @@ func InitSlaveRouter() *gin.Engine { { // Ping v3.POST("ping", controllers.SlavePing) + // 测试 Aria2 RPC 连接 + v3.POST("ping/aria2", controllers.AdminTestAria2) + // 接收主机心跳包 + v3.POST("heartbeat", controllers.SlaveHeartbeat) // 上传 v3.POST("upload", controllers.SlaveUpload) // 下载 @@ -49,6 +56,28 @@ func InitSlaveRouter() *gin.Engine { v3.POST("delete", controllers.SlaveDelete) // 列出文件 v3.POST("list", controllers.SlaveList) + + // 离线下载 + aria2 := v3.Group("aria2") + aria2.Use(middleware.UseSlaveAria2Instance()) + { + // 创建离线下载任务 + aria2.POST("task", controllers.SlaveAria2Create) + // 获取任务状态 + aria2.POST("status", controllers.SlaveAria2Status) + // 取消离线下载任务 + aria2.POST("cancel", controllers.SlaveCancelAria2Task) + // 选取任务文件 + aria2.POST("select", controllers.SlaveSelectTask) + // 删除任务临时文件 + aria2.POST("delete", controllers.SlaveDeleteTempFile) + } + + // 异步任务 + task := v3.Group("task") + { + task.PUT("transfer", controllers.SlaveCreateTransferTask) + } } return r } @@ -131,7 +160,7 @@ func InitMasterRouter() *gin.Engine { user.PATCH("reset", controllers.UserReset) // 邮件激活 user.GET("activate/:id", - middleware.SignRequired(), + middleware.SignRequired(auth.General), middleware.HashID(hashid.UserID), controllers.UserActivate, ) @@ -159,7 +188,7 @@ func InitMasterRouter() *gin.Engine { // 需要携带签名验证的 sign := v3.Group("") - sign.Use(middleware.SignRequired()) + sign.Use(middleware.SignRequired(auth.General)) { file := sign.Group("file") { @@ -174,6 +203,18 @@ func InitMasterRouter() *gin.Engine { } } + // 从机的 RPC 通信 + slave := v3.Group("slave") + slave.Use(middleware.SlaveRPCSignRequired()) + { + // 事件通知 + slave.PUT("notification/:subject", controllers.SlaveNotificationPush) + // 上传 + slave.POST("upload", controllers.SlaveUpload) + // OneDrive 存储策略凭证 + slave.GET("credential/onedrive/:id", controllers.SlaveGetOneDriveCredential) + } + // 回调接口 callback := v3.Group("callback") { @@ -405,6 +446,22 @@ func InitMasterRouter() *gin.Engine { task.POST("import", controllers.AdminCreateImportTask) } + node := admin.Group("node") + { + // 列出从机节点 + node.POST("list", controllers.AdminListNodes) + // 列出从机节点 + node.POST("aria2/test", controllers.AdminTestAria2) + // 创建/保存节点 + node.POST("", controllers.AdminAddNode) + // 启用/暂停节点 + node.PATCH("enable/:id/:desired", controllers.AdminToggleNode) + // 删除节点 + node.DELETE(":id", controllers.AdminDeleteNode) + // 获取节点 + node.GET(":id", controllers.AdminGetNode) + } + } // 用户 diff --git a/service/admin/aria2.go b/service/admin/aria2.go index 8801c96..0df3275 100644 --- a/service/admin/aria2.go +++ b/service/admin/aria2.go @@ -1,43 +1,71 @@ package admin import ( + "bytes" + "encoding/json" + model "github.com/cloudreve/Cloudreve/v3/models" "net/url" + "time" "github.com/cloudreve/Cloudreve/v3/pkg/aria2" + "github.com/cloudreve/Cloudreve/v3/pkg/auth" + "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" ) // Aria2TestService aria2连接测试服务 type Aria2TestService struct { - Server string `json:"server" binding:"required"` - Token string `json:"token"` + Server string `json:"server" binding:"required"` + RPC string `json:"rpc" binding:"required"` + Secret string `json:"secret" binding:"required"` + Token string `json:"token"` + Type model.ModelType `json:"type"` } // Test 测试aria2连接 -func (service *Aria2TestService) Test() serializer.Response { - testRPC := aria2.RPCService{} - - // 解析RPC服务地址 - server, err := url.Parse(service.Server) +func (service *Aria2TestService) TestMaster() serializer.Response { + res, err := aria2.TestRPCConnection(service.RPC, service.Token, 5) if err != nil { - return serializer.ParamErr("无法解析 aria2 RPC 服务地址, "+err.Error(), nil) + return serializer.ParamErr(err.Error(), err) } - server.Path = "/jsonrpc" - if err := testRPC.Init(server.String(), service.Token, 5, map[string]interface{}{}); err != nil { - return serializer.ParamErr("无法初始化连接, "+err.Error(), nil) + if res.Version == "" { + return serializer.ParamErr("RPC 服务返回非预期响应", nil) } - defer testRPC.Caller.Close() + return serializer.Response{Data: res.Version} +} - info, err := testRPC.Caller.GetVersion() +func (service *Aria2TestService) TestSlave() serializer.Response { + slave, err := url.Parse(service.Server) if err != nil { - return serializer.ParamErr("无法请求 RPC 服务, "+err.Error(), nil) + return serializer.ParamErr("无法解析从机端地址,"+err.Error(), nil) } - if info.Version == "" { - return serializer.ParamErr("RPC 服务返回非预期响应", nil) + controller, _ := url.Parse("/api/v3/slave/ping/aria2") + + // 请求正文 + service.Type = model.MasterNodeType + bodyByte, _ := json.Marshal(service) + + r := request.NewClient() + res, err := r.Request( + "POST", + slave.ResolveReference(controller).String(), + bytes.NewReader(bodyByte), + request.WithTimeout(time.Duration(10)*time.Second), + request.WithCredential( + auth.HMACAuth{SecretKey: []byte(service.Secret)}, + int64(model.GetIntSetting("slave_api_timeout", 60)), + ), + ).DecodeResponse() + if err != nil { + return serializer.ParamErr("无连接到从机,"+err.Error(), nil) + } + + if res.Code != 0 { + return serializer.ParamErr("成功接到从机,但是从机返回:"+res.Msg, nil) } - return serializer.Response{Data: info.Version} + return serializer.Response{Data: res.Data.(string)} } diff --git a/service/admin/node.go b/service/admin/node.go new file mode 100644 index 0000000..7d52dbd --- /dev/null +++ b/service/admin/node.go @@ -0,0 +1,138 @@ +package admin + +import ( + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "strings" +) + +// AddNodeService 节点添加服务 +type AddNodeService struct { + Node model.Node `json:"node" binding:"required"` +} + +// Add 添加节点 +func (service *AddNodeService) Add() serializer.Response { + if service.Node.ID > 0 { + if err := model.DB.Save(&service.Node).Error; err != nil { + return serializer.ParamErr("节点保存失败", err) + } + } else { + if err := model.DB.Create(&service.Node).Error; err != nil { + return serializer.ParamErr("节点添加失败", err) + } + } + + return serializer.Response{Data: service.Node.ID} +} + +// Nodes 列出从机节点 +func (service *AdminListService) Nodes() serializer.Response { + var res []model.Node + total := 0 + + tx := model.DB.Model(&model.Node{}) + if service.OrderBy != "" { + tx = tx.Order(service.OrderBy) + } + + for k, v := range service.Conditions { + tx = tx.Where(k+" = ?", v) + } + + if len(service.Searches) > 0 { + search := "" + for k, v := range service.Searches { + search += k + " like '%" + v + "%' OR " + } + search = strings.TrimSuffix(search, " OR ") + tx = tx.Where(search) + } + + // 计算总数用于分页 + tx.Count(&total) + + // 查询记录 + tx.Limit(service.PageSize).Offset((service.Page - 1) * service.PageSize).Find(&res) + + isActive := make(map[uint]bool) + for i := 0; i < len(res); i++ { + if node := cluster.Default.GetNodeByID(res[i].ID); node != nil { + isActive[res[i].ID] = node.IsActive() + } + } + + return serializer.Response{Data: map[string]interface{}{ + "total": total, + "items": res, + "active": isActive, + }} +} + +// ToggleNodeService 开关节点服务 +type ToggleNodeService struct { + ID uint `uri:"id"` + Desired model.NodeStatus `uri:"desired"` +} + +// Toggle 开关节点 +func (service *ToggleNodeService) Toggle() serializer.Response { + node, err := model.GetNodeByID(service.ID) + if err != nil { + return serializer.DBErr("找不到节点", err) + } + + // 是否为系统节点 + if node.ID <= 1 { + return serializer.Err(serializer.CodeNoPermissionErr, "系统节点无法更改", err) + } + + if err = node.SetStatus(service.Desired); err != nil { + return serializer.DBErr("无法更改节点状态", err) + } + + if service.Desired == model.NodeActive { + cluster.Default.Add(&node) + } else { + cluster.Default.Delete(node.ID) + } + + return serializer.Response{} +} + +// NodeService 节点ID服务 +type NodeService struct { + ID uint `uri:"id" json:"id" binding:"required"` +} + +// Delete 删除节点 +func (service *NodeService) Delete() serializer.Response { + // 查找用户组 + node, err := model.GetNodeByID(service.ID) + if err != nil { + return serializer.Err(serializer.CodeNotFound, "节点不存在", err) + } + + // 是否为系统节点 + if node.ID <= 1 { + return serializer.Err(serializer.CodeNoPermissionErr, "系统节点无法删除", err) + } + + cluster.Default.Delete(node.ID) + if err := model.DB.Delete(&node).Error; err != nil { + return serializer.DBErr("无法删除节点", err) + } + + return serializer.Response{} +} + +// Get 获取节点详情 +func (service *NodeService) Get() serializer.Response { + node, err := model.GetNodeByID(service.ID) + if err != nil { + return serializer.Err(serializer.CodeNotFound, "节点不存在", err) + } + + return serializer.Response{Data: node} +} diff --git a/service/admin/policy.go b/service/admin/policy.go index 7cd55fd..5c648ec 100644 --- a/service/admin/policy.go +++ b/service/admin/policy.go @@ -151,7 +151,7 @@ func (service *PolicyService) AddCORS() serializer.Response { case "oss": handler := oss.Driver{ Policy: &policy, - HTTPClient: request.HTTPClient{}, + HTTPClient: request.NewClient(), } if err := handler.CORS(); err != nil { return serializer.Err(serializer.CodeInternalSetting, "跨域策略添加失败", err) @@ -161,7 +161,7 @@ func (service *PolicyService) AddCORS() serializer.Response { b := &cossdk.BaseURL{BucketURL: u} handler := cos.Driver{ Policy: &policy, - HTTPClient: request.HTTPClient{}, + HTTPClient: request.NewClient(), Client: cossdk.NewClient(b, &http.Client{ Transport: &cossdk.AuthorizationTransport{ SecretID: policy.AccessKey, @@ -195,7 +195,7 @@ func (service *SlavePingService) Test() serializer.Response { controller, _ := url.Parse("/api/v3/site/ping") - r := request.HTTPClient{} + r := request.NewClient() res, err := r.Request( "GET", master.ResolveReference(controller).String(), @@ -229,7 +229,7 @@ func (service *SlaveTestService) Test() serializer.Response { } bodyByte, _ := json.Marshal(body) - r := request.HTTPClient{} + r := request.NewClient() res, err := r.Request( "POST", slave.ResolveReference(controller).String(), @@ -245,7 +245,7 @@ func (service *SlaveTestService) Test() serializer.Response { } if res.Code != 0 { - return serializer.ParamErr("成功接到从机,但是"+res.Msg, nil) + return serializer.ParamErr("成功接到从机,但是从机返回:"+res.Msg, nil) } return serializer.Response{} diff --git a/service/aria2/add.go b/service/aria2/add.go index be7213a..26b6baa 100644 --- a/service/aria2/add.go +++ b/service/aria2/add.go @@ -3,8 +3,14 @@ package aria2 import ( model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/aria2" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/cloudreve/Cloudreve/v3/pkg/slave" + "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/gin-gonic/gin" ) @@ -14,7 +20,7 @@ type AddURLService struct { Dst string `json:"dst" binding:"required,min=1"` } -// Add 创建新的链接离线下载任务 +// Add 主机创建新的链接离线下载任务 func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Response { // 创建文件系统 fs, err := filesystem.NewFileSystemFromContext(c) @@ -35,19 +41,60 @@ func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Respo // 创建任务 task := &model.Download{ - Status: aria2.Ready, + Status: common.Ready, Type: taskType, Dst: service.Dst, UserID: fs.User.ID, Source: service.URL, } + // 获取 Aria2 负载均衡器 aria2.Lock.RLock() - if err := aria2.Instance.CreateTask(task, fs.User.Group.OptionsSerialized.Aria2Options); err != nil { - aria2.Lock.RUnlock() + lb := aria2.LB + aria2.Lock.RUnlock() + + // 获取 Aria2 实例 + err, node := cluster.Default.BalanceNodeByFeature("aria2", lb) + if err != nil { + return serializer.Err(serializer.CodeInternalSetting, "Aria2 实例获取失败", err) + } + + // 创建任务 + gid, err := node.GetAria2Instance().CreateTask(task, fs.User.Group.OptionsSerialized.Aria2Options) + if err != nil { return serializer.Err(serializer.CodeNotSet, "任务创建失败", err) } - aria2.Lock.RUnlock() + + task.GID = gid + task.NodeID = node.ID() + _, err = task.Create() + if err != nil { + return serializer.DBErr("任务创建失败", err) + } + + // 创建任务监控 + monitor.NewMonitor(task) return serializer.Response{} } + +// Add 从机创建新的链接离线下载任务 +func Add(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response { + caller, _ := c.Get("MasterAria2Instance") + + // 创建任务 + gid, err := caller.(common.Aria2).CreateTask(service.Task, service.GroupOptions) + if err != nil { + return serializer.Err(serializer.CodeInternalSetting, "无法创建离线下载任务", err) + } + + // 创建事件通知回调 + siteID, _ := c.Get("MasterSiteID") + mq.GlobalMQ.SubscribeCallback(gid, func(message mq.Message) { + if err := slave.DefaultController.SendNotification(siteID.(string), message.TriggeredBy, message); err != nil { + util.Log().Warning("无法发送离线下载任务状态变更通知, %s", err) + } + }) + + return serializer.Response{Data: gid} +} diff --git a/service/aria2/manage.go b/service/aria2/manage.go index d93bbca..f3ed47d 100644 --- a/service/aria2/manage.go +++ b/service/aria2/manage.go @@ -2,7 +2,8 @@ package aria2 import ( model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/gin-gonic/gin" ) @@ -25,14 +26,14 @@ type DownloadListService struct { // Finished 获取已完成的任务 func (service *DownloadListService) Finished(c *gin.Context, user *model.User) serializer.Response { // 查找下载记录 - downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, aria2.Error, aria2.Complete, aria2.Canceled, aria2.Unknown) + downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, common.Error, common.Complete, common.Canceled, common.Unknown) return serializer.BuildFinishedListResponse(downloads) } // Downloading 获取正在下载中的任务 func (service *DownloadListService) Downloading(c *gin.Context, user *model.User) serializer.Response { // 查找下载记录 - downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, aria2.Downloading, aria2.Paused, aria2.Ready) + downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, common.Downloading, common.Paused, common.Ready) return serializer.BuildDownloadingResponse(downloads) } @@ -47,7 +48,7 @@ func (service *DownloadTaskService) Delete(c *gin.Context) serializer.Response { return serializer.Err(serializer.CodeNotFound, "下载记录不存在", err) } - if download.Status >= aria2.Error { + if download.Status >= common.Error { // 如果任务已完成,则删除任务记录 if err := download.Delete(); err != nil { return serializer.Err(serializer.CodeDBError, "任务记录删除失败", err) @@ -56,9 +57,12 @@ func (service *DownloadTaskService) Delete(c *gin.Context) serializer.Response { } // 取消任务 - aria2.Lock.RLock() - defer aria2.Lock.RUnlock() - if err := aria2.Instance.Cancel(download); err != nil { + node := cluster.Default.GetNodeByID(download.GetNodeID()) + if node == nil { + return serializer.Err(serializer.CodeInternalSetting, "目标节点不可用", err) + } + + if err := node.GetAria2Instance().Cancel(download); err != nil { return serializer.Err(serializer.CodeNotSet, "操作失败", err) } @@ -76,17 +80,72 @@ func (service *SelectFileService) Select(c *gin.Context) serializer.Response { return serializer.Err(serializer.CodeNotFound, "下载记录不存在", err) } - if download.StatusInfo.BitTorrent.Mode != "multi" || (download.Status != aria2.Downloading && download.Status != aria2.Paused) { + if download.StatusInfo.BitTorrent.Mode != "multi" || (download.Status != common.Downloading && download.Status != common.Paused) { return serializer.Err(serializer.CodeNoPermissionErr, "此下载任务无法选取文件", err) } // 选取下载 - aria2.Lock.RLock() - defer aria2.Lock.RUnlock() - if err := aria2.Instance.Select(download, service.Indexes); err != nil { + node := cluster.Default.GetNodeByID(download.GetNodeID()) + if err := node.GetAria2Instance().Select(download, service.Indexes); err != nil { return serializer.Err(serializer.CodeNotSet, "操作失败", err) } return serializer.Response{} } + +// SlaveStatus 从机查询离线任务状态 +func SlaveStatus(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response { + caller, _ := c.Get("MasterAria2Instance") + + // 查询任务 + status, err := caller.(common.Aria2).Status(service.Task) + if err != nil { + return serializer.Err(serializer.CodeInternalSetting, "离线下载任务查询失败", err) + } + + return serializer.NewResponseWithGobData(status) + +} + +// SlaveCancel 取消从机离线下载任务 +func SlaveCancel(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response { + caller, _ := c.Get("MasterAria2Instance") + + // 查询任务 + err := caller.(common.Aria2).Cancel(service.Task) + if err != nil { + return serializer.Err(serializer.CodeInternalSetting, "任务取消失败", err) + } + + return serializer.Response{} + +} + +// SlaveSelect 从机选取离线下载任务文件 +func SlaveSelect(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response { + caller, _ := c.Get("MasterAria2Instance") + + // 查询任务 + err := caller.(common.Aria2).Select(service.Task, service.Files) + if err != nil { + return serializer.Err(serializer.CodeInternalSetting, "任务选取失败", err) + } + + return serializer.Response{} + +} + +// SlaveSelect 从机选取离线下载任务文件 +func SlaveDeleteTemp(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response { + caller, _ := c.Get("MasterAria2Instance") + + // 查询任务 + err := caller.(common.Aria2).DeleteTempFile(service.Task) + if err != nil { + return serializer.Err(serializer.CodeInternalSetting, "临时文件删除失败", err) + } + + return serializer.Response{} + +} diff --git a/service/explorer/file.go b/service/explorer/file.go index d128c86..8b56871 100644 --- a/service/explorer/file.go +++ b/service/explorer/file.go @@ -2,7 +2,6 @@ package explorer import ( "context" - "encoding/base64" "encoding/json" "fmt" "io/ioutil" @@ -20,7 +19,6 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/gin-gonic/gin" - "github.com/jinzhu/gorm" ) // SingleFileService 对单文件进行操作的五福,path为文件完整路径 @@ -43,29 +41,6 @@ type DownloadService struct { ID string `uri:"id" binding:"required"` } -// SlaveDownloadService 从机文件下載服务 -type SlaveDownloadService struct { - PathEncoded string `uri:"path" binding:"required"` - Name string `uri:"name" binding:"required"` - Speed int `uri:"speed" binding:"min=0"` -} - -// SlaveFileService 从机单文件文件相关服务 -type SlaveFileService struct { - PathEncoded string `uri:"path" binding:"required"` -} - -// SlaveFilesService 从机多文件相关服务 -type SlaveFilesService struct { - Files []string `json:"files" binding:"required,gt=0"` -} - -// SlaveListService 从机列表服务 -type SlaveListService struct { - Path string `json:"path" binding:"required,min=1,max=65535"` - Recursive bool `json:"recursive"` -} - // New 创建新文件 func (service *SingleFileService) Create(c *gin.Context) serializer.Response { // 创建文件系统 @@ -449,106 +424,3 @@ func (service *FileIDService) PutContent(ctx context.Context, c *gin.Context) se Code: 0, } } - -// ServeFile 通过签名的URL下载从机文件 -func (service *SlaveDownloadService) ServeFile(ctx context.Context, c *gin.Context, isDownload bool) serializer.Response { - // 创建文件系统 - fs, err := filesystem.NewAnonymousFileSystem() - if err != nil { - return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) - } - defer fs.Recycle() - - // 解码文件路径 - fileSource, err := base64.RawURLEncoding.DecodeString(service.PathEncoded) - if err != nil { - return serializer.ParamErr("无法解析的文件地址", err) - } - - // 根据URL里的信息创建一个文件对象和用户对象 - file := model.File{ - Name: service.Name, - SourceName: string(fileSource), - Policy: model.Policy{ - Model: gorm.Model{ID: 1}, - Type: "local", - }, - } - fs.User = &model.User{ - Group: model.Group{SpeedLimit: service.Speed}, - } - fs.FileTarget = []model.File{file} - - // 开始处理下载 - ctx = context.WithValue(ctx, fsctx.GinCtx, c) - rs, err := fs.GetDownloadContent(ctx, 0) - if err != nil { - return serializer.Err(serializer.CodeNotSet, err.Error(), err) - } - defer rs.Close() - - // 设置下载文件名 - if isDownload { - c.Header("Content-Disposition", "attachment; filename=\""+url.PathEscape(fs.FileTarget[0].Name)+"\"") - } - - // 发送文件 - http.ServeContent(c.Writer, c.Request, fs.FileTarget[0].Name, time.Now(), rs) - - return serializer.Response{ - Code: 0, - } -} - -// Delete 通过签名的URL删除从机文件 -func (service *SlaveFilesService) Delete(ctx context.Context, c *gin.Context) serializer.Response { - // 创建文件系统 - fs, err := filesystem.NewAnonymousFileSystem() - if err != nil { - return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) - } - defer fs.Recycle() - - // 删除文件 - failed, err := fs.Handler.Delete(ctx, service.Files) - if err != nil { - // 将Data字段写为字符串方便主控端解析 - data, _ := json.Marshal(serializer.RemoteDeleteRequest{Files: failed}) - - return serializer.Response{ - Code: serializer.CodeNotFullySuccess, - Data: string(data), - Msg: fmt.Sprintf("有 %d 个文件未能成功删除", len(failed)), - Error: err.Error(), - } - } - return serializer.Response{Code: 0} -} - -// Thumb 通过签名URL获取从机文件缩略图 -func (service *SlaveFileService) Thumb(ctx context.Context, c *gin.Context) serializer.Response { - // 创建文件系统 - fs, err := filesystem.NewAnonymousFileSystem() - if err != nil { - return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) - } - defer fs.Recycle() - - // 解码文件路径 - fileSource, err := base64.RawURLEncoding.DecodeString(service.PathEncoded) - if err != nil { - return serializer.ParamErr("无法解析的文件地址", err) - } - fs.FileTarget = []model.File{{SourceName: string(fileSource), PicInfo: "1,1"}} - - // 获取缩略图 - resp, err := fs.GetThumb(ctx, 0) - if err != nil { - return serializer.Err(serializer.CodeNotSet, "无法获取缩略图", err) - } - - defer resp.Content.Close() - http.ServeContent(c.Writer, c.Request, "thumb.png", time.Now(), resp.Content) - - return serializer.Response{Code: 0} -} diff --git a/service/explorer/slave.go b/service/explorer/slave.go new file mode 100644 index 0000000..8beb15b --- /dev/null +++ b/service/explorer/slave.go @@ -0,0 +1,166 @@ +package explorer + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/cloudreve/Cloudreve/v3/pkg/slave" + "github.com/cloudreve/Cloudreve/v3/pkg/task" + "github.com/cloudreve/Cloudreve/v3/pkg/task/slavetask" + "github.com/gin-gonic/gin" + "github.com/jinzhu/gorm" + "net/http" + "net/url" + "time" +) + +// SlaveDownloadService 从机文件下載服务 +type SlaveDownloadService struct { + PathEncoded string `uri:"path" binding:"required"` + Name string `uri:"name" binding:"required"` + Speed int `uri:"speed" binding:"min=0"` +} + +// SlaveFileService 从机单文件文件相关服务 +type SlaveFileService struct { + PathEncoded string `uri:"path" binding:"required"` +} + +// SlaveFilesService 从机多文件相关服务 +type SlaveFilesService struct { + Files []string `json:"files" binding:"required,gt=0"` +} + +// SlaveListService 从机列表服务 +type SlaveListService struct { + Path string `json:"path" binding:"required,min=1,max=65535"` + Recursive bool `json:"recursive"` +} + +// ServeFile 通过签名的URL下载从机文件 +func (service *SlaveDownloadService) ServeFile(ctx context.Context, c *gin.Context, isDownload bool) serializer.Response { + // 创建文件系统 + fs, err := filesystem.NewAnonymousFileSystem() + if err != nil { + return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) + } + defer fs.Recycle() + + // 解码文件路径 + fileSource, err := base64.RawURLEncoding.DecodeString(service.PathEncoded) + if err != nil { + return serializer.ParamErr("无法解析的文件地址", err) + } + + // 根据URL里的信息创建一个文件对象和用户对象 + file := model.File{ + Name: service.Name, + SourceName: string(fileSource), + Policy: model.Policy{ + Model: gorm.Model{ID: 1}, + Type: "local", + }, + } + fs.User = &model.User{ + Group: model.Group{SpeedLimit: service.Speed}, + } + fs.FileTarget = []model.File{file} + + // 开始处理下载 + ctx = context.WithValue(ctx, fsctx.GinCtx, c) + rs, err := fs.GetDownloadContent(ctx, 0) + if err != nil { + return serializer.Err(serializer.CodeNotSet, err.Error(), err) + } + defer rs.Close() + + // 设置下载文件名 + if isDownload { + c.Header("Content-Disposition", "attachment; filename=\""+url.PathEscape(fs.FileTarget[0].Name)+"\"") + } + + // 发送文件 + http.ServeContent(c.Writer, c.Request, fs.FileTarget[0].Name, time.Now(), rs) + + return serializer.Response{ + Code: 0, + } +} + +// Delete 通过签名的URL删除从机文件 +func (service *SlaveFilesService) Delete(ctx context.Context, c *gin.Context) serializer.Response { + // 创建文件系统 + fs, err := filesystem.NewAnonymousFileSystem() + if err != nil { + return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) + } + defer fs.Recycle() + + // 删除文件 + failed, err := fs.Handler.Delete(ctx, service.Files) + if err != nil { + // 将Data字段写为字符串方便主控端解析 + data, _ := json.Marshal(serializer.RemoteDeleteRequest{Files: failed}) + + return serializer.Response{ + Code: serializer.CodeNotFullySuccess, + Data: string(data), + Msg: fmt.Sprintf("有 %d 个文件未能成功删除", len(failed)), + Error: err.Error(), + } + } + return serializer.Response{Code: 0} +} + +// Thumb 通过签名URL获取从机文件缩略图 +func (service *SlaveFileService) Thumb(ctx context.Context, c *gin.Context) serializer.Response { + // 创建文件系统 + fs, err := filesystem.NewAnonymousFileSystem() + if err != nil { + return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) + } + defer fs.Recycle() + + // 解码文件路径 + fileSource, err := base64.RawURLEncoding.DecodeString(service.PathEncoded) + if err != nil { + return serializer.ParamErr("无法解析的文件地址", err) + } + fs.FileTarget = []model.File{{SourceName: string(fileSource), PicInfo: "1,1"}} + + // 获取缩略图 + resp, err := fs.GetThumb(ctx, 0) + if err != nil { + return serializer.Err(serializer.CodeNotSet, "无法获取缩略图", err) + } + + defer resp.Content.Close() + http.ServeContent(c.Writer, c.Request, "thumb.png", time.Now(), resp.Content) + + return serializer.Response{Code: 0} +} + +// CreateTransferTask 创建从机文件转存任务 +func CreateTransferTask(c *gin.Context, req *serializer.SlaveTransferReq) serializer.Response { + if id, ok := c.Get("MasterSiteID"); ok { + job := &slavetask.TransferTask{ + Req: req, + MasterID: id.(string), + } + + if err := slave.DefaultController.SubmitTask(job.MasterID, job, req.Hash(job.MasterID), func(job interface{}) { + task.TaskPoll.Submit(job.(task.Job)) + }); err != nil { + return serializer.Err(serializer.CodeInternalSetting, "任务创建失败", err) + } + + return serializer.Response{} + } + + return serializer.ParamErr("未知的主机节点ID", nil) +} diff --git a/service/node/fabric.go b/service/node/fabric.go new file mode 100644 index 0000000..79dfb29 --- /dev/null +++ b/service/node/fabric.go @@ -0,0 +1,62 @@ +package node + +import ( + "encoding/gob" + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/onedrive" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/cloudreve/Cloudreve/v3/pkg/slave" + "github.com/gin-gonic/gin" +) + +type SlaveNotificationService struct { + Subject string `uri:"subject" binding:"required"` +} + +type OneDriveCredentialService struct { + PolicyID uint `uri:"id" binding:"required"` +} + +func HandleMasterHeartbeat(req *serializer.NodePingReq) serializer.Response { + res, err := slave.DefaultController.HandleHeartBeat(req) + if err != nil { + return serializer.Err(serializer.CodeInternalSetting, "Cannot initialize slave controller", err) + } + + return serializer.Response{ + Code: 0, + Data: res, + } +} + +// HandleSlaveNotificationPush 转发从机的消息通知到本机消息队列 +func (s *SlaveNotificationService) HandleSlaveNotificationPush(c *gin.Context) serializer.Response { + var msg mq.Message + dec := gob.NewDecoder(c.Request.Body) + if err := dec.Decode(&msg); err != nil { + return serializer.ParamErr("Cannot parse notification message", err) + } + + mq.GlobalMQ.Publish(s.Subject, msg) + return serializer.Response{} +} + +// Get 获取主机OneDrive策略的AccessToken +func (s *OneDriveCredentialService) Get(c *gin.Context) serializer.Response { + policy, err := model.GetPolicyByID(s.PolicyID) + if err != nil { + return serializer.Err(serializer.CodeNotFound, "Cannot found storage policy", err) + } + + client, err := onedrive.NewClient(&policy) + if err != nil { + return serializer.Err(serializer.CodeInternalSetting, "Cannot initialize OneDrive client", err) + } + + if err := client.UpdateCredential(c); err != nil { + return serializer.Err(serializer.CodeInternalSetting, "Cannot refresh OneDrive credential", err) + } + + return serializer.Response{Data: client.Credential.AccessToken} +}