From 3ed84ad5ec679609a5b6e55a726231fb668b8641 Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Wed, 5 Feb 2020 12:58:26 +0800 Subject: [PATCH] Feat: validate / cancel task while downloading file in aria2 --- models/download.go | 13 +++++ models/task.go | 2 +- pkg/aria2/Monitor.go | 79 ++++++++++++++++++++++++++- pkg/aria2/aria2.go | 10 +++- pkg/aria2/caller.go | 6 ++ pkg/filesystem/driver/onedrive/api.go | 3 +- pkg/filesystem/hooks.go | 10 ++++ 7 files changed, 119 insertions(+), 4 deletions(-) diff --git a/models/download.go b/models/download.go index 462fc34..599ff2c 100644 --- a/models/download.go +++ b/models/download.go @@ -22,6 +22,9 @@ type Download struct { Dst string `gorm:"type:text"` // 用户文件系统存储父目录路径 UserID uint // 发起者UID TaskID uint // 对应的转存任务ID + + // 关联模型 + User *User `gorm:"PRELOAD:false,association_autoupdate:false"` } // Create 创建离线下载记录 @@ -48,3 +51,13 @@ func GetDownloadsByStatus(status ...int) []Download { DB.Where("status in (?)", status).Find(&tasks) return tasks } + +// GetOwner 获取下载任务所属用户 +func (task *Download) GetOwner() *User { + if task.User == nil { + if user, err := GetUserByID(task.UserID); err == nil { + return &user + } + } + return task.User +} diff --git a/models/task.go b/models/task.go index 400fc91..1efbada 100644 --- a/models/task.go +++ b/models/task.go @@ -12,7 +12,7 @@ type Task struct { Type int // 任务类型 UserID uint // 发起者UID,0表示为系统发起 Progress int // 进度 - Error string // 错误信息 + Error string `gorm:"type:text"` // 错误信息 Props string `gorm:"type:text"` // 任务属性 } diff --git a/pkg/aria2/Monitor.go b/pkg/aria2/Monitor.go index d3ec66c..7cc4eb0 100644 --- a/pkg/aria2/Monitor.go +++ b/pkg/aria2/Monitor.go @@ -1,9 +1,13 @@ package aria2 import ( + "context" "encoding/json" "errors" model "github.com/HFO4/cloudreve/models" + "github.com/HFO4/cloudreve/pkg/filesystem" + "github.com/HFO4/cloudreve/pkg/filesystem/driver/local" + "github.com/HFO4/cloudreve/pkg/filesystem/fsctx" "github.com/HFO4/cloudreve/pkg/task" "github.com/HFO4/cloudreve/pkg/util" "github.com/zyxar/argo/rpc" @@ -71,9 +75,18 @@ func (monitor *Monitor) Update() bool { return true } + // 磁力链下载需要跟随 + if len(status.FollowedBy) > 0 { + util.Log().Debug("离线下载[%s]重定向至[%s]", monitor.Task.GID, status.FollowedBy[0]) + monitor.Task.GID = status.FollowedBy[0] + monitor.Task.Save() + return false + } + // 更新任务信息 if err := monitor.UpdateTaskInfo(status); err != nil { util.Log().Warning("无法更新下载任务[%s]的任务信息[%s],", monitor.Task.GID, err) + monitor.setErrorStatus(err) return true } @@ -96,6 +109,9 @@ func (monitor *Monitor) Update() bool { // UpdateTaskInfo 更新数据库中的任务信息 func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error { + originSize := monitor.Task.TotalSize + originPath := monitor.Task.Path + monitor.Task.GID = status.Gid monitor.Task.Status = getStatus(status.Status) @@ -126,7 +142,68 @@ func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error { attrs, _ := json.Marshal(status) monitor.Task.Attrs = string(attrs) - return monitor.Task.Save() + if err := monitor.Task.Save(); err != nil { + return nil + } + + if originSize != monitor.Task.TotalSize || originPath != monitor.Task.Path { + // 大小、文件名更新后,对文件限制等进行校验 + if err := monitor.ValidateFile(); err != nil { + // 验证失败时取消任务 + monitor.Cancel() + return err + } + } + + return nil +} + +// Cancel 取消上传并尝试删除临时文件 +func (monitor *Monitor) Cancel() { + if err := Instance.Cancel(monitor.Task); err != nil { + util.Log().Warning("无法取消离线下载任务[%s], %s", monitor.Task.GID, err) + } + util.Log().Debug("离线下载任务[%s]已取消,1 分钟后删除临时文件", monitor.Task.GID) + go func(monitor *Monitor) { + select { + case <-time.After(time.Duration(60) * time.Second): + monitor.RemoveTempFolder() + } + }(monitor) +} + +// ValidateFile 上传过程中校验文件大小、文件名 +func (monitor *Monitor) ValidateFile() error { + // 找到任务创建者 + user := monitor.Task.GetOwner() + if user == nil { + return ErrUserNotFound + } + + // 创建文件系统 + fs, err := filesystem.NewFileSystem(user) + if err != nil { + return err + } + defer fs.Recycle() + + // 创建上下文环境 + ctx := context.WithValue(context.Background(), fsctx.FileHeaderCtx, local.FileStream{ + Size: monitor.Task.TotalSize, + Name: filepath.Base(monitor.Task.Path), + }) + + // 验证文件 + if err := filesystem.HookValidateFile(ctx, fs); err != nil { + return err + } + + // 验证用户容量 + if err := filesystem.HookValidateCapacityWithoutIncrease(ctx, fs); err != nil { + return err + } + + return nil } // Error 任务下载出错处理,返回是否中断监控 diff --git a/pkg/aria2/aria2.go b/pkg/aria2/aria2.go index 98af20a..bef0b7d 100644 --- a/pkg/aria2/aria2.go +++ b/pkg/aria2/aria2.go @@ -20,6 +20,8 @@ type Aria2 interface { CreateTask(task *model.Download) error // 返回状态信息 Status(task *model.Download) (rpc.StatusInfo, error) + // 取消任务 + Cancel(task *model.Download) error } const ( @@ -48,7 +50,8 @@ const ( var ( // ErrNotEnabled 功能未开启错误 - ErrNotEnabled = serializer.NewError(serializer.CodeNoPermissionErr, "离线下载功能未开启", nil) + ErrNotEnabled = serializer.NewError(serializer.CodeNoPermissionErr, "离线下载功能未开启", nil) + ErrUserNotFound = serializer.NewError(serializer.CodeNotFound, "无法找到任务创建者", nil) ) // DummyAria2 未开启Aria2功能时使用的默认处理器 @@ -65,6 +68,11 @@ func (instance *DummyAria2) Status(task *model.Download) (rpc.StatusInfo, error) return rpc.StatusInfo{}, ErrNotEnabled } +// Cancel 返回未开启错误 +func (instance *DummyAria2) Cancel(task *model.Download) error { + return ErrNotEnabled +} + // Init 初始化 func Init() { options := model.GetSettingByNames("aria2_rpcurl", "aria2_token", "aria2_options") diff --git a/pkg/aria2/caller.go b/pkg/aria2/caller.go index 57fe809..4a7a19d 100644 --- a/pkg/aria2/caller.go +++ b/pkg/aria2/caller.go @@ -40,6 +40,12 @@ func (client *RPCService) Status(task *model.Download) (rpc.StatusInfo, error) { return client.caller.TellStatus(task.GID) } +// Cancel 取消下载 +func (client *RPCService) Cancel(task *model.Download) error { + _, err := client.caller.Remove(task.GID) + return err +} + // CreateTask 创建新任务 func (client *RPCService) CreateTask(task *model.Download) error { // 生成存储路径 diff --git a/pkg/filesystem/driver/onedrive/api.go b/pkg/filesystem/driver/onedrive/api.go index 739054e..d8e2122 100644 --- a/pkg/filesystem/driver/onedrive/api.go +++ b/pkg/filesystem/driver/onedrive/api.go @@ -153,7 +153,7 @@ func (client *Client) UploadChunk(ctx context.Context, uploadURL string, chunk * // 如果重试次数小于限制,5秒后重试 if chunk.Retried < model.GetIntSetting("onedrive_chunk_retries", 1) { chunk.Retried++ - util.Log().Debug("分片偏移%d上传失败,5秒钟后重试", chunk.Offset) + util.Log().Debug("分片偏移%d上传失败[%s],5秒钟后重试", chunk.Offset, err) time.Sleep(time.Duration(5) * time.Second) return client.UploadChunk(ctx, uploadURL, chunk) } @@ -518,6 +518,7 @@ func (client *Client) request(ctx context.Context, method string, url string, bo if res.Response.StatusCode < 200 || res.Response.StatusCode >= 300 { decodeErr = json.Unmarshal([]byte(respBody), &errResp) if decodeErr != nil { + util.Log().Debug("Onedrive返回未知响应[%s]", respBody) return "", sysError(decodeErr) } return "", &errResp diff --git a/pkg/filesystem/hooks.go b/pkg/filesystem/hooks.go index 725d689..20d3cd4 100644 --- a/pkg/filesystem/hooks.go +++ b/pkg/filesystem/hooks.go @@ -128,6 +128,16 @@ func HookValidateCapacity(ctx context.Context, fs *FileSystem) error { return nil } +// HookValidateCapacityWithoutIncrease 验证用户容量,不扣除 +func HookValidateCapacityWithoutIncrease(ctx context.Context, fs *FileSystem) error { + file := ctx.Value(fsctx.FileHeaderCtx).(FileHeader) + // 验证并扣除容量 + if fs.User.GetRemainingCapacity() < file.GetSize() { + return ErrInsufficientCapacity + } + return nil +} + // HookChangeCapacity 根据原有文件和新文件的大小更新用户容量 func HookChangeCapacity(ctx context.Context, fs *FileSystem) error { newFile := ctx.Value(fsctx.FileHeaderCtx).(FileHeader)