From b1685d2863d84658be3e142b5a9952eb149ee246 Mon Sep 17 00:00:00 2001 From: XYenon Date: Thu, 29 Sep 2022 09:24:58 +0800 Subject: [PATCH] feat: seeding status for aria2 download tasks (#1422) * feat: add aria2 seeding * fix: move RecycleTaskType to the bottom * refactor: refactor recycle aria2 temp file --- assets | 2 +- pkg/aria2/aria2.go | 6 +- pkg/aria2/common/common.go | 9 +- pkg/aria2/common/common_test.go | 25 ++-- pkg/aria2/monitor/monitor.go | 48 ++++++- pkg/aria2/monitor/monitor_test.go | 11 +- pkg/aria2/rpc/resp.go | 60 ++++---- .../driver/shadow/slaveinmaster/handler.go | 7 +- pkg/serializer/slave.go | 1 + pkg/serializer/slave_test.go | 14 +- pkg/task/job.go | 4 + pkg/task/job_test.go | 14 +- pkg/task/recycle.go | 130 ++++++++++++++++++ pkg/task/recycle_test.go | 117 ++++++++++++++++ pkg/task/slavetask/transfer.go | 13 +- pkg/task/tranfer.go | 13 -- service/aria2/manage.go | 4 +- service/explorer/slave.go | 7 +- service/user/register.go | 5 +- 19 files changed, 403 insertions(+), 87 deletions(-) create mode 100644 pkg/task/recycle.go create mode 100644 pkg/task/recycle_test.go diff --git a/assets b/assets index a1028e7..02d9320 160000 --- a/assets +++ b/assets @@ -1 +1 @@ -Subproject commit a1028e7e0ae96be4bb67d8c117cf39e07c207473 +Subproject commit 02d93206cc5b943c34b5f5ac86c23dd96f5ef603 diff --git a/pkg/aria2/aria2.go b/pkg/aria2/aria2.go index 60d254e..f91766f 100644 --- a/pkg/aria2/aria2.go +++ b/pkg/aria2/aria2.go @@ -3,8 +3,6 @@ package aria2 import ( "context" "fmt" - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" "net/url" "sync" "time" @@ -14,6 +12,8 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/balancer" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" ) // Instance 默认使用的Aria2处理实例 @@ -40,7 +40,7 @@ func Init(isReload bool, pool cluster.Pool, mqClient mq.MQ) { if !isReload { // 从数据库中读取未完成任务,创建监控 - unfinished := model.GetDownloadsByStatus(common.Ready, common.Paused, common.Downloading) + unfinished := model.GetDownloadsByStatus(common.Ready, common.Paused, common.Downloading, common.Seeding) for i := 0; i < len(unfinished); i++ { // 创建任务监控 diff --git a/pkg/aria2/common/common.go b/pkg/aria2/common/common.go index 8f281d8..455c89f 100644 --- a/pkg/aria2/common/common.go +++ b/pkg/aria2/common/common.go @@ -46,6 +46,8 @@ const ( Canceled // Unknown 未知状态 Unknown + // Seeding 做种中 + Seeding ) var ( @@ -94,11 +96,14 @@ func (instance *DummyAria2) DeleteTempFile(src *model.Download) error { } // GetStatus 将给定的状态字符串转换为状态标识数字 -func GetStatus(status string) int { - switch status { +func GetStatus(status rpc.StatusInfo) int { + switch status.Status { case "complete": return Complete case "active": + if status.BitTorrent.Mode != "" && status.CompletedLength == status.TotalLength { + return Seeding + } return Downloading case "waiting": return Ready diff --git a/pkg/aria2/common/common_test.go b/pkg/aria2/common/common_test.go index a93f5f8..7b0f237 100644 --- a/pkg/aria2/common/common_test.go +++ b/pkg/aria2/common/common_test.go @@ -1,9 +1,11 @@ package common import ( + "testing" + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/stretchr/testify/assert" - "testing" ) func TestDummyAria2(t *testing.T) { @@ -35,11 +37,18 @@ func TestDummyAria2(t *testing.T) { func TestGetStatus(t *testing.T) { a := assert.New(t) - a.Equal(GetStatus("complete"), Complete) - a.Equal(GetStatus("active"), Downloading) - a.Equal(GetStatus("waiting"), Ready) - a.Equal(GetStatus("paused"), Paused) - a.Equal(GetStatus("error"), Error) - a.Equal(GetStatus("removed"), Canceled) - a.Equal(GetStatus("unknown"), Unknown) + a.Equal(GetStatus(rpc.StatusInfo{Status: "complete"}), Complete) + a.Equal(GetStatus(rpc.StatusInfo{Status: "active", + BitTorrent: rpc.BitTorrentInfo{Mode: ""}}), Downloading) + a.Equal(GetStatus(rpc.StatusInfo{Status: "active", + BitTorrent: rpc.BitTorrentInfo{Mode: "single"}, + TotalLength: "100", CompletedLength: "50"}), Downloading) + a.Equal(GetStatus(rpc.StatusInfo{Status: "active", + BitTorrent: rpc.BitTorrentInfo{Mode: "multi"}, + TotalLength: "100", CompletedLength: "100"}), Seeding) + a.Equal(GetStatus(rpc.StatusInfo{Status: "waiting"}), Ready) + a.Equal(GetStatus(rpc.StatusInfo{Status: "paused"}), Paused) + a.Equal(GetStatus(rpc.StatusInfo{Status: "error"}), Error) + a.Equal(GetStatus(rpc.StatusInfo{Status: "removed"}), Canceled) + a.Equal(GetStatus(rpc.StatusInfo{Status: "unknown"}), Unknown) } diff --git a/pkg/aria2/monitor/monitor.go b/pkg/aria2/monitor/monitor.go index a515b66..531d6ed 100644 --- a/pkg/aria2/monitor/monitor.go +++ b/pkg/aria2/monitor/monitor.go @@ -109,14 +109,14 @@ func (monitor *Monitor) Update() bool { util.Log().Debug("离线下载[%s]更新状态[%s]", status.Gid, status.Status) - switch status.Status { - case "complete": + switch common.GetStatus(status) { + case common.Complete, common.Seeding: return monitor.Complete(task.TaskPoll) - case "error": + case common.Error: return monitor.Error(status) - case "active", "waiting", "paused": + case common.Downloading, common.Ready, common.Paused: return false - case "removed": + case common.Canceled: monitor.Task.Status = common.Canceled monitor.Task.Save() monitor.RemoveTempFolder() @@ -132,7 +132,7 @@ func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error { originSize := monitor.Task.TotalSize monitor.Task.GID = status.Gid - monitor.Task.Status = common.GetStatus(status.Status) + monitor.Task.Status = common.GetStatus(status) // 文件大小、已下载大小 total, err := strconv.ParseUint(status.TotalLength, 10, 64) @@ -235,6 +235,40 @@ func (monitor *Monitor) RemoveTempFolder() { // Complete 完成下载,返回是否中断监控 func (monitor *Monitor) Complete(pool task.Pool) bool { + // 未开始转存,提交转存任务 + if monitor.Task.TaskID == 0 { + return monitor.transfer(pool) + } + + // 做种完成 + if common.GetStatus(monitor.Task.StatusInfo) == common.Complete { + transferTask, err := model.GetTasksByID(monitor.Task.TaskID) + if err != nil { + monitor.setErrorStatus(err) + monitor.RemoveTempFolder() + return true + } + + // 转存完成,回收下载目录 + if transferTask.Type == task.TransferTaskType && transferTask.Status >= task.Error { + job, err := task.NewRecycleTask(monitor.Task) + if err != nil { + monitor.setErrorStatus(err) + monitor.RemoveTempFolder() + return true + } + + // 提交回收任务 + pool.Submit(job) + + return true + } + } + + return false +} + +func (monitor *Monitor) transfer(pool task.Pool) bool { // 创建中转任务 file := make([]string, 0, len(monitor.Task.StatusInfo.Files)) sizes := make(map[string]uint64, len(monitor.Task.StatusInfo.Files)) @@ -269,7 +303,7 @@ func (monitor *Monitor) Complete(pool task.Pool) bool { monitor.Task.TaskID = job.Model().ID monitor.Task.Save() - return true + return false } func (monitor *Monitor) setErrorStatus(err error) { diff --git a/pkg/aria2/monitor/monitor_test.go b/pkg/aria2/monitor/monitor_test.go index 885484a..a6be586 100644 --- a/pkg/aria2/monitor/monitor_test.go +++ b/pkg/aria2/monitor/monitor_test.go @@ -3,6 +3,8 @@ package monitor import ( "database/sql" "errors" + "testing" + "github.com/DATA-DOG/go-sqlmock" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" @@ -13,7 +15,6 @@ import ( "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" testMock "github.com/stretchr/testify/mock" - "testing" ) var mock sqlmock.Sqlmock @@ -431,6 +432,14 @@ func TestMonitor_Complete(t *testing.T) { mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectCommit() + mock.ExpectQuery("SELECT(.+)tasks").WillReturnRows(sqlmock.NewRows([]string{"id", "type", "status"}).AddRow(1, 2, 4)) + mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(9414)) + mock.ExpectBegin() + mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(2, 1)) + mock.ExpectCommit() + + a.False(m.Complete(mockPool)) + m.Task.StatusInfo.Status = "complete" a.True(m.Complete(mockPool)) a.NoError(mock.ExpectationsWereMet()) mockNode.AssertExpectations(t) diff --git a/pkg/aria2/rpc/resp.go b/pkg/aria2/rpc/resp.go index e685ce6..3614228 100644 --- a/pkg/aria2/rpc/resp.go +++ b/pkg/aria2/rpc/resp.go @@ -4,35 +4,27 @@ package rpc // StatusInfo represents response of aria2.tellStatus type StatusInfo struct { - Gid string `json:"gid"` // GID of the download. - Status string `json:"status"` // active for currently downloading/seeding downloads. waiting for downloads in the queue; download is not started. paused for paused downloads. error for downloads that were stopped because of error. complete for stopped and completed downloads. removed for the downloads removed by user. - TotalLength string `json:"totalLength"` // Total length of the download in bytes. - CompletedLength string `json:"completedLength"` // Completed length of the download in bytes. - UploadLength string `json:"uploadLength"` // Uploaded length of the download in bytes. - BitField string `json:"bitfield"` // Hexadecimal representation of the download progress. The highest bit corresponds to the piece at index 0. Any set bits indicate loaded pieces, while unset bits indicate not yet loaded and/or missing pieces. Any overflow bits at the end are set to zero. When the download was not started yet, this key will not be included in the response. - DownloadSpeed string `json:"downloadSpeed"` // Download speed of this download measured in bytes/sec. - UploadSpeed string `json:"uploadSpeed"` // LocalUpload speed of this download measured in bytes/sec. - InfoHash string `json:"infoHash"` // InfoHash. BitTorrent only. - NumSeeders string `json:"numSeeders"` // The number of seeders aria2 has connected to. BitTorrent only. - Seeder string `json:"seeder"` // true if the local endpoint is a seeder. Otherwise false. BitTorrent only. - PieceLength string `json:"pieceLength"` // Piece length in bytes. - NumPieces string `json:"numPieces"` // The number of pieces. - Connections string `json:"connections"` // The number of peers/servers aria2 has connected to. - ErrorCode string `json:"errorCode"` // The code of the last error for this item, if any. The value is a string. The error codes are defined in the EXIT STATUS section. This value is only available for stopped/completed downloads. - ErrorMessage string `json:"errorMessage"` // The (hopefully) human readable error message associated to errorCode. - FollowedBy []string `json:"followedBy"` // List of GIDs which are generated as the result of this download. For example, when aria2 downloads a Metalink file, it generates downloads described in the Metalink (see the --follow-metalink option). This value is useful to track auto-generated downloads. If there are no such downloads, this key will not be included in the response. - BelongsTo string `json:"belongsTo"` // GID of a parent download. Some downloads are a part of another download. For example, if a file in a Metalink has BitTorrent resources, the downloads of ".torrent" files are parts of that parent. If this download has no parent, this key will not be included in the response. - Dir string `json:"dir"` // Directory to save files. - Files []FileInfo `json:"files"` // Returns the list of files. The elements of this list are the same structs used in aria2.getFiles() method. - BitTorrent struct { - AnnounceList [][]string `json:"announceList"` // List of lists of announce URIs. If the torrent contains announce and no announce-list, announce is converted to the announce-list format. - Comment string `json:"comment"` // The comment of the torrent. comment.utf-8 is used if available. - CreationDate int64 `json:"creationDate"` // The creation time of the torrent. The value is an integer since the epoch, measured in seconds. - Mode string `json:"mode"` // File mode of the torrent. The value is either single or multi. - Info struct { - Name string `json:"name"` // name in info dictionary. name.utf-8 is used if available. - } `json:"info"` // Struct which contains data from Info dictionary. It contains following keys. - } `json:"bittorrent"` // Struct which contains information retrieved from the .torrent (file). BitTorrent only. It contains following keys. + Gid string `json:"gid"` // GID of the download. + Status string `json:"status"` // active for currently downloading/seeding downloads. waiting for downloads in the queue; download is not started. paused for paused downloads. error for downloads that were stopped because of error. complete for stopped and completed downloads. removed for the downloads removed by user. + TotalLength string `json:"totalLength"` // Total length of the download in bytes. + CompletedLength string `json:"completedLength"` // Completed length of the download in bytes. + UploadLength string `json:"uploadLength"` // Uploaded length of the download in bytes. + BitField string `json:"bitfield"` // Hexadecimal representation of the download progress. The highest bit corresponds to the piece at index 0. Any set bits indicate loaded pieces, while unset bits indicate not yet loaded and/or missing pieces. Any overflow bits at the end are set to zero. When the download was not started yet, this key will not be included in the response. + DownloadSpeed string `json:"downloadSpeed"` // Download speed of this download measured in bytes/sec. + UploadSpeed string `json:"uploadSpeed"` // LocalUpload speed of this download measured in bytes/sec. + InfoHash string `json:"infoHash"` // InfoHash. BitTorrent only. + NumSeeders string `json:"numSeeders"` // The number of seeders aria2 has connected to. BitTorrent only. + Seeder string `json:"seeder"` // true if the local endpoint is a seeder. Otherwise false. BitTorrent only. + PieceLength string `json:"pieceLength"` // Piece length in bytes. + NumPieces string `json:"numPieces"` // The number of pieces. + Connections string `json:"connections"` // The number of peers/servers aria2 has connected to. + ErrorCode string `json:"errorCode"` // The code of the last error for this item, if any. The value is a string. The error codes are defined in the EXIT STATUS section. This value is only available for stopped/completed downloads. + ErrorMessage string `json:"errorMessage"` // The (hopefully) human readable error message associated to errorCode. + FollowedBy []string `json:"followedBy"` // List of GIDs which are generated as the result of this download. For example, when aria2 downloads a Metalink file, it generates downloads described in the Metalink (see the --follow-metalink option). This value is useful to track auto-generated downloads. If there are no such downloads, this key will not be included in the response. + BelongsTo string `json:"belongsTo"` // GID of a parent download. Some downloads are a part of another download. For example, if a file in a Metalink has BitTorrent resources, the downloads of ".torrent" files are parts of that parent. If this download has no parent, this key will not be included in the response. + Dir string `json:"dir"` // Directory to save files. + Files []FileInfo `json:"files"` // Returns the list of files. The elements of this list are the same structs used in aria2.getFiles() method. + BitTorrent BitTorrentInfo `json:"bittorrent"` // Struct which contains information retrieved from the .torrent (file). BitTorrent only. It contains following keys. } // URIInfo represents an element of response of aria2.getUris @@ -100,3 +92,13 @@ type Method struct { Name string `json:"methodName"` // Method name to call Params []interface{} `json:"params"` // Array containing parameters to the method call } + +type BitTorrentInfo struct { + AnnounceList [][]string `json:"announceList"` // List of lists of announce URIs. If the torrent contains announce and no announce-list, announce is converted to the announce-list format. + Comment string `json:"comment"` // The comment of the torrent. comment.utf-8 is used if available. + CreationDate int64 `json:"creationDate"` // The creation time of the torrent. The value is an integer since the epoch, measured in seconds. + Mode string `json:"mode"` // File mode of the torrent. The value is either single or multi. + Info struct { + Name string `json:"name"` // name in info dictionary. name.utf-8 is used if available. + } `json:"info"` // Struct which contains data from Info dictionary. It contains following keys. +} diff --git a/pkg/filesystem/driver/shadow/slaveinmaster/handler.go b/pkg/filesystem/driver/shadow/slaveinmaster/handler.go index 7fc7b09..4dd9da8 100644 --- a/pkg/filesystem/driver/shadow/slaveinmaster/handler.go +++ b/pkg/filesystem/driver/shadow/slaveinmaster/handler.go @@ -5,6 +5,9 @@ import ( "context" "encoding/json" "errors" + "net/url" + "time" + model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" @@ -13,8 +16,6 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "net/url" - "time" ) // Driver 影子存储策略,将上传任务指派给从机节点处理,并等待从机通知上传结果 @@ -118,6 +119,6 @@ func (d *Driver) List(ctx context.Context, path string, recursive bool) ([]respo } // 取消上传凭证 -func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error { +func (d *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error { return nil } diff --git a/pkg/serializer/slave.go b/pkg/serializer/slave.go index 245767a..04d56d3 100644 --- a/pkg/serializer/slave.go +++ b/pkg/serializer/slave.go @@ -4,6 +4,7 @@ import ( "crypto/sha1" "encoding/gob" "fmt" + model "github.com/cloudreve/Cloudreve/v3/models" ) diff --git a/pkg/serializer/slave_test.go b/pkg/serializer/slave_test.go index 6471542..add3a63 100644 --- a/pkg/serializer/slave_test.go +++ b/pkg/serializer/slave_test.go @@ -1,9 +1,10 @@ package serializer import ( + "testing" + model "github.com/cloudreve/Cloudreve/v3/models" "github.com/stretchr/testify/assert" - "testing" ) func TestSlaveTransferReq_Hash(t *testing.T) { @@ -18,3 +19,14 @@ func TestSlaveTransferReq_Hash(t *testing.T) { } a.NotEqual(s1.Hash("1"), s2.Hash("1")) } + +func TestSlaveRecycleReq_Hash(t *testing.T) { + a := assert.New(t) + s1 := &SlaveRecycleReq{ + Path: "1", + } + s2 := &SlaveRecycleReq{ + Path: "2", + } + a.NotEqual(s1.Hash("1"), s2.Hash("1")) +} diff --git a/pkg/task/job.go b/pkg/task/job.go index 781c460..e9d54d8 100644 --- a/pkg/task/job.go +++ b/pkg/task/job.go @@ -15,6 +15,8 @@ const ( TransferTaskType // ImportTaskType 导入任务 ImportTaskType + // RecycleTaskType 回收任务 + RecycleTaskType ) // 任务状态 @@ -113,6 +115,8 @@ func GetJobFromModel(task *model.Task) (Job, error) { return NewTransferTaskFromModel(task) case ImportTaskType: return NewImportTaskFromModel(task) + case RecycleTaskType: + return NewRecycleTaskFromModel(task) default: return nil, ErrUnknownTaskType } diff --git a/pkg/task/job_test.go b/pkg/task/job_test.go index 81793ee..737f5b7 100644 --- a/pkg/task/job_test.go +++ b/pkg/task/job_test.go @@ -2,12 +2,12 @@ package task import ( "errors" - testMock "github.com/stretchr/testify/mock" "testing" "github.com/DATA-DOG/go-sqlmock" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/stretchr/testify/assert" + testMock "github.com/stretchr/testify/mock" ) func TestRecord(t *testing.T) { @@ -103,4 +103,16 @@ func TestGetJobFromModel(t *testing.T) { asserts.Nil(job) asserts.Error(err) } + // RecycleTaskType + { + task := &model.Task{ + Status: 0, + Type: RecycleTaskType, + } + mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error")) + job, err := GetJobFromModel(task) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Nil(job) + asserts.Error(err) + } } diff --git a/pkg/task/recycle.go b/pkg/task/recycle.go new file mode 100644 index 0000000..17eaf3c --- /dev/null +++ b/pkg/task/recycle.go @@ -0,0 +1,130 @@ +package task + +import ( + "encoding/json" + + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" + "github.com/cloudreve/Cloudreve/v3/pkg/util" +) + +// RecycleTask 文件回收任务 +type RecycleTask struct { + User *model.User + TaskModel *model.Task + TaskProps RecycleProps + Err *JobError +} + +// RecycleProps 回收任务属性 +type RecycleProps struct { + // 下载任务 GID + DownloadGID string `json:"download_gid"` +} + +// Props 获取任务属性 +func (job *RecycleTask) Props() string { + res, _ := json.Marshal(job.TaskProps) + return string(res) +} + +// Type 获取任务状态 +func (job *RecycleTask) Type() int { + return RecycleTaskType +} + +// Creator 获取创建者ID +func (job *RecycleTask) Creator() uint { + return job.User.ID +} + +// Model 获取任务的数据库模型 +func (job *RecycleTask) Model() *model.Task { + return job.TaskModel +} + +// SetStatus 设定状态 +func (job *RecycleTask) SetStatus(status int) { + job.TaskModel.SetStatus(status) +} + +// SetError 设定任务失败信息 +func (job *RecycleTask) SetError(err *JobError) { + job.Err = err + res, _ := json.Marshal(job.Err) + job.TaskModel.SetError(string(res)) +} + +// SetErrorMsg 设定任务失败信息 +func (job *RecycleTask) SetErrorMsg(msg string, err error) { + jobErr := &JobError{Msg: msg} + if err != nil { + jobErr.Error = err.Error() + } + job.SetError(jobErr) +} + +// GetError 返回任务失败信息 +func (job *RecycleTask) GetError() *JobError { + return job.Err +} + +// Do 开始执行任务 +func (job *RecycleTask) Do() { + download, err := model.GetDownloadByGid(job.TaskProps.DownloadGID, job.User.ID) + if err != nil { + util.Log().Warning("回收任务 %d 找不到下载记录", job.TaskModel.ID) + job.SetErrorMsg("无法找到下载任务", err) + return + } + nodeID := download.GetNodeID() + node := cluster.Default.GetNodeByID(nodeID) + if node == nil { + util.Log().Warning("回收任务 %d 找不到节点", job.TaskModel.ID) + job.SetErrorMsg("从机节点不可用", nil) + return + } + err = node.GetAria2Instance().DeleteTempFile(download) + if err != nil { + util.Log().Warning("无法删除中转临时目录[%s], %s", download.Parent, err) + job.SetErrorMsg("文件回收失败", err) + return + } +} + +// NewRecycleTask 新建回收任务 +func NewRecycleTask(download *model.Download) (Job, error) { + newTask := &RecycleTask{ + User: download.GetOwner(), + TaskProps: RecycleProps{ + DownloadGID: download.GID, + }, + } + + record, err := Record(newTask) + if err != nil { + return nil, err + } + newTask.TaskModel = record + + return newTask, nil +} + +// NewRecycleTaskFromModel 从数据库记录中恢复回收任务 +func NewRecycleTaskFromModel(task *model.Task) (Job, error) { + user, err := model.GetActiveUserByID(task.UserID) + if err != nil { + return nil, err + } + newTask := &RecycleTask{ + User: &user, + TaskModel: task, + } + + err = json.Unmarshal([]byte(task.Props), &newTask.TaskProps) + if err != nil { + return nil, err + } + + return newTask, nil +} diff --git a/pkg/task/recycle_test.go b/pkg/task/recycle_test.go new file mode 100644 index 0000000..0092a30 --- /dev/null +++ b/pkg/task/recycle_test.go @@ -0,0 +1,117 @@ +package task + +import ( + "errors" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/jinzhu/gorm" + "github.com/stretchr/testify/assert" +) + +func TestRecycleTask_Props(t *testing.T) { + asserts := assert.New(t) + task := &RecycleTask{ + User: &model.User{}, + } + asserts.NotEmpty(task.Props()) + asserts.Equal(RecycleTaskType, task.Type()) + asserts.EqualValues(0, task.Creator()) + asserts.Nil(task.Model()) +} + +func TestRecycleTask_SetStatus(t *testing.T) { + asserts := assert.New(t) + task := &RecycleTask{ + User: &model.User{}, + TaskModel: &model.Task{ + Model: gorm.Model{ID: 1}, + }, + } + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + task.SetStatus(3) + asserts.NoError(mock.ExpectationsWereMet()) +} + +func TestRecycleTask_SetError(t *testing.T) { + asserts := assert.New(t) + task := &RecycleTask{ + User: &model.User{}, + TaskModel: &model.Task{ + Model: gorm.Model{ID: 1}, + }, + } + + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + task.SetErrorMsg("error", nil) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Equal("error", task.GetError().Msg) +} + +func TestNewRecycleTask(t *testing.T) { + asserts := assert.New(t) + + // 成功 + { + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + mock.ExpectBegin() + mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + job, err := NewRecycleTask(&model.Download{ + Model: gorm.Model{ID: 1}, + GID: "test_g_id", + Parent: "/", + UserID: 1, + NodeID: 1, + }) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.NotNil(job) + asserts.NoError(err) + } + + // 失败 + { + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + mock.ExpectBegin() + mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) + mock.ExpectRollback() + job, err := NewRecycleTask(&model.Download{ + Model: gorm.Model{ID: 1}, + GID: "test_g_id", + Parent: "test/not_exist", + UserID: 1, + NodeID: 1, + }) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Nil(job) + asserts.Error(err) + } +} + +func TestNewRecycleTaskFromModel(t *testing.T) { + asserts := assert.New(t) + + // 成功 + { + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + job, err := NewRecycleTaskFromModel(&model.Task{Props: "{}"}) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(err) + asserts.NotNil(job) + } + + // JSON解析失败 + { + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + job, err := NewRecycleTaskFromModel(&model.Task{Props: "?"}) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Error(err) + asserts.Nil(job) + } +} diff --git a/pkg/task/slavetask/transfer.go b/pkg/task/slavetask/transfer.go index 20c5fcc..818028e 100644 --- a/pkg/task/slavetask/transfer.go +++ b/pkg/task/slavetask/transfer.go @@ -2,6 +2,8 @@ package slavetask import ( "context" + "os" + model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" @@ -10,7 +12,6 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/task" "github.com/cloudreve/Cloudreve/v3/pkg/util" - "os" ) // TransferTask 文件中转任务 @@ -79,8 +80,6 @@ func (job *TransferTask) GetError() *task.JobError { // Do 开始执行任务 func (job *TransferTask) Do() { - defer job.Recycle() - fs, err := filesystem.NewAnonymousFileSystem() if err != nil { job.SetErrorMsg("无法初始化匿名文件系统", err) @@ -137,11 +136,3 @@ func (job *TransferTask) Do() { util.Log().Warning("无法发送转存成功通知到从机, %s", err) } } - -// Recycle 回收临时文件 -func (job *TransferTask) Recycle() { - err := os.Remove(job.Req.Src) - if err != nil { - util.Log().Warning("无法删除中转临时文件[%s], %s", job.Req.Src, err) - } -} diff --git a/pkg/task/tranfer.go b/pkg/task/tranfer.go index 5f9aa58..f115e80 100644 --- a/pkg/task/tranfer.go +++ b/pkg/task/tranfer.go @@ -3,7 +3,6 @@ package task import ( "context" "encoding/json" - "os" "path" "path/filepath" "strings" @@ -87,8 +86,6 @@ func (job *TransferTask) GetError() *JobError { // Do 开始执行任务 func (job *TransferTask) Do() { - defer job.Recycle() - // 创建文件系统 fs, err := filesystem.NewFileSystem(job.User) if err != nil { @@ -139,16 +136,6 @@ func (job *TransferTask) Do() { } -// Recycle 回收临时文件 -func (job *TransferTask) Recycle() { - if job.TaskProps.NodeID == 1 { - err := os.RemoveAll(job.TaskProps.Parent) - if err != nil { - util.Log().Warning("无法删除中转临时目录[%s], %s", job.TaskProps.Parent, err) - } - } -} - // NewTransferTask 新建中转任务 func NewTransferTask(user uint, src []string, dst, parent string, trim bool, node uint, sizes map[string]uint64) (Job, error) { creator, err := model.GetActiveUserByID(user) diff --git a/service/aria2/manage.go b/service/aria2/manage.go index 6344ddd..115a440 100644 --- a/service/aria2/manage.go +++ b/service/aria2/manage.go @@ -33,7 +33,7 @@ func (service *DownloadListService) Finished(c *gin.Context, user *model.User) s // Downloading 获取正在下载中的任务 func (service *DownloadListService) Downloading(c *gin.Context, user *model.User) serializer.Response { // 查找下载记录 - downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, common.Downloading, common.Paused, common.Ready) + downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, common.Downloading, common.Seeding, common.Paused, common.Ready) intervals := make(map[uint]int) for _, download := range downloads { if _, ok := intervals[download.ID]; !ok { @@ -57,7 +57,7 @@ func (service *DownloadTaskService) Delete(c *gin.Context) serializer.Response { return serializer.Err(serializer.CodeNotFound, "Download record not found", err) } - if download.Status >= common.Error { + if download.Status >= common.Error && download.Status <= common.Unknown { // 如果任务已完成,则删除任务记录 if err := download.Delete(); err != nil { return serializer.DBErr("Failed to delete task record", err) diff --git a/service/explorer/slave.go b/service/explorer/slave.go index 1435640..afb61af 100644 --- a/service/explorer/slave.go +++ b/service/explorer/slave.go @@ -5,6 +5,10 @@ import ( "encoding/base64" "encoding/json" "fmt" + "net/http" + "net/url" + "time" + model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" @@ -16,9 +20,6 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/gin-gonic/gin" "github.com/jinzhu/gorm" - "net/http" - "net/url" - "time" ) // SlaveDownloadService 从机文件下載服务 diff --git a/service/user/register.go b/service/user/register.go index d3c81b5..35e8253 100644 --- a/service/user/register.go +++ b/service/user/register.go @@ -1,14 +1,15 @@ package user import ( + "net/url" + "strings" + model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/email" "github.com/cloudreve/Cloudreve/v3/pkg/hashid" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/gin-gonic/gin" - "net/url" - "strings" ) // UserRegisterService 管理用户注册的服务