diff --git a/models/node.go b/models/node.go index 0c49aad1..fde6d3ea 100644 --- a/models/node.go +++ b/models/node.go @@ -31,7 +31,9 @@ type Aria2Option struct { // 附加下载配置 Options string `json:"options,omitempty"` // 下载监控间隔 - Interval string `json:"interval,omitempty"` + Interval int `json:"interval,omitempty"` + // RPC API 请求超时 + Timeout int `json:"timeout,omitempty"` } type NodeStatus int diff --git a/pkg/aria2/aria2.go b/pkg/aria2/aria2.go index 83131b68..9aef4953 100644 --- a/pkg/aria2/aria2.go +++ b/pkg/aria2/aria2.go @@ -5,12 +5,16 @@ import ( model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" + "github.com/cloudreve/Cloudreve/v3/pkg/balancer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" ) // Instance 默认使用的Aria2处理实例 var Instance Aria2 = &DummyAria2{} +// LB 获取 Aria2 节点的负载均衡器 +var LB balancer.Balancer + // Lock Instance的读写锁 var Lock sync.RWMutex @@ -92,6 +96,10 @@ func (instance *DummyAria2) Select(task *model.Download, files []int) error { // Init 初始化 func Init(isReload bool) { + Lock.Lock() + LB = balancer.NewBalancer("RoundRobin") + Lock.Unlock() + if !isReload { // 从数据库中读取未完成任务,创建监控 unfinished := model.GetDownloadsByStatus(Ready, Paused, Downloading) @@ -101,7 +109,6 @@ func Init(isReload bool) { NewMonitor(&unfinished[i]) } } - } // getStatus 将给定的状态字符串转换为状态标识数字 diff --git a/pkg/balancer/balancer.go b/pkg/balancer/balancer.go new file mode 100644 index 00000000..5d5c028c --- /dev/null +++ b/pkg/balancer/balancer.go @@ -0,0 +1,15 @@ +package balancer + +type Balancer interface { + NextPeer(nodes interface{}) (error, interface{}) +} + +// NewBalancer 根据策略标识返回新的负载均衡器 +func NewBalancer(strategy string) Balancer { + switch strategy { + case "RoundRobin": + return &RoundRobin{} + default: + return &RoundRobin{} + } +} diff --git a/pkg/balancer/errors.go b/pkg/balancer/errors.go new file mode 100644 index 00000000..5285478c --- /dev/null +++ b/pkg/balancer/errors.go @@ -0,0 +1,7 @@ +package balancer + +import "errors" + +var ( + ErrInputNotSlice = errors.New("Input value is not silice") +) diff --git a/pkg/balancer/roundrobin.go b/pkg/balancer/roundrobin.go new file mode 100644 index 00000000..26f4ccc6 --- /dev/null +++ b/pkg/balancer/roundrobin.go @@ -0,0 +1,26 @@ +package balancer + +import ( + "reflect" + "sync/atomic" +) + +type RoundRobin struct { + current uint64 +} + +// NextPeer 返回轮盘的下一节点 +func (r *RoundRobin) NextPeer(nodes interface{}) (error, interface{}) { + v := reflect.ValueOf(nodes) + if v.Kind() != reflect.Slice { + return ErrInputNotSlice, nil + } + + next := r.NextIndex(v.Len()) + return nil, v.Index(next).Interface() +} + +// NextIndex 返回下一个节点下标 +func (r *RoundRobin) NextIndex(total int) int { + return int(atomic.AddUint64(&r.current, uint64(1)) % uint64(total)) +} diff --git a/pkg/cluster/errors.go b/pkg/cluster/errors.go new file mode 100644 index 00000000..0a19f6ef --- /dev/null +++ b/pkg/cluster/errors.go @@ -0,0 +1,7 @@ +package cluster + +import "errors" + +var ( + ErrFeatureNotExist = errors.New("No nodes in nodepool match the feature specificed") +) diff --git a/pkg/cluster/master.go b/pkg/cluster/master.go index 6b2cf080..0b6ce3fc 100644 --- a/pkg/cluster/master.go +++ b/pkg/cluster/master.go @@ -3,6 +3,7 @@ package cluster import ( "context" "encoding/json" + "fmt" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/aria2" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" @@ -54,7 +55,12 @@ func (node *MasterNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingR // IsFeatureEnabled 查询节点的某项功能是否启用 func (node *MasterNode) IsFeatureEnabled(feature string) bool { + node.lock.RLock() + defer node.lock.RUnlock() + switch feature { + case "aria2": + return node.Model.Aria2Enabled default: return false } @@ -70,18 +76,18 @@ func (node *MasterNode) IsActive() bool { } // GetAria2Instance 获取主机Aria2实例 -func (node *MasterNode) GetAria2Instance() (aria2.Aria2, error) { +func (node *MasterNode) GetAria2Instance() aria2.Aria2 { if !node.Model.Aria2Enabled { - return &aria2.DummyAria2{}, nil + return &aria2.DummyAria2{} } node.lock.RLock() defer node.lock.RUnlock() if !node.aria2RPC.Initialized { - return &aria2.DummyAria2{}, nil + return &aria2.DummyAria2{} } - return &node.aria2RPC, nil + return &node.aria2RPC } func (r *rpcService) Init() error { @@ -104,16 +110,18 @@ func (r *rpcService) Init() error { // 加载自定义下载配置 var globalOptions map[string]interface{} - err = json.Unmarshal([]byte(r.parent.Model.Aria2OptionsSerialized.Options), &globalOptions) - if err != nil { - util.Log().Warning("无法解析主机 Aria2 配置,%s", err) - return err + if r.parent.Model.Aria2OptionsSerialized.Options != "" { + err = json.Unmarshal([]byte(r.parent.Model.Aria2OptionsSerialized.Options), &globalOptions) + if err != nil { + util.Log().Warning("无法解析主机 Aria2 配置,%s", err) + return err + } } r.options = &clientOptions{ Options: globalOptions, } - timeout := model.GetIntSetting("aria2_call_timeout", 5) + timeout := r.parent.Model.Aria2OptionsSerialized.Timeout caller, err := rpc.New(context.Background(), server.String(), r.parent.Model.Aria2OptionsSerialized.Token, time.Duration(timeout)*time.Second, aria2.EventNotifier) r.Caller = caller @@ -122,7 +130,7 @@ func (r *rpcService) Init() error { } func (r *rpcService) CreateTask(task *model.Download, options map[string]interface{}) (string, error) { - panic("implement me") + return "", fmt.Errorf("some error #%d", r.parent.Model.ID) } func (r *rpcService) Status(task *model.Download) (rpc.StatusInfo, error) { diff --git a/pkg/cluster/node.go b/pkg/cluster/node.go index 09256239..041c11ae 100644 --- a/pkg/cluster/node.go +++ b/pkg/cluster/node.go @@ -12,7 +12,7 @@ type Node interface { SubscribeStatusChange(callback func(isActive bool, id uint)) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) IsActive() bool - GetAria2Instance() (aria2.Aria2, error) + GetAria2Instance() aria2.Aria2 } func getNodeFromDBModel(node *model.Node) Node { diff --git a/pkg/cluster/pool.go b/pkg/cluster/pool.go index 09b04ee8..de364de0 100644 --- a/pkg/cluster/pool.go +++ b/pkg/cluster/pool.go @@ -2,6 +2,7 @@ package cluster import ( model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/balancer" "github.com/cloudreve/Cloudreve/v3/pkg/util" "sync" ) @@ -9,11 +10,11 @@ import ( var Default *NodePool // 需要分类的节点组 -var featureGroup = []string{"Aria2"} +var featureGroup = []string{"aria2"} // Pool 节点池 type Pool interface { - Select() + BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, Node) } // NodePool 通用节点池 @@ -26,10 +27,6 @@ type NodePool struct { lock sync.RWMutex } -func (pool *NodePool) Select() { - -} - // Init 初始化从机节点池 func Init() { Default = &NodePool{ @@ -100,3 +97,15 @@ func (pool *NodePool) initFromDB() error { pool.buildIndexMap() return nil } + +// BalanceNodeByFeature 根据 feature 和 LoadBalancer 取出节点 +func (pool *NodePool) BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, Node) { + pool.lock.RLock() + defer pool.lock.RUnlock() + if nodes, ok := pool.featureMap[feature]; ok { + err, res := lb.NextPeer(nodes) + return err, res.(Node) + } + + return ErrFeatureNotExist, nil +} diff --git a/pkg/cluster/slave.go b/pkg/cluster/slave.go index 347d34c3..247942a1 100644 --- a/pkg/cluster/slave.go +++ b/pkg/cluster/slave.go @@ -185,6 +185,6 @@ loop: } // GetAria2Instance 获取从机Aria2实例 -func (node *SlaveNode) GetAria2Instance() (aria2.Aria2, error) { - return nil, nil +func (node *SlaveNode) GetAria2Instance() aria2.Aria2 { + return nil } diff --git a/service/aria2/add.go b/service/aria2/add.go index 9e0726d0..07865bfa 100644 --- a/service/aria2/add.go +++ b/service/aria2/add.go @@ -3,6 +3,7 @@ package aria2 import ( model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/aria2" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/gin-gonic/gin" @@ -42,10 +43,20 @@ func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Respo Source: service.URL, } + // 获取 Aria2 负载均衡器 aria2.Lock.RLock() - gid, err := aria2.Instance.CreateTask(task, fs.User.Group.OptionsSerialized.Aria2Options) + lb := aria2.LB + aria2.Lock.RUnlock() + + // 获取 Aria2 实例 + err, node := cluster.Default.BalanceNodeByFeature("aria2", lb) + if err != nil { + return serializer.Err(serializer.CodeInternalSetting, "Aria2 实例获取失败", err) + } + + // 创建任务 + gid, err := node.GetAria2Instance().CreateTask(task, fs.User.Group.OptionsSerialized.Aria2Options) if err != nil { - aria2.Lock.RUnlock() return serializer.Err(serializer.CodeNotSet, "任务创建失败", err) }