From 491e4de9deb24c7d204c4a9fd09f9755cc185501 Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Wed, 5 Feb 2020 15:11:34 +0800 Subject: [PATCH] Feat: download torrent / multiple file / select file --- go.mod | 1 + go.sum | 2 + models/download.go | 32 +++++++++++++-- pkg/aria2/Monitor.go | 56 ++++++++++++-------------- pkg/aria2/aria2.go | 7 ++++ pkg/aria2/caller.go | 18 ++++++++- pkg/filesystem/driver/onedrive/api.go | 1 + pkg/task/tranfer.go | 30 +++++++------- routers/controllers/aria2.go | 57 ++++++++++++++++++++++++++- routers/router.go | 5 +++ service/aria2/add.go | 6 +-- service/aria2/manage.go | 37 +++++++++++++++++ 12 files changed, 197 insertions(+), 55 deletions(-) create mode 100644 service/aria2/manage.go diff --git a/go.mod b/go.mod index 83f415c..50b546f 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/go-ini/ini v1.50.0 github.com/gomodule/redigo v2.0.0+incompatible github.com/google/go-querystring v1.0.0 + github.com/gorilla/websocket v1.4.1 github.com/jinzhu/gorm v1.9.11 github.com/juju/ratelimit v1.0.1 github.com/mattn/go-colorable v0.1.4 // indirect diff --git a/go.sum b/go.sum index 58e655b..1cd8134 100644 --- a/go.sum +++ b/go.sum @@ -91,6 +91,8 @@ github.com/gorilla/sessions v1.1.3 h1:uXoZdcdA5XdXF3QzuSlheVRUvjl+1rKY7zBXL68L9R github.com/gorilla/sessions v1.1.3/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w= github.com/gorilla/websocket v1.4.0 h1:WDFjx/TMzVgy9VdMMQi2K2Emtwi2QcUQsztZ/zLaH/Q= github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= +github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= +github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/jinzhu/gorm v1.9.11 h1:gaHGvE+UnWGlbWG4Y3FUwY1EcZ5n6S9WtqBA/uySMLE= diff --git a/models/download.go b/models/download.go index 599ff2c..4cd1597 100644 --- a/models/download.go +++ b/models/download.go @@ -1,8 +1,10 @@ package model import ( + "encoding/json" "github.com/HFO4/cloudreve/pkg/util" "github.com/jinzhu/gorm" + "github.com/zyxar/argo/rpc" ) // Download 离线下载队列模型 @@ -10,12 +12,11 @@ type Download struct { gorm.Model Status int // 任务状态 Type int // 任务类型 - Source string // 文件下载地址 + Source string `gorm:"type:text"` // 文件下载地址 TotalSize uint64 // 文件大小 DownloadedSize uint64 // 文件大小 - GID string // 任务ID + GID string `gorm:"size:32,index:gid"` // 任务ID Speed int // 下载速度 - Path string `gorm:"type:text"` // 存储路径 Parent string `gorm:"type:text"` // 存储目录 Attrs string `gorm:"type:text"` // 任务状态属性 Error string `gorm:"type:text"` // 错误描述 @@ -25,6 +26,24 @@ type Download struct { // 关联模型 User *User `gorm:"PRELOAD:false,association_autoupdate:false"` + + // 数据库忽略字段 + StatusInfo rpc.StatusInfo `gorm:"-"` +} + +// AfterFind 找到下载任务后的钩子,处理Status结构 +func (task *Download) AfterFind() (err error) { + // 解析状态 + if task.Attrs != "" { + err = json.Unmarshal([]byte(task.Attrs), &task.StatusInfo) + } + + return err +} + +// BeforeSave Save下载任务前的钩子 +func (task *Download) BeforeSave() (err error) { + return task.AfterFind() } // Create 创建离线下载记录 @@ -52,6 +71,13 @@ func GetDownloadsByStatus(status ...int) []Download { return tasks } +// GetDownloadByGid 根据GID和用户ID查找下载 +func GetDownloadByGid(gid string, uid uint) (*Download, error) { + download := &Download{} + result := DB.Where("user_id = ? and g_id = ?", uid, gid).Find(download) + return download, result.Error +} + // GetOwner 获取下载任务所属用户 func (task *Download) GetOwner() *User { if task.User == nil { diff --git a/pkg/aria2/Monitor.go b/pkg/aria2/Monitor.go index 7cc4eb0..765c094 100644 --- a/pkg/aria2/Monitor.go +++ b/pkg/aria2/Monitor.go @@ -12,7 +12,6 @@ import ( "github.com/HFO4/cloudreve/pkg/util" "github.com/zyxar/argo/rpc" "os" - "path" "path/filepath" "strconv" "time" @@ -110,7 +109,6 @@ func (monitor *Monitor) Update() bool { // UpdateTaskInfo 更新数据库中的任务信息 func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error { originSize := monitor.Task.TotalSize - originPath := monitor.Task.Path monitor.Task.GID = status.Gid monitor.Task.Status = getStatus(status.Status) @@ -136,9 +134,6 @@ func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error { } monitor.Task.Speed = speed - if len(status.Files) > 0 { - monitor.Task.Path = status.Files[0].Path - } attrs, _ := json.Marshal(status) monitor.Task.Attrs = string(attrs) @@ -146,8 +141,8 @@ func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error { return nil } - if originSize != monitor.Task.TotalSize || originPath != monitor.Task.Path { - // 大小、文件名更新后,对文件限制等进行校验 + if originSize != monitor.Task.TotalSize { + // 文件大小更新后,对文件限制等进行校验 if err := monitor.ValidateFile(); err != nil { // 验证失败时取消任务 monitor.Cancel() @@ -190,19 +185,29 @@ func (monitor *Monitor) ValidateFile() error { // 创建上下文环境 ctx := context.WithValue(context.Background(), fsctx.FileHeaderCtx, local.FileStream{ Size: monitor.Task.TotalSize, - Name: filepath.Base(monitor.Task.Path), }) - // 验证文件 - if err := filesystem.HookValidateFile(ctx, fs); err != nil { - return err - } - // 验证用户容量 if err := filesystem.HookValidateCapacityWithoutIncrease(ctx, fs); err != nil { return err } + // 验证每个文件 + for _, fileInfo := range monitor.Task.StatusInfo.Files { + if fileInfo.Selected == "true" { + // 创建上下文环境 + fileSize, _ := strconv.ParseUint(fileInfo.Length, 10, 64) + ctx := context.WithValue(context.Background(), fsctx.FileHeaderCtx, local.FileStream{ + Size: fileSize, + Name: filepath.Base(fileInfo.Path), + }) + if err := filesystem.HookValidateFile(ctx, fs); err != nil { + return err + } + } + + } + return nil } @@ -216,21 +221,6 @@ func (monitor *Monitor) Error(status rpc.StatusInfo) bool { return true } -// RemoveTempFile 清理下载临时文件 -func (monitor *Monitor) RemoveTempFile() { - err := os.Remove(monitor.Task.Path) - if err != nil { - util.Log().Warning("无法删除离线下载临时文件[%s], %s", monitor.Task.Path, err) - } - - if empty, _ := util.IsEmpty(monitor.Task.Parent); empty { - err := os.Remove(monitor.Task.Parent) - if err != nil { - util.Log().Warning("无法删除离线下载临时目录[%s], %s", monitor.Task.Parent, err) - } - } -} - // RemoveTempFolder 清理下载临时目录 func (monitor *Monitor) RemoveTempFolder() { err := os.RemoveAll(monitor.Task.Parent) @@ -243,10 +233,16 @@ func (monitor *Monitor) RemoveTempFolder() { // Complete 完成下载,返回是否中断监控 func (monitor *Monitor) Complete(status rpc.StatusInfo) bool { // 创建中转任务 + file := make([]string, 0, len(monitor.Task.StatusInfo.Files)) + for i := 0; i < len(monitor.Task.StatusInfo.Files); i++ { + if monitor.Task.StatusInfo.Files[i].Selected == "true" { + file = append(file, monitor.Task.StatusInfo.Files[i].Path) + } + } job, err := task.NewTransferTask( monitor.Task.UserID, - path.Join(monitor.Task.Dst, filepath.Base(monitor.Task.Path)), - monitor.Task.Path, + file, + monitor.Task.Dst, monitor.Task.Parent, ) if err != nil { diff --git a/pkg/aria2/aria2.go b/pkg/aria2/aria2.go index bef0b7d..07c0e2d 100644 --- a/pkg/aria2/aria2.go +++ b/pkg/aria2/aria2.go @@ -22,6 +22,8 @@ type Aria2 interface { Status(task *model.Download) (rpc.StatusInfo, error) // 取消任务 Cancel(task *model.Download) error + // 选择要下载的文件 + Select(task *model.Download, files []int) error } const ( @@ -73,6 +75,11 @@ func (instance *DummyAria2) Cancel(task *model.Download) error { return ErrNotEnabled } +// Select 返回未开启错误 +func (instance *DummyAria2) Select(task *model.Download, files []int) error { + return ErrNotEnabled +} + // Init 初始化 func Init() { options := model.GetSettingByNames("aria2_rpcurl", "aria2_token", "aria2_options") diff --git a/pkg/aria2/caller.go b/pkg/aria2/caller.go index 4a7a19d..2ec6396 100644 --- a/pkg/aria2/caller.go +++ b/pkg/aria2/caller.go @@ -3,9 +3,11 @@ package aria2 import ( "context" model "github.com/HFO4/cloudreve/models" + "github.com/HFO4/cloudreve/pkg/util" "github.com/zyxar/argo/rpc" "path/filepath" "strconv" + "strings" "time" ) @@ -46,20 +48,32 @@ func (client *RPCService) Cancel(task *model.Download) error { return err } +// Select 选取要下载的文件 +func (client *RPCService) Select(task *model.Download, files []int) error { + var selected = make([]string, len(files)) + for i := 0; i < len(files); i++ { + selected[i] = strconv.Itoa(files[i]) + } + ok, err := client.caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")}) + util.Log().Debug(ok) + return err +} + // CreateTask 创建新任务 func (client *RPCService) CreateTask(task *model.Download) error { // 生成存储路径 - task.Path = filepath.Join( + path := filepath.Join( model.GetSettingByName("aria2_temp_path"), "aria2", strconv.FormatInt(time.Now().UnixNano(), 10), ) // 创建下载任务 - options := []interface{}{map[string]string{"dir": task.Path}} + options := []interface{}{map[string]string{"dir": path}} if len(client.options.Options) > 0 { options = append(options, client.options.Options) } + gid, err := client.caller.AddURI(task.Source, options...) if err != nil || gid == "" { return err diff --git a/pkg/filesystem/driver/onedrive/api.go b/pkg/filesystem/driver/onedrive/api.go index d8e2122..7498b10 100644 --- a/pkg/filesystem/driver/onedrive/api.go +++ b/pkg/filesystem/driver/onedrive/api.go @@ -501,6 +501,7 @@ func (client *Client) request(ctx context.Context, method string, url string, bo ) if res.Err != nil { + // TODO 重试 return "", sysError(res.Err) } diff --git a/pkg/task/tranfer.go b/pkg/task/tranfer.go index fa52416..4f90c97 100644 --- a/pkg/task/tranfer.go +++ b/pkg/task/tranfer.go @@ -7,6 +7,8 @@ import ( "github.com/HFO4/cloudreve/pkg/filesystem" "github.com/HFO4/cloudreve/pkg/util" "os" + "path" + "path/filepath" ) // TransferTask 文件中转任务 @@ -21,9 +23,9 @@ type TransferTask struct { // TransferProps 中转任务属性 type TransferProps struct { - Src string `json:"src"` // 原始目录 - Parent string `json:"parent"` // 父目录 - Dst string `json:"dst"` // 目的目录ID + Src []string `json:"src"` // 原始目录 + Parent string `json:"parent"` // 父目录 + Dst string `json:"dst"` // 目的目录ID } // Props 获取任务属性 @@ -86,30 +88,26 @@ func (job *TransferTask) Do() { } defer fs.Recycle() - err = fs.UploadFromPath(context.Background(), job.TaskProps.Src, job.TaskProps.Dst) - if err != nil { - job.SetErrorMsg("文件转存失败", err) - return + for _, file := range job.TaskProps.Src { + err = fs.UploadFromPath(context.Background(), file, path.Join(job.TaskProps.Dst, filepath.Base(file))) + if err != nil { + job.SetErrorMsg("文件转存失败", err) + } } + } // Recycle 回收临时文件 func (job *TransferTask) Recycle() { - err := os.Remove(job.TaskProps.Src) + err := os.RemoveAll(job.TaskProps.Parent) if err != nil { - util.Log().Warning("无法删除中转临时文件[%s], %s", job.TaskProps.Src, err) + util.Log().Warning("无法删除中转临时目录[%s], %s", job.TaskProps.Parent, err) } - if empty, _ := util.IsEmpty(job.TaskProps.Parent); empty { - err := os.Remove(job.TaskProps.Parent) - if err != nil { - util.Log().Warning("无法删除中转临时目录[%s], %s", job.TaskProps.Parent, err) - } - } } // NewTransferTask 新建中转任务 -func NewTransferTask(user uint, dst, src, parent string) (Job, error) { +func NewTransferTask(user uint, src []string, dst, parent string) (Job, error) { creator, err := model.GetUserByID(user) if err != nil { return nil, err diff --git a/routers/controllers/aria2.go b/routers/controllers/aria2.go index 6b25877..3680134 100644 --- a/routers/controllers/aria2.go +++ b/routers/controllers/aria2.go @@ -1,17 +1,72 @@ package controllers import ( + "context" + ariaCall "github.com/HFO4/cloudreve/pkg/aria2" + "github.com/HFO4/cloudreve/pkg/serializer" "github.com/HFO4/cloudreve/service/aria2" + "github.com/HFO4/cloudreve/service/explorer" "github.com/gin-gonic/gin" + "strings" ) // AddAria2URL 添加离线下载URL func AddAria2URL(c *gin.Context) { var addService aria2.AddURLService if err := c.ShouldBindJSON(&addService); err == nil { - res := addService.Add(c) + res := addService.Add(c, ariaCall.URLTask) c.JSON(200, res) } else { c.JSON(200, ErrorResponse(err)) } } + +// SelectAria2File 选择多文件离线下载中要下载的文件 +func SelectAria2File(c *gin.Context) { + var selectService aria2.SelectFileService + if err := c.ShouldBindJSON(&selectService); err == nil { + res := selectService.Select(c) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +// AddAria2Torrent 添加离线下载种子 +func AddAria2Torrent(c *gin.Context) { + // 创建上下文 + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var service explorer.SingleFileService + if err := c.ShouldBindUri(&service); err == nil { + // 验证必须是种子文件 + filePath := c.Param("path") + if !strings.HasSuffix(filePath, ".torrent") { + c.JSON(200, serializer.ParamErr("只能下载 .torrent 文件", nil)) + return + } + + // 获取种子内容的下载地址 + res := service.CreateDownloadSession(ctx, c) + if res.Code != 0 { + c.JSON(200, res) + return + } + + // 创建下载任务 + var addService aria2.AddURLService + addService.URL = res.Data.(string) + + if err := c.ShouldBindJSON(&addService); err == nil { + addService.URL = res.Data.(string) + res := addService.Add(c, ariaCall.URLTask) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } + + } else { + c.JSON(200, ErrorResponse(err)) + } +} diff --git a/routers/router.go b/routers/router.go index b2c9e49..c317c57 100644 --- a/routers/router.go +++ b/routers/router.go @@ -276,7 +276,12 @@ func InitMasterRouter() *gin.Engine { // 离线下载任务 aria2 := auth.Group("aria2") { + // 创建URL下载任务 aria2.POST("url", controllers.AddAria2URL) + // 创建种子下载任务 + aria2.POST("torrent/*path", controllers.AddAria2Torrent) + // 重新选择要下载的文件 + aria2.PUT("select/:gid", controllers.SelectAria2File) } // 目录 diff --git a/service/aria2/add.go b/service/aria2/add.go index 763b9b2..9edbd29 100644 --- a/service/aria2/add.go +++ b/service/aria2/add.go @@ -11,11 +11,11 @@ import ( // AddURLService 添加URL离线下载服务 type AddURLService struct { URL string `json:"url" binding:"required"` - Dst string `json:"dst" binding:"required,min=1,max=65535"` + Dst string `json:"dst" binding:"required,min=1"` } // Add 创建新的链接离线下载任务 -func (service *AddURLService) Add(c *gin.Context) serializer.Response { +func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Response { // 创建文件系统 fs, err := filesystem.NewFileSystemFromContext(c) if err != nil { @@ -36,7 +36,7 @@ func (service *AddURLService) Add(c *gin.Context) serializer.Response { // 创建任务 task := &model.Download{ Status: aria2.Ready, - Type: aria2.URLTask, + Type: taskType, Dst: service.Dst, UserID: fs.User.ID, Source: service.URL, diff --git a/service/aria2/manage.go b/service/aria2/manage.go new file mode 100644 index 0000000..a965753 --- /dev/null +++ b/service/aria2/manage.go @@ -0,0 +1,37 @@ +package aria2 + +import ( + model "github.com/HFO4/cloudreve/models" + "github.com/HFO4/cloudreve/pkg/aria2" + "github.com/HFO4/cloudreve/pkg/serializer" + "github.com/gin-gonic/gin" +) + +// SelectFileService 选择要下载的文件服务 +type SelectFileService struct { + Indexes []int `json:"indexes" binding:"required"` +} + +// Select 选取要下载的文件 +func (service *SelectFileService) Select(c *gin.Context) serializer.Response { + userCtx, _ := c.Get("user") + user := userCtx.(*model.User) + + // 查找下载记录 + download, err := model.GetDownloadByGid(c.Param("gid"), user.ID) + if err != nil { + return serializer.Err(serializer.CodeNotFound, "下载记录不存在", err) + } + + if download.StatusInfo.BitTorrent.Mode != "multi" || (download.Status != aria2.Downloading && download.Status != aria2.Paused) { + return serializer.Err(serializer.CodeNoPermissionErr, "此下载任务无法选取文件", err) + } + + // 选取下载 + if err := aria2.Instance.Select(download, service.Indexes); err != nil { + return serializer.Err(serializer.CodeNotSet, "操作失败", err) + } + + return serializer.Response{} + +}