From fa56d81381b5697714af7b44da3e31e3d3fb4f33 Mon Sep 17 00:00:00 2001 From: Cian John Date: Sat, 20 Mar 2021 16:30:48 +0800 Subject: [PATCH 1/5] Feat(remotearia2): add task --- bootstrap/init.go | 5 ++ models/download.go | 7 +++ models/migration.go | 2 + pkg/aria2/aria2.go | 96 +++++++++++++++++++++++++++++--- pkg/aria2/caller.go | 2 +- pkg/aria2/remote_caller.go | 103 +++++++++++++++++++++++++++++++++++ pkg/conf/conf.go | 2 + pkg/serializer/slave.go | 5 ++ routers/controllers/slave.go | 12 ++++ routers/router.go | 6 ++ service/admin/aria2.go | 54 ++++++++++-------- service/aria2/add.go | 5 ++ service/slave/aria2.go | 28 ++++++++++ 13 files changed, 295 insertions(+), 32 deletions(-) create mode 100644 pkg/aria2/remote_caller.go create mode 100644 service/slave/aria2.go diff --git a/bootstrap/init.go b/bootstrap/init.go index 98f2c8da..c71e0d6a 100644 --- a/bootstrap/init.go +++ b/bootstrap/init.go @@ -28,6 +28,11 @@ func Init(path string) { email.Init() crontab.Init() InitStatic() + } else { + if conf.SlaveConfig.Aria2 { + model.Init() + aria2.Init(false) + } } auth.Init() } diff --git a/models/download.go b/models/download.go index 8b6599d3..d981e385 100644 --- a/models/download.go +++ b/models/download.go @@ -100,6 +100,13 @@ func GetDownloadByGid(gid string, uid uint) (*Download, error) { return download, result.Error } +// GetDownloadById 根据ID查找下载 +func GetDownloadById(id uint) (*Download, error) { + var download Download + result := DB.First(&download, id) + return &download, result.Error +} + // GetOwner 获取下载任务所属用户 func (task *Download) GetOwner() *User { if task.User == nil { diff --git a/models/migration.go b/models/migration.go index 4522c8eb..c8d0e17d 100644 --- a/models/migration.go +++ b/models/migration.go @@ -136,6 +136,8 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti {Name: "aria2_temp_path", Value: ``, Type: "aria2"}, {Name: "aria2_options", Value: `{}`, Type: "aria2"}, {Name: "aria2_interval", Value: `60`, Type: "aria2"}, + {Name: "aria2_remote_enabled", Value: `0`, Type: "aria2"}, + {Name: "aria2_remote_id", Value: `0`, Type: "aria2"}, {Name: "max_worker_num", Value: `10`, Type: "task"}, {Name: "max_parallel_transfer", Value: `4`, Type: "task"}, {Name: "secret_key", Value: util.RandStringRunes(256), Type: "auth"}, diff --git a/pkg/aria2/aria2.go b/pkg/aria2/aria2.go index 40ce36ae..a2513a38 100644 --- a/pkg/aria2/aria2.go +++ b/pkg/aria2/aria2.go @@ -2,6 +2,7 @@ package aria2 import ( "encoding/json" + "github.com/cloudreve/Cloudreve/v3/pkg/conf" "net/url" "sync" @@ -89,6 +90,21 @@ func (instance *DummyAria2) Select(task *model.Download, files []int) error { // Init 初始化 func Init(isReload bool) { + if conf.SystemConfig.Mode == "master" { + MasterInit(isReload) + } else { + SlaveInit(isReload) + } +} + +// SlaveInit 从机初始化 +func SlaveInit(isReload bool) { + if !model.IsTrueVal(model.GetSettingByName("aria2_remote_enabled")) { + return + } + if conf.SlaveConfig.SlaveId == 0 || model.GetIntSetting("aria2_remote_id", 0) != int(conf.SlaveConfig.SlaveId) { + return + } Lock.Lock() defer Lock.Unlock() @@ -136,16 +152,82 @@ func Init(isReload bool) { Instance = client - if !isReload { - // 从数据库中读取未完成任务,创建监控 - unfinished := model.GetDownloadsByStatus(Ready, Paused, Downloading) + // monitor +} + +// MasterInit 主机初始化 +func MasterInit(isReload bool) { + Lock.Lock() + defer Lock.Unlock() - for i := 0; i < len(unfinished); i++ { - // 创建任务监控 - NewMonitor(&unfinished[i]) + if !model.IsTrueVal(model.GetSettingByName("aria2_remote_enabled")) { + // 关闭上个初始连接 + if previousClient, ok := Instance.(*RPCService); ok { + if previousClient.Caller != nil { + util.Log().Debug("关闭上个 aria2 连接") + previousClient.Caller.Close() + } + } + + options := model.GetSettingByNames("aria2_rpcurl", "aria2_token", "aria2_options") + timeout := model.GetIntSetting("aria2_call_timeout", 5) + if options["aria2_rpcurl"] == "" { + Instance = &DummyAria2{} + return } - } + util.Log().Info("初始化 aria2 RPC 服务[%s]", options["aria2_rpcurl"]) + client := &RPCService{} + + // 解析RPC服务地址 + server, err := url.Parse(options["aria2_rpcurl"]) + if err != nil { + util.Log().Warning("无法解析 aria2 RPC 服务地址,%s", err) + Instance = &DummyAria2{} + return + } + server.Path = "/jsonrpc" + + // 加载自定义下载配置 + var globalOptions map[string]interface{} + err = json.Unmarshal([]byte(options["aria2_options"]), &globalOptions) + if err != nil { + util.Log().Warning("无法解析 aria2 全局配置,%s", err) + Instance = &DummyAria2{} + return + } + + if err := client.Init(server.String(), options["aria2_token"], timeout, globalOptions); err != nil { + util.Log().Warning("初始化 aria2 RPC 服务失败,%s", err) + Instance = &DummyAria2{} + return + } + + Instance = client + + if !isReload { + // 从数据库中读取未完成任务,创建监控 + unfinished := model.GetDownloadsByStatus(Ready, Paused, Downloading) + + for i := 0; i < len(unfinished); i++ { + // 创建任务监控 + NewMonitor(&unfinished[i]) + } + } + } else { + util.Log().Info("初始化 从机 aria2 RPC 服务") + remote, err := model.GetPolicyByID(uint(model.GetIntSetting("aria2_remote_id", 0))) + if err != nil { + util.Log().Warning("初始化 从机 aria2 RPC 服务失败,%s", err) + Instance = &DummyAria2{} + return + } + + client := &RemoteService{} + + client.Init(&remote) + Instance = client + } } // getStatus 将给定的状态字符串转换为状态标识数字 diff --git a/pkg/aria2/caller.go b/pkg/aria2/caller.go index 6e287a2a..63151f82 100644 --- a/pkg/aria2/caller.go +++ b/pkg/aria2/caller.go @@ -111,7 +111,7 @@ func (client *RPCService) CreateTask(task *model.Download, groupOptions map[stri // 保存到数据库 task.GID = gid - _, err = task.Create() + err = task.Save() if err != nil { return err } diff --git a/pkg/aria2/remote_caller.go b/pkg/aria2/remote_caller.go new file mode 100644 index 00000000..fd11eff0 --- /dev/null +++ b/pkg/aria2/remote_caller.go @@ -0,0 +1,103 @@ +package aria2 + +import ( + "encoding/json" + "errors" + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" + "github.com/cloudreve/Cloudreve/v3/pkg/auth" + "github.com/cloudreve/Cloudreve/v3/pkg/request" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "net/url" + "path" + "strings" +) + +// RemoteService 通过从机RPC服务的Aria2任务管理器 +type RemoteService struct { + Policy *model.Policy + Client request.Client + AuthInstance auth.Auth +} + +func (client *RemoteService) Init(policy *model.Policy) { + client.Policy = policy + client.Client = request.HTTPClient{} + client.AuthInstance = auth.HMACAuth{SecretKey: []byte(client.Policy.SecretKey)} +} + +func (client *RemoteService) CreateTask(task *model.Download, options map[string]interface{}) error { + reqBody := serializer.RemoteAria2AddRequest{ + TaskId: task.ID, + Options: options, + } + reqBodyEncoded, err := json.Marshal(reqBody) + if err != nil { + return err + } + + // 发送列表请求 + bodyReader := strings.NewReader(string(reqBodyEncoded)) + signTTL := model.GetIntSetting("slave_api_timeout", 60) + resp, err := client.Client.Request( + "POST", + client.getAPIUrl("add"), + bodyReader, + request.WithCredential(client.AuthInstance, int64(signTTL)), + ).CheckHTTPResponse(200).DecodeResponse() + if err != nil { + return err + } + + // 处理列取结果 + if resp.Code != 0 { + return errors.New(resp.Error) + } + + if resStr, ok := resp.Data.(string); ok { + var res serializer.Response + err = json.Unmarshal([]byte(resStr), &res) + if err != nil { + return err + } + if res.Code != 0 { + return errors.New(res.Msg) + } + } + + return nil +} + +func (client *RemoteService) Status(task *model.Download) (rpc.StatusInfo, error) { + panic("implement me") +} + +func (client *RemoteService) Cancel(task *model.Download) error { + panic("implement me") +} + +func (client *RemoteService) Select(task *model.Download, files []int) error { + panic("implement me") +} + +// getAPIUrl 获取接口请求地址 +func (client *RemoteService) getAPIUrl(scope string, routes ...string) string { + serverURL, err := url.Parse(client.Policy.Server) + if err != nil { + return "" + } + var controller *url.URL + + switch scope { + case "add": + controller, _ = url.Parse("/api/v3/slave/aria2/add") + default: + controller = serverURL + } + + for _, r := range routes { + controller.Path = path.Join(controller.Path, r) + } + + return serverURL.ResolveReference(controller).String() +} diff --git a/pkg/conf/conf.go b/pkg/conf/conf.go index b9be97c4..da6d3256 100644 --- a/pkg/conf/conf.go +++ b/pkg/conf/conf.go @@ -42,6 +42,8 @@ type slave struct { Secret string `validate:"omitempty,gte=64"` CallbackTimeout int `validate:"omitempty,gte=1"` SignatureTTL int `validate:"omitempty,gte=1"` + SlaveId uint `validate:"omitempty"` + Aria2 bool `validate:"omitempty"` } // captcha 验证码配置 diff --git a/pkg/serializer/slave.go b/pkg/serializer/slave.go index e23e809d..1cf835bb 100644 --- a/pkg/serializer/slave.go +++ b/pkg/serializer/slave.go @@ -10,3 +10,8 @@ type ListRequest struct { Path string `json:"path"` Recursive bool `json:"recursive"` } + +type RemoteAria2AddRequest struct { + TaskId uint `json:"task_id"` + Options map[string]interface{} `json:"options"` +} diff --git a/routers/controllers/slave.go b/routers/controllers/slave.go index e10e2b0c..71652fe1 100644 --- a/routers/controllers/slave.go +++ b/routers/controllers/slave.go @@ -2,6 +2,7 @@ package controllers import ( "context" + "github.com/cloudreve/Cloudreve/v3/service/slave" "net/url" "strconv" @@ -175,3 +176,14 @@ func SlaveList(c *gin.Context) { c.JSON(200, ErrorResponse(err)) } } + +// SlaveAria2Add 从机创建远程下载任务 +func SlaveAria2Add(c *gin.Context) { + var service slave.Aria2AddService + if err := c.ShouldBindJSON(&service); err == nil { + res := service.Add() + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} diff --git a/routers/router.go b/routers/router.go index 8cf5a9d2..cbffe306 100644 --- a/routers/router.go +++ b/routers/router.go @@ -50,6 +50,12 @@ func InitSlaveRouter() *gin.Engine { // 列出文件 v3.POST("list", controllers.SlaveList) } + + aria2 := v3.Group("aria2") + aria2.Use(middleware.SignRequired()) + { + aria2.POST("add", controllers.SlaveAria2Add) + } return r } diff --git a/service/admin/aria2.go b/service/admin/aria2.go index 8801c962..6c3308d9 100644 --- a/service/admin/aria2.go +++ b/service/admin/aria2.go @@ -1,6 +1,7 @@ package admin import ( + model "github.com/cloudreve/Cloudreve/v3/models" "net/url" "github.com/cloudreve/Cloudreve/v3/pkg/aria2" @@ -15,29 +16,34 @@ type Aria2TestService struct { // Test 测试aria2连接 func (service *Aria2TestService) Test() serializer.Response { - testRPC := aria2.RPCService{} - - // 解析RPC服务地址 - server, err := url.Parse(service.Server) - if err != nil { - return serializer.ParamErr("无法解析 aria2 RPC 服务地址, "+err.Error(), nil) - } - server.Path = "/jsonrpc" - - if err := testRPC.Init(server.String(), service.Token, 5, map[string]interface{}{}); err != nil { - return serializer.ParamErr("无法初始化连接, "+err.Error(), nil) + if !model.IsTrueVal(model.GetSettingByName("aria2_remote_enabled")) { + testRPC := aria2.RPCService{} + + // 解析RPC服务地址 + server, err := url.Parse(service.Server) + if err != nil { + return serializer.ParamErr("无法解析 aria2 RPC 服务地址, "+err.Error(), nil) + } + server.Path = "/jsonrpc" + + if err := testRPC.Init(server.String(), service.Token, 5, map[string]interface{}{}); err != nil { + return serializer.ParamErr("无法初始化连接, "+err.Error(), nil) + } + + defer testRPC.Caller.Close() + + info, err := testRPC.Caller.GetVersion() + if err != nil { + return serializer.ParamErr("无法请求 RPC 服务, "+err.Error(), nil) + } + + if info.Version == "" { + return serializer.ParamErr("RPC 服务返回非预期响应", nil) + } + + return serializer.Response{Data: info.Version} + } else { + // TODO + return serializer.Response{Data: "TODO"} } - - defer testRPC.Caller.Close() - - info, err := testRPC.Caller.GetVersion() - if err != nil { - return serializer.ParamErr("无法请求 RPC 服务, "+err.Error(), nil) - } - - if info.Version == "" { - return serializer.ParamErr("RPC 服务返回非预期响应", nil) - } - - return serializer.Response{Data: info.Version} } diff --git a/service/aria2/add.go b/service/aria2/add.go index be7213ac..5b95fd64 100644 --- a/service/aria2/add.go +++ b/service/aria2/add.go @@ -42,6 +42,11 @@ func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Respo Source: service.URL, } + _, err = task.Create() + if err != nil { + return serializer.Err(serializer.CodeNotSet, "任务创建失败", err) + } + aria2.Lock.RLock() if err := aria2.Instance.CreateTask(task, fs.User.Group.OptionsSerialized.Aria2Options); err != nil { aria2.Lock.RUnlock() diff --git a/service/slave/aria2.go b/service/slave/aria2.go new file mode 100644 index 00000000..6e222085 --- /dev/null +++ b/service/slave/aria2.go @@ -0,0 +1,28 @@ +package slave + +import ( + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/cloudreve/Cloudreve/v3/pkg/util" +) + +type Aria2AddService struct { + TaskId uint `json:"task_id"` + Options map[string]interface{} `json:"options"` +} + +func (service *Aria2AddService) Add() serializer.Response { + task, err := model.GetDownloadById(service.TaskId) + if err != nil { + util.Log().Warning("无法获取记录, %s", err) + return serializer.Err(serializer.CodeNotSet, "任务创建失败, 无法获取记录", err) + } + aria2.Lock.RLock() + if err := aria2.Instance.CreateTask(task, service.Options); err != nil { + aria2.Lock.RUnlock() + return serializer.Err(serializer.CodeNotSet, "任务创建失败", err) + } + aria2.Lock.RUnlock() + return serializer.Response{} +} From 08ebfa60b28754ee01578739ad9007becce2d2e8 Mon Sep 17 00:00:00 2001 From: Cian John Date: Sat, 20 Mar 2021 17:59:06 +0800 Subject: [PATCH 2/5] =?UTF-8?q?Feat:=20=E4=BB=8E=E6=9C=BA=E4=BD=BF?= =?UTF-8?q?=E7=94=A8=E5=A4=8D=E5=88=B6=E4=B8=8A=E4=BC=A0=E5=88=B0=E8=87=AA?= =?UTF-8?q?=E5=B7=B1=E7=9A=84=E7=AD=96=E7=95=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/filesystem/upload.go | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/pkg/filesystem/upload.go b/pkg/filesystem/upload.go index 2ff8997f..41a302c7 100644 --- a/pkg/filesystem/upload.go +++ b/pkg/filesystem/upload.go @@ -2,6 +2,7 @@ package filesystem import ( "context" + "github.com/cloudreve/Cloudreve/v3/pkg/conf" "io" "os" "path" @@ -42,6 +43,26 @@ func (fs *FileSystem) Upload(ctx context.Context, file FileHeader) (err error) { } ctx = context.WithValue(ctx, fsctx.SavePathCtx, savePath) + if conf.SystemConfig.Mode == "slave" && fs.Policy.Type == "remote" && fs.Policy.ID == conf.SlaveConfig.SlaveId { + fs.Handler = &local.Driver{} + fs.Policy.Type = "remote-local" + // 生成上传策略 + policy := serializer.UploadPolicy{ + SavePath: path.Dir(savePath), + FileName: path.Base(savePath), + AutoRename: false, + MaxSize: file.GetSize(), + } + ctx = context.WithValue(ctx, fsctx.UploadPolicyCtx, policy) + + // 执行上传 + err = fs.Upload(ctx, file) + if err != nil { + return err + } + return nil + } + // 处理客户端未完成上传时,关闭连接 go fs.CancelUpload(ctx, savePath, file) From 194338a89352f8b86c19db5c145aa1a73892af01 Mon Sep 17 00:00:00 2001 From: Cian John Date: Sat, 20 Mar 2021 17:59:55 +0800 Subject: [PATCH 3/5] =?UTF-8?q?Feat(remotearia2):=20=E8=BD=AC=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bootstrap/init.go | 1 + pkg/aria2/aria2.go | 10 +++++++++- pkg/task/tranfer.go | 8 ++++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/bootstrap/init.go b/bootstrap/init.go index c71e0d6a..d1002975 100644 --- a/bootstrap/init.go +++ b/bootstrap/init.go @@ -31,6 +31,7 @@ func Init(path string) { } else { if conf.SlaveConfig.Aria2 { model.Init() + task.Init() aria2.Init(false) } } diff --git a/pkg/aria2/aria2.go b/pkg/aria2/aria2.go index a2513a38..641c4052 100644 --- a/pkg/aria2/aria2.go +++ b/pkg/aria2/aria2.go @@ -152,7 +152,15 @@ func SlaveInit(isReload bool) { Instance = client - // monitor + if !isReload { + // 从数据库中读取未完成任务,创建监控 + unfinished := model.GetDownloadsByStatus(Ready, Paused, Downloading) + + for i := 0; i < len(unfinished); i++ { + // 创建任务监控 + NewMonitor(&unfinished[i]) + } + } } // MasterInit 主机初始化 diff --git a/pkg/task/tranfer.go b/pkg/task/tranfer.go index 8cdc2474..4c555df6 100644 --- a/pkg/task/tranfer.go +++ b/pkg/task/tranfer.go @@ -3,6 +3,7 @@ package task import ( "context" "encoding/json" + "github.com/cloudreve/Cloudreve/v3/pkg/conf" "os" "path" "path/filepath" @@ -84,6 +85,13 @@ func (job *TransferTask) GetError() *JobError { // Do 开始执行任务 func (job *TransferTask) Do() { defer job.Recycle() + if model.IsTrueVal(model.GetSettingByName("aria2_remote_enabled")) && conf.SystemConfig.Mode != "slave" { + return + } + + if conf.SlaveConfig.SlaveId == 0 || model.GetIntSetting("aria2_remote_id", 0) != int(conf.SlaveConfig.SlaveId) { + return + } // 创建文件系统 fs, err := filesystem.NewFileSystem(job.User) From 6d04ef112f26f07b07d25e98a554c7405394459c Mon Sep 17 00:00:00 2001 From: Cian John Date: Sat, 20 Mar 2021 18:32:25 +0800 Subject: [PATCH 4/5] =?UTF-8?q?Feat(remotearia2):=20=E5=88=A0=E9=99=A4?= =?UTF-8?q?=E4=BB=BB=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/aria2/remote_caller.go | 38 +++++++++++++++++++++--------------- routers/controllers/slave.go | 11 +++++++++++ routers/router.go | 1 + service/slave/aria2.go | 22 +++++++++++++++++++++ 4 files changed, 56 insertions(+), 16 deletions(-) diff --git a/pkg/aria2/remote_caller.go b/pkg/aria2/remote_caller.go index fd11eff0..19b92399 100644 --- a/pkg/aria2/remote_caller.go +++ b/pkg/aria2/remote_caller.go @@ -10,6 +10,7 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "net/url" "path" + "strconv" "strings" ) @@ -36,7 +37,6 @@ func (client *RemoteService) CreateTask(task *model.Download, options map[string return err } - // 发送列表请求 bodyReader := strings.NewReader(string(reqBodyEncoded)) signTTL := model.GetIntSetting("slave_api_timeout", 60) resp, err := client.Client.Request( @@ -49,31 +49,35 @@ func (client *RemoteService) CreateTask(task *model.Download, options map[string return err } - // 处理列取结果 if resp.Code != 0 { - return errors.New(resp.Error) - } - - if resStr, ok := resp.Data.(string); ok { - var res serializer.Response - err = json.Unmarshal([]byte(resStr), &res) - if err != nil { - return err - } - if res.Code != 0 { - return errors.New(res.Msg) - } + return errors.New(resp.Msg) } return nil } func (client *RemoteService) Status(task *model.Download) (rpc.StatusInfo, error) { - panic("implement me") + // 远程 Aria2 不会使用此方法 + return rpc.StatusInfo{}, nil } func (client *RemoteService) Cancel(task *model.Download) error { - panic("implement me") + signTTL := model.GetIntSetting("slave_api_timeout", 60) + resp, err := client.Client.Request( + "POST", + client.getAPIUrl("cancel", strconv.Itoa(int(task.ID))), + nil, + request.WithCredential(client.AuthInstance, int64(signTTL)), + ).CheckHTTPResponse(200).DecodeResponse() + if err != nil { + return err + } + + if resp.Code != 0 { + return errors.New(resp.Error) + } + + return nil } func (client *RemoteService) Select(task *model.Download, files []int) error { @@ -91,6 +95,8 @@ func (client *RemoteService) getAPIUrl(scope string, routes ...string) string { switch scope { case "add": controller, _ = url.Parse("/api/v3/slave/aria2/add") + case "cancel": + controller, _ = url.Parse("/api/v3/slave/aria2/cancel") default: controller = serverURL } diff --git a/routers/controllers/slave.go b/routers/controllers/slave.go index 71652fe1..5cb15940 100644 --- a/routers/controllers/slave.go +++ b/routers/controllers/slave.go @@ -187,3 +187,14 @@ func SlaveAria2Add(c *gin.Context) { c.JSON(200, ErrorResponse(err)) } } + +// SlaveAria2Cancel 从机删除远程下载任务 +func SlaveAria2Cancel(c *gin.Context) { + var service slave.Aria2CancelService + if err := c.ShouldBindUri(&service); err == nil { + res := service.Cancel() + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} diff --git a/routers/router.go b/routers/router.go index cbffe306..c1e32d73 100644 --- a/routers/router.go +++ b/routers/router.go @@ -55,6 +55,7 @@ func InitSlaveRouter() *gin.Engine { aria2.Use(middleware.SignRequired()) { aria2.POST("add", controllers.SlaveAria2Add) + aria2.POST("cancel/:taskId", controllers.SlaveAria2Cancel) } return r } diff --git a/service/slave/aria2.go b/service/slave/aria2.go index 6e222085..145e82fa 100644 --- a/service/slave/aria2.go +++ b/service/slave/aria2.go @@ -12,6 +12,10 @@ type Aria2AddService struct { Options map[string]interface{} `json:"options"` } +type Aria2CancelService struct { + TaskId uint `uri:"taskId"` +} + func (service *Aria2AddService) Add() serializer.Response { task, err := model.GetDownloadById(service.TaskId) if err != nil { @@ -26,3 +30,21 @@ func (service *Aria2AddService) Add() serializer.Response { aria2.Lock.RUnlock() return serializer.Response{} } + +func (service *Aria2CancelService) Cancel() serializer.Response { + task, err := model.GetDownloadById(service.TaskId) + if err != nil { + util.Log().Warning("无法获取记录, %s", err) + return serializer.Err(serializer.CodeNotSet, "任务创建失败, 无法获取记录", err) + } + + // 取消任务 + aria2.Lock.RLock() + defer aria2.Lock.RUnlock() + if err := aria2.Instance.Cancel(task); err != nil { + util.Log().Debug("删除远程下载任务出错, %s", err) + return serializer.Err(serializer.CodeNotSet, "操作失败", err) + } + + return serializer.Response{} +} From e1ec5481dd4ffe53dc75a55d8f3ea8f87fe344e5 Mon Sep 17 00:00:00 2001 From: Cian John Date: Sat, 20 Mar 2021 18:55:23 +0800 Subject: [PATCH 5/5] Feat(remotearia2): tip --- service/admin/aria2.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/service/admin/aria2.go b/service/admin/aria2.go index 6c3308d9..921dc65d 100644 --- a/service/admin/aria2.go +++ b/service/admin/aria2.go @@ -43,7 +43,6 @@ func (service *Aria2TestService) Test() serializer.Response { return serializer.Response{Data: info.Version} } else { - // TODO - return serializer.Response{Data: "TODO"} + return serializer.Response{Data: "从机离线无法进行测试"} } }