From 8c7e3883eed0b1487b14ec8a58ba6d04025cbefb Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Wed, 5 Feb 2020 11:22:19 +0800 Subject: [PATCH] Feat: handle aria2 download complete --- models/download.go | 41 +++++--- models/migration.go | 3 +- pkg/aria2/Monitor.go | 194 ++++++++++++++++++++++++++++++++++++++ pkg/aria2/aria2.go | 50 ++++++++++ pkg/aria2/caller.go | 21 ++++- pkg/aria2/notification.go | 63 +++++++++++++ pkg/task/job.go | 4 + pkg/task/tranfer.go | 153 ++++++++++++++++++++++++++++++ pkg/util/io.go | 16 ++++ service/aria2/add.go | 16 ++-- 10 files changed, 536 insertions(+), 25 deletions(-) create mode 100644 pkg/aria2/Monitor.go create mode 100644 pkg/aria2/notification.go create mode 100644 pkg/task/tranfer.go diff --git a/models/download.go b/models/download.go index 89a62cd..462fc34 100644 --- a/models/download.go +++ b/models/download.go @@ -8,17 +8,20 @@ import ( // Download 离线下载队列模型 type Download struct { gorm.Model - Status int // 任务状态 - Type int // 任务类型 - Source string // 文件下载地址 - Name string // 任务文件名 - Size uint64 // 文件大小 - GID string // 任务ID - Path string `gorm:"type:text"` // 存储路径 - Attrs string `gorm:"type:text"` // 任务状态属性 - FolderID uint // 存储父目录ID - UserID uint // 发起者UID - TaskID uint // 对应的转存任务ID + Status int // 任务状态 + Type int // 任务类型 + Source string // 文件下载地址 + TotalSize uint64 // 文件大小 + DownloadedSize uint64 // 文件大小 + GID string // 任务ID + Speed int // 下载速度 + Path string `gorm:"type:text"` // 存储路径 + Parent string `gorm:"type:text"` // 存储目录 + Attrs string `gorm:"type:text"` // 任务状态属性 + Error string `gorm:"type:text"` // 错误描述 + Dst string `gorm:"type:text"` // 用户文件系统存储父目录路径 + UserID uint // 发起者UID + TaskID uint // 对应的转存任务ID } // Create 创建离线下载记录 @@ -29,3 +32,19 @@ func (task *Download) Create() (uint, error) { } return task.ID, nil } + +// Save 更新 +func (task *Download) Save() error { + if err := DB.Save(task).Error; err != nil { + util.Log().Warning("无法更新离线下载记录, %s", err) + return err + } + return nil +} + +// GetDownloadsByStatus 根据状态检索下载 +func GetDownloadsByStatus(status ...int) []Download { + var tasks []Download + DB.Where("status in (?)", status).Find(&tasks) + return tasks +} diff --git a/models/migration.go b/models/migration.go index af6d4a1..cb95bde 100644 --- a/models/migration.go +++ b/models/migration.go @@ -159,8 +159,9 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti {Name: "header", Value: `X-Sendfile`, Type: "download"}, {Name: "aria2_token", Value: `your token`, Type: "aria2"}, {Name: "aria2_token", Value: `your token`, Type: "aria2"}, - {Name: "aria2_temp_path", Value: `F:\aria2-1.33.1-win-64bit-build1\temp`, Type: "aria2"}, + {Name: "aria2_temp_path", Value: ``, Type: "aria2"}, {Name: "aria2_options", Value: `{"max-tries":5}`, Type: "aria2"}, + {Name: "aria2_interval", Value: `10`, Type: "aria2"}, {Name: "max_worker_num", Value: `10`, Type: "task"}, {Name: "max_parallel_transfer", Value: `4`, Type: "task"}, {Name: "secret_key", Value: util.RandStringRunes(256), Type: "auth"}, diff --git a/pkg/aria2/Monitor.go b/pkg/aria2/Monitor.go new file mode 100644 index 0000000..d3ec66c --- /dev/null +++ b/pkg/aria2/Monitor.go @@ -0,0 +1,194 @@ +package aria2 + +import ( + "encoding/json" + "errors" + model "github.com/HFO4/cloudreve/models" + "github.com/HFO4/cloudreve/pkg/task" + "github.com/HFO4/cloudreve/pkg/util" + "github.com/zyxar/argo/rpc" + "os" + "path" + "path/filepath" + "strconv" + "time" +) + +// Monitor 离线下载状态监控 +type Monitor struct { + Task *model.Download + Interval time.Duration + + notifier chan StatusEvent +} + +// StatusEvent 状态改变事件 +type StatusEvent struct { + GID string + Status int +} + +// NewMonitor 新建上传状态监控 +func NewMonitor(task *model.Download) { + monitor := &Monitor{ + Task: task, + Interval: time.Duration(model.GetIntSetting("aria2_interval", 10)) * time.Second, + notifier: make(chan StatusEvent), + } + go monitor.Loop() + EventNotifier.Subscribe(monitor.notifier, monitor.Task.GID) +} + +// Loop 开启监控循环 +func (monitor *Monitor) Loop() { + defer EventNotifier.Unsubscribe(monitor.Task.GID) + + // 首次循环立即更新 + interval := time.Duration(0) + + for { + select { + case <-monitor.notifier: + if monitor.Update() { + return + } + case <-time.After(interval): + interval = monitor.Interval + if monitor.Update() { + return + } + } + } +} + +// Update 更新状态,返回值表示是否退出监控 +func (monitor *Monitor) Update() bool { + status, err := Instance.Status(monitor.Task) + if err != nil { + util.Log().Warning("无法获取下载任务[%s]的状态,%s", monitor.Task.GID, err) + monitor.setErrorStatus(err) + monitor.RemoveTempFolder() + return true + } + + // 更新任务信息 + if err := monitor.UpdateTaskInfo(status); err != nil { + util.Log().Warning("无法更新下载任务[%s]的任务信息[%s],", monitor.Task.GID, err) + return true + } + + util.Log().Debug(status.Status) + + switch status.Status { + case "complete": + return monitor.Complete(status) + case "error": + return monitor.Error(status) + case "active", "waiting", "paused": + return false + case "removed": + return true + default: + util.Log().Warning("下载任务[%s]返回未知状态信息[%s],", monitor.Task.GID, status.Status) + return true + } +} + +// UpdateTaskInfo 更新数据库中的任务信息 +func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error { + monitor.Task.GID = status.Gid + monitor.Task.Status = getStatus(status.Status) + + // 文件大小、已下载大小 + total, err := strconv.ParseUint(status.TotalLength, 10, 64) + if err != nil { + total = 0 + } + downloaded, err := strconv.ParseUint(status.CompletedLength, 10, 64) + if err != nil { + downloaded = 0 + } + monitor.Task.TotalSize = total + monitor.Task.DownloadedSize = downloaded + monitor.Task.GID = status.Gid + monitor.Task.Parent = status.Dir + + // 下载速度 + speed, err := strconv.Atoi(status.DownloadSpeed) + if err != nil { + speed = 0 + } + + monitor.Task.Speed = speed + if len(status.Files) > 0 { + monitor.Task.Path = status.Files[0].Path + } + attrs, _ := json.Marshal(status) + monitor.Task.Attrs = string(attrs) + + return monitor.Task.Save() +} + +// Error 任务下载出错处理,返回是否中断监控 +func (monitor *Monitor) Error(status rpc.StatusInfo) bool { + monitor.setErrorStatus(errors.New(status.ErrorMessage)) + + // 清理临时文件 + monitor.RemoveTempFolder() + + return true +} + +// RemoveTempFile 清理下载临时文件 +func (monitor *Monitor) RemoveTempFile() { + err := os.Remove(monitor.Task.Path) + if err != nil { + util.Log().Warning("无法删除离线下载临时文件[%s], %s", monitor.Task.Path, err) + } + + if empty, _ := util.IsEmpty(monitor.Task.Parent); empty { + err := os.Remove(monitor.Task.Parent) + if err != nil { + util.Log().Warning("无法删除离线下载临时目录[%s], %s", monitor.Task.Parent, err) + } + } +} + +// RemoveTempFolder 清理下载临时目录 +func (monitor *Monitor) RemoveTempFolder() { + err := os.RemoveAll(monitor.Task.Parent) + if err != nil { + util.Log().Warning("无法删除离线下载临时目录[%s], %s", monitor.Task.Parent, err) + } + +} + +// Complete 完成下载,返回是否中断监控 +func (monitor *Monitor) Complete(status rpc.StatusInfo) bool { + // 创建中转任务 + job, err := task.NewTransferTask( + monitor.Task.UserID, + path.Join(monitor.Task.Dst, filepath.Base(monitor.Task.Path)), + monitor.Task.Path, + monitor.Task.Parent, + ) + if err != nil { + monitor.setErrorStatus(err) + return true + } + + // 提交中转任务 + task.TaskPoll.Submit(job) + + // 更新任务ID + monitor.Task.TaskID = job.Model().ID + monitor.Task.Save() + + return true +} + +func (monitor *Monitor) setErrorStatus(err error) { + monitor.Task.Status = Error + monitor.Task.Error = err.Error() + monitor.Task.Save() +} diff --git a/pkg/aria2/aria2.go b/pkg/aria2/aria2.go index 9372181..98af20a 100644 --- a/pkg/aria2/aria2.go +++ b/pkg/aria2/aria2.go @@ -4,16 +4,22 @@ import ( model "github.com/HFO4/cloudreve/models" "github.com/HFO4/cloudreve/pkg/serializer" "github.com/HFO4/cloudreve/pkg/util" + "github.com/zyxar/argo/rpc" "net/url" ) // Instance 默认使用的Aria2处理实例 var Instance Aria2 = &DummyAria2{} +// EventNotifier 任务状态更新通知处理器 +var EventNotifier = &Notifier{} + // Aria2 离线下载处理接口 type Aria2 interface { // CreateTask 创建新的任务 CreateTask(task *model.Download) error + // 返回状态信息 + Status(task *model.Download) (rpc.StatusInfo, error) } const ( @@ -26,6 +32,18 @@ const ( const ( // Ready 准备就绪 Ready = iota + // Downloading 下载中 + Downloading + // Paused 暂停中 + Paused + // Error 出错 + Error + // Complete 完成 + Complete + // Canceled 取消/停止 + Canceled + // Unknown 未知状态 + Unknown ) var ( @@ -42,6 +60,11 @@ func (instance *DummyAria2) CreateTask(task *model.Download) error { return ErrNotEnabled } +// Status 返回未开启错误 +func (instance *DummyAria2) Status(task *model.Download) (rpc.StatusInfo, error) { + return rpc.StatusInfo{}, ErrNotEnabled +} + // Init 初始化 func Init() { options := model.GetSettingByNames("aria2_rpcurl", "aria2_token", "aria2_options") @@ -72,4 +95,31 @@ func Init() { } Instance = client + + // 从数据库中读取未完成任务,创建监控 + unfinished := model.GetDownloadsByStatus(Ready, Paused, Downloading) + for _, task := range unfinished { + // 创建任务监控 + NewMonitor(&task) + } +} + +// getStatus 将给定的状态字符串转换为状态标识数字 +func getStatus(status string) int { + switch status { + case "complete": + return Complete + case "active": + return Downloading + case "waiting": + return Ready + case "paused": + return Paused + case "error": + return Error + case "removed": + return Canceled + default: + return Unknown + } } diff --git a/pkg/aria2/caller.go b/pkg/aria2/caller.go index 70404cb..57fe809 100644 --- a/pkg/aria2/caller.go +++ b/pkg/aria2/caller.go @@ -30,11 +30,16 @@ func (client *RPCService) Init(server, secret string, timeout int, options []int Options: options, } caller, err := rpc.New(context.Background(), server, secret, time.Duration(timeout)*time.Second, - rpc.DummyNotifier{}) + EventNotifier) client.caller = caller return err } +// Status 查询下载状态 +func (client *RPCService) Status(task *model.Download) (rpc.StatusInfo, error) { + return client.caller.TellStatus(task.GID) +} + // CreateTask 创建新任务 func (client *RPCService) CreateTask(task *model.Download) error { // 生成存储路径 @@ -45,7 +50,11 @@ func (client *RPCService) CreateTask(task *model.Download) error { ) // 创建下载任务 - gid, err := client.caller.AddURI(task.Source, map[string]string{"dir": task.Path}) + options := []interface{}{map[string]string{"dir": task.Path}} + if len(client.options.Options) > 0 { + options = append(options, client.options.Options) + } + gid, err := client.caller.AddURI(task.Source, options...) if err != nil || gid == "" { return err } @@ -53,6 +62,12 @@ func (client *RPCService) CreateTask(task *model.Download) error { // 保存到数据库 task.GID = gid _, err = task.Create() + if err != nil { + return err + } - return err + // 创建任务监控 + NewMonitor(task) + + return nil } diff --git a/pkg/aria2/notification.go b/pkg/aria2/notification.go new file mode 100644 index 0000000..e265a60 --- /dev/null +++ b/pkg/aria2/notification.go @@ -0,0 +1,63 @@ +package aria2 + +import ( + "github.com/zyxar/argo/rpc" + "sync" +) + +// Notifier aria2实践通知处理 +type Notifier struct { + Subscribes sync.Map +} + +// Subscribe 订阅事件通知 +func (notifier *Notifier) Subscribe(target chan StatusEvent, gid string) { + notifier.Subscribes.Store(gid, target) +} + +// Unsubscribe 取消订阅事件通知 +func (notifier *Notifier) Unsubscribe(gid string) { + notifier.Subscribes.Delete(gid) +} + +// Notify 发送通知 +func (notifier *Notifier) Notify(events []rpc.Event, status int) { + for _, event := range events { + if target, ok := notifier.Subscribes.Load(event.Gid); ok { + target.(chan StatusEvent) <- StatusEvent{ + GID: event.Gid, + Status: status, + } + } + } +} + +// OnDownloadStart 下载开始 +func (notifier *Notifier) OnDownloadStart(events []rpc.Event) { + notifier.Notify(events, Downloading) +} + +// OnDownloadPause 下载暂停 +func (notifier *Notifier) OnDownloadPause(events []rpc.Event) { + notifier.Notify(events, Paused) +} + +// OnDownloadStop 下载停止 +func (notifier *Notifier) OnDownloadStop(events []rpc.Event) { + notifier.Notify(events, Canceled) +} + +// OnDownloadComplete 下载完成 +func (notifier *Notifier) OnDownloadComplete(events []rpc.Event) { + notifier.Notify(events, Complete) +} + +// OnDownloadError 下载出错 +func (notifier *Notifier) OnDownloadError(events []rpc.Event) { + notifier.Notify(events, Error) +} + +// OnBtDownloadComplete BT下载完成 +func (notifier *Notifier) OnBtDownloadComplete(events []rpc.Event) { + notifier.Notify(events, Complete) +} diff --git a/pkg/task/job.go b/pkg/task/job.go index ede0be7..2379ec3 100644 --- a/pkg/task/job.go +++ b/pkg/task/job.go @@ -11,6 +11,8 @@ const ( CompressTaskType = iota // DecompressTaskType 解压缩任务 DecompressTaskType + // TransferTaskType 中转任务 + TransferTaskType ) // 任务状态 @@ -99,6 +101,8 @@ func GetJobFromModel(task *model.Task) (Job, error) { return NewCompressTaskFromModel(task) case DecompressTaskType: return NewDecompressTaskFromModel(task) + case TransferTaskType: + return NewTransferTaskFromModel(task) default: return nil, ErrUnknownTaskType } diff --git a/pkg/task/tranfer.go b/pkg/task/tranfer.go new file mode 100644 index 0000000..fa52416 --- /dev/null +++ b/pkg/task/tranfer.go @@ -0,0 +1,153 @@ +package task + +import ( + "context" + "encoding/json" + model "github.com/HFO4/cloudreve/models" + "github.com/HFO4/cloudreve/pkg/filesystem" + "github.com/HFO4/cloudreve/pkg/util" + "os" +) + +// TransferTask 文件中转任务 +type TransferTask struct { + User *model.User + TaskModel *model.Task + TaskProps TransferProps + Err *JobError + + zipPath string +} + +// TransferProps 中转任务属性 +type TransferProps struct { + Src string `json:"src"` // 原始目录 + Parent string `json:"parent"` // 父目录 + Dst string `json:"dst"` // 目的目录ID +} + +// Props 获取任务属性 +func (job *TransferTask) Props() string { + res, _ := json.Marshal(job.TaskProps) + return string(res) +} + +// Type 获取任务状态 +func (job *TransferTask) Type() int { + return TransferTaskType +} + +// Creator 获取创建者ID +func (job *TransferTask) Creator() uint { + return job.User.ID +} + +// Model 获取任务的数据库模型 +func (job *TransferTask) Model() *model.Task { + return job.TaskModel +} + +// SetStatus 设定状态 +func (job *TransferTask) SetStatus(status int) { + job.TaskModel.SetStatus(status) +} + +// SetError 设定任务失败信息 +func (job *TransferTask) SetError(err *JobError) { + job.Err = err + res, _ := json.Marshal(job.Err) + job.TaskModel.SetError(string(res)) + +} + +// SetErrorMsg 设定任务失败信息 +func (job *TransferTask) SetErrorMsg(msg string, err error) { + jobErr := &JobError{Msg: msg} + if err != nil { + jobErr.Error = err.Error() + } + job.SetError(jobErr) +} + +// GetError 返回任务失败信息 +func (job *TransferTask) GetError() *JobError { + return job.Err +} + +// Do 开始执行任务 +func (job *TransferTask) Do() { + defer job.Recycle() + + // 创建文件系统 + fs, err := filesystem.NewFileSystem(job.User) + if err != nil { + job.SetErrorMsg(err.Error(), nil) + return + } + defer fs.Recycle() + + err = fs.UploadFromPath(context.Background(), job.TaskProps.Src, job.TaskProps.Dst) + if err != nil { + job.SetErrorMsg("文件转存失败", err) + return + } +} + +// Recycle 回收临时文件 +func (job *TransferTask) Recycle() { + err := os.Remove(job.TaskProps.Src) + if err != nil { + util.Log().Warning("无法删除中转临时文件[%s], %s", job.TaskProps.Src, err) + } + + if empty, _ := util.IsEmpty(job.TaskProps.Parent); empty { + err := os.Remove(job.TaskProps.Parent) + if err != nil { + util.Log().Warning("无法删除中转临时目录[%s], %s", job.TaskProps.Parent, err) + } + } +} + +// NewTransferTask 新建中转任务 +func NewTransferTask(user uint, dst, src, parent string) (Job, error) { + creator, err := model.GetUserByID(user) + if err != nil { + return nil, err + } + + newTask := &TransferTask{ + User: &creator, + TaskProps: TransferProps{ + Src: src, + Parent: parent, + Dst: dst, + }, + } + + record, err := Record(newTask) + if err != nil { + return nil, err + } + newTask.TaskModel = record + + return newTask, nil +} + +// NewTransferTaskFromModel 从数据库记录中恢复中转任务 +func NewTransferTaskFromModel(task *model.Task) (Job, error) { + user, err := model.GetUserByID(task.UserID) + if err != nil { + return nil, err + } + newTask := &TransferTask{ + User: &user, + TaskModel: task, + } + + err = json.Unmarshal([]byte(task.Props), &newTask.TaskProps) + if err != nil { + return nil, err + } + + return newTask, nil +} diff --git a/pkg/util/io.go b/pkg/util/io.go index a91db57..25b9dc9 100644 --- a/pkg/util/io.go +++ b/pkg/util/io.go @@ -1,6 +1,7 @@ package util import ( + "io" "os" "path/filepath" ) @@ -28,3 +29,18 @@ func CreatNestedFile(path string) (*os.File, error) { return os.Create(path) } + +// IsEmpty 返回给定目录是否为空目录 +func IsEmpty(name string) (bool, error) { + f, err := os.Open(name) + if err != nil { + return false, err + } + defer f.Close() + + _, err = f.Readdirnames(1) // Or f.Readdir(1) + if err == io.EOF { + return true, nil + } + return false, err // Either not empty or error, suits both cases +} diff --git a/service/aria2/add.go b/service/aria2/add.go index 4e91e96..763b9b2 100644 --- a/service/aria2/add.go +++ b/service/aria2/add.go @@ -29,21 +29,17 @@ func (service *AddURLService) Add(c *gin.Context) serializer.Response { } // 存放目录是否存在 - var ( - exist bool - parent *model.Folder - ) - if exist, parent = fs.IsPathExist(service.Dst); !exist { + if exist, _ := fs.IsPathExist(service.Dst); !exist { return serializer.Err(serializer.CodeNotFound, "存放路径不存在", nil) } // 创建任务 task := &model.Download{ - Status: aria2.Ready, - Type: aria2.URLTask, - FolderID: parent.ID, - UserID: fs.User.ID, - Source: service.URL, + Status: aria2.Ready, + Type: aria2.URLTask, + Dst: service.Dst, + UserID: fs.User.ID, + Source: service.URL, } if err := aria2.Instance.CreateTask(task); err != nil { return serializer.Err(serializer.CodeNotSet, "任务创建失败", err)