Feat(remotearia2): add task

pull/801/head
Cian John 5 years ago
parent a1252c810b
commit 19c770a075

@ -28,6 +28,11 @@ func Init(path string) {
email.Init() email.Init()
crontab.Init() crontab.Init()
InitStatic() InitStatic()
} else {
if conf.SlaveConfig.Aria2 {
model.Init()
aria2.Init(false)
}
} }
auth.Init() auth.Init()
} }

@ -100,6 +100,13 @@ func GetDownloadByGid(gid string, uid uint) (*Download, error) {
return download, result.Error return download, result.Error
} }
// GetDownloadById 根据ID查找下载
func GetDownloadById(id uint) (*Download, error) {
var download Download
result := DB.First(&download, id)
return &download, result.Error
}
// GetOwner 获取下载任务所属用户 // GetOwner 获取下载任务所属用户
func (task *Download) GetOwner() *User { func (task *Download) GetOwner() *User {
if task.User == nil { if task.User == nil {

@ -136,6 +136,8 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
{Name: "aria2_temp_path", Value: ``, Type: "aria2"}, {Name: "aria2_temp_path", Value: ``, Type: "aria2"},
{Name: "aria2_options", Value: `{}`, Type: "aria2"}, {Name: "aria2_options", Value: `{}`, Type: "aria2"},
{Name: "aria2_interval", Value: `60`, Type: "aria2"}, {Name: "aria2_interval", Value: `60`, Type: "aria2"},
{Name: "aria2_remote_enabled", Value: `0`, Type: "aria2"},
{Name: "aria2_remote_id", Value: `0`, Type: "aria2"},
{Name: "max_worker_num", Value: `10`, Type: "task"}, {Name: "max_worker_num", Value: `10`, Type: "task"},
{Name: "max_parallel_transfer", Value: `4`, Type: "task"}, {Name: "max_parallel_transfer", Value: `4`, Type: "task"},
{Name: "secret_key", Value: util.RandStringRunes(256), Type: "auth"}, {Name: "secret_key", Value: util.RandStringRunes(256), Type: "auth"},

@ -2,6 +2,7 @@ package aria2
import ( import (
"encoding/json" "encoding/json"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"net/url" "net/url"
"sync" "sync"
@ -89,6 +90,21 @@ func (instance *DummyAria2) Select(task *model.Download, files []int) error {
// Init 初始化 // Init 初始化
func Init(isReload bool) { func Init(isReload bool) {
if conf.SystemConfig.Mode == "master" {
MasterInit(isReload)
} else {
SlaveInit(isReload)
}
}
// SlaveInit 从机初始化
func SlaveInit(isReload bool) {
if !model.IsTrueVal(model.GetSettingByName("aria2_remote_enabled")) {
return
}
if conf.SlaveConfig.SlaveId == 0 || model.GetIntSetting("aria2_remote_id", 0) != int(conf.SlaveConfig.SlaveId) {
return
}
Lock.Lock() Lock.Lock()
defer Lock.Unlock() defer Lock.Unlock()
@ -136,6 +152,59 @@ func Init(isReload bool) {
Instance = client Instance = client
// monitor
}
// MasterInit 主机初始化
func MasterInit(isReload bool) {
Lock.Lock()
defer Lock.Unlock()
if !model.IsTrueVal(model.GetSettingByName("aria2_remote_enabled")) {
// 关闭上个初始连接
if previousClient, ok := Instance.(*RPCService); ok {
if previousClient.Caller != nil {
util.Log().Debug("关闭上个 aria2 连接")
previousClient.Caller.Close()
}
}
options := model.GetSettingByNames("aria2_rpcurl", "aria2_token", "aria2_options")
timeout := model.GetIntSetting("aria2_call_timeout", 5)
if options["aria2_rpcurl"] == "" {
Instance = &DummyAria2{}
return
}
util.Log().Info("初始化 aria2 RPC 服务[%s]", options["aria2_rpcurl"])
client := &RPCService{}
// 解析RPC服务地址
server, err := url.Parse(options["aria2_rpcurl"])
if err != nil {
util.Log().Warning("无法解析 aria2 RPC 服务地址,%s", err)
Instance = &DummyAria2{}
return
}
server.Path = "/jsonrpc"
// 加载自定义下载配置
var globalOptions map[string]interface{}
err = json.Unmarshal([]byte(options["aria2_options"]), &globalOptions)
if err != nil {
util.Log().Warning("无法解析 aria2 全局配置,%s", err)
Instance = &DummyAria2{}
return
}
if err := client.Init(server.String(), options["aria2_token"], timeout, globalOptions); err != nil {
util.Log().Warning("初始化 aria2 RPC 服务失败,%s", err)
Instance = &DummyAria2{}
return
}
Instance = client
if !isReload { if !isReload {
// 从数据库中读取未完成任务,创建监控 // 从数据库中读取未完成任务,创建监控
unfinished := model.GetDownloadsByStatus(Ready, Paused, Downloading) unfinished := model.GetDownloadsByStatus(Ready, Paused, Downloading)
@ -145,7 +214,20 @@ func Init(isReload bool) {
NewMonitor(&unfinished[i]) NewMonitor(&unfinished[i])
} }
} }
} else {
util.Log().Info("初始化 从机 aria2 RPC 服务")
remote, err := model.GetPolicyByID(uint(model.GetIntSetting("aria2_remote_id", 0)))
if err != nil {
util.Log().Warning("初始化 从机 aria2 RPC 服务失败,%s", err)
Instance = &DummyAria2{}
return
}
client := &RemoteService{}
client.Init(&remote)
Instance = client
}
} }
// getStatus 将给定的状态字符串转换为状态标识数字 // getStatus 将给定的状态字符串转换为状态标识数字

@ -111,7 +111,7 @@ func (client *RPCService) CreateTask(task *model.Download, groupOptions map[stri
// 保存到数据库 // 保存到数据库
task.GID = gid task.GID = gid
_, err = task.Create() err = task.Save()
if err != nil { if err != nil {
return err return err
} }

@ -0,0 +1,103 @@
package aria2
import (
"encoding/json"
"errors"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"net/url"
"path"
"strings"
)
// RemoteService 通过从机RPC服务的Aria2任务管理器
type RemoteService struct {
Policy *model.Policy
Client request.Client
AuthInstance auth.Auth
}
func (client *RemoteService) Init(policy *model.Policy) {
client.Policy = policy
client.Client = request.HTTPClient{}
client.AuthInstance = auth.HMACAuth{SecretKey: []byte(client.Policy.SecretKey)}
}
func (client *RemoteService) CreateTask(task *model.Download, options map[string]interface{}) error {
reqBody := serializer.RemoteAria2AddRequest{
TaskId: task.ID,
Options: options,
}
reqBodyEncoded, err := json.Marshal(reqBody)
if err != nil {
return err
}
// 发送列表请求
bodyReader := strings.NewReader(string(reqBodyEncoded))
signTTL := model.GetIntSetting("slave_api_timeout", 60)
resp, err := client.Client.Request(
"POST",
client.getAPIUrl("add"),
bodyReader,
request.WithCredential(client.AuthInstance, int64(signTTL)),
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return err
}
// 处理列取结果
if resp.Code != 0 {
return errors.New(resp.Error)
}
if resStr, ok := resp.Data.(string); ok {
var res serializer.Response
err = json.Unmarshal([]byte(resStr), &res)
if err != nil {
return err
}
if res.Code != 0 {
return errors.New(res.Msg)
}
}
return nil
}
func (client *RemoteService) Status(task *model.Download) (rpc.StatusInfo, error) {
panic("implement me")
}
func (client *RemoteService) Cancel(task *model.Download) error {
panic("implement me")
}
func (client *RemoteService) Select(task *model.Download, files []int) error {
panic("implement me")
}
// getAPIUrl 获取接口请求地址
func (client *RemoteService) getAPIUrl(scope string, routes ...string) string {
serverURL, err := url.Parse(client.Policy.Server)
if err != nil {
return ""
}
var controller *url.URL
switch scope {
case "add":
controller, _ = url.Parse("/api/v3/slave/aria2/add")
default:
controller = serverURL
}
for _, r := range routes {
controller.Path = path.Join(controller.Path, r)
}
return serverURL.ResolveReference(controller).String()
}

@ -42,6 +42,8 @@ type slave struct {
Secret string `validate:"omitempty,gte=64"` Secret string `validate:"omitempty,gte=64"`
CallbackTimeout int `validate:"omitempty,gte=1"` CallbackTimeout int `validate:"omitempty,gte=1"`
SignatureTTL int `validate:"omitempty,gte=1"` SignatureTTL int `validate:"omitempty,gte=1"`
SlaveId uint `validate:"omitempty"`
Aria2 bool `validate:"omitempty"`
} }
// captcha 验证码配置 // captcha 验证码配置

@ -10,3 +10,8 @@ type ListRequest struct {
Path string `json:"path"` Path string `json:"path"`
Recursive bool `json:"recursive"` Recursive bool `json:"recursive"`
} }
type RemoteAria2AddRequest struct {
TaskId uint `json:"task_id"`
Options map[string]interface{} `json:"options"`
}

@ -2,6 +2,7 @@ package controllers
import ( import (
"context" "context"
"github.com/cloudreve/Cloudreve/v3/service/slave"
"net/url" "net/url"
"strconv" "strconv"
@ -175,3 +176,14 @@ func SlaveList(c *gin.Context) {
c.JSON(200, ErrorResponse(err)) c.JSON(200, ErrorResponse(err))
} }
} }
// SlaveAria2Add 从机创建远程下载任务
func SlaveAria2Add(c *gin.Context) {
var service slave.Aria2AddService
if err := c.ShouldBindJSON(&service); err == nil {
res := service.Add()
c.JSON(200, res)
} else {
c.JSON(200, ErrorResponse(err))
}
}

@ -50,6 +50,12 @@ func InitSlaveRouter() *gin.Engine {
// 列出文件 // 列出文件
v3.POST("list", controllers.SlaveList) v3.POST("list", controllers.SlaveList)
} }
aria2 := v3.Group("aria2")
aria2.Use(middleware.SignRequired())
{
aria2.POST("add", controllers.SlaveAria2Add)
}
return r return r
} }

@ -1,6 +1,7 @@
package admin package admin
import ( import (
model "github.com/cloudreve/Cloudreve/v3/models"
"net/url" "net/url"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2" "github.com/cloudreve/Cloudreve/v3/pkg/aria2"
@ -15,6 +16,7 @@ type Aria2TestService struct {
// Test 测试aria2连接 // Test 测试aria2连接
func (service *Aria2TestService) Test() serializer.Response { func (service *Aria2TestService) Test() serializer.Response {
if !model.IsTrueVal(model.GetSettingByName("aria2_remote_enabled")) {
testRPC := aria2.RPCService{} testRPC := aria2.RPCService{}
// 解析RPC服务地址 // 解析RPC服务地址
@ -40,4 +42,8 @@ func (service *Aria2TestService) Test() serializer.Response {
} }
return serializer.Response{Data: info.Version} return serializer.Response{Data: info.Version}
} else {
// TODO
return serializer.Response{Data: "TODO"}
}
} }

@ -42,6 +42,11 @@ func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Respo
Source: service.URL, Source: service.URL,
} }
_, err = task.Create()
if err != nil {
return serializer.Err(serializer.CodeNotSet, "任务创建失败", err)
}
aria2.Lock.RLock() aria2.Lock.RLock()
if err := aria2.Instance.CreateTask(task, fs.User.Group.OptionsSerialized.Aria2Options); err != nil { if err := aria2.Instance.CreateTask(task, fs.User.Group.OptionsSerialized.Aria2Options); err != nil {
aria2.Lock.RUnlock() aria2.Lock.RUnlock()

@ -0,0 +1,28 @@
package slave
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
type Aria2AddService struct {
TaskId uint `json:"task_id"`
Options map[string]interface{} `json:"options"`
}
func (service *Aria2AddService) Add() serializer.Response {
task, err := model.GetDownloadById(service.TaskId)
if err != nil {
util.Log().Warning("无法获取记录, %s", err)
return serializer.Err(serializer.CodeNotSet, "任务创建失败, 无法获取记录", err)
}
aria2.Lock.RLock()
if err := aria2.Instance.CreateTask(task, service.Options); err != nil {
aria2.Lock.RUnlock()
return serializer.Err(serializer.CodeNotSet, "任务创建失败", err)
}
aria2.Lock.RUnlock()
return serializer.Response{}
}
Loading…
Cancel
Save