diff --git a/assets b/assets index a1028e7e..02d93206 160000 --- a/assets +++ b/assets @@ -1 +1 @@ -Subproject commit a1028e7e0ae96be4bb67d8c117cf39e07c207473 +Subproject commit 02d93206cc5b943c34b5f5ac86c23dd96f5ef603 diff --git a/pkg/aria2/aria2.go b/pkg/aria2/aria2.go index 60d254e5..f91766fa 100644 --- a/pkg/aria2/aria2.go +++ b/pkg/aria2/aria2.go @@ -3,8 +3,6 @@ package aria2 import ( "context" "fmt" - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" "net/url" "sync" "time" @@ -14,6 +12,8 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/balancer" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" ) // Instance 默认使用的Aria2处理实例 @@ -40,7 +40,7 @@ func Init(isReload bool, pool cluster.Pool, mqClient mq.MQ) { if !isReload { // 从数据库中读取未完成任务,创建监控 - unfinished := model.GetDownloadsByStatus(common.Ready, common.Paused, common.Downloading) + unfinished := model.GetDownloadsByStatus(common.Ready, common.Paused, common.Downloading, common.Seeding) for i := 0; i < len(unfinished); i++ { // 创建任务监控 diff --git a/pkg/aria2/common/common.go b/pkg/aria2/common/common.go index 8f281d81..d4a8313d 100644 --- a/pkg/aria2/common/common.go +++ b/pkg/aria2/common/common.go @@ -38,6 +38,8 @@ const ( Downloading // Paused 暂停中 Paused + // Seeding 做种中 + Seeding // Error 出错 Error // Complete 完成 @@ -94,11 +96,14 @@ func (instance *DummyAria2) DeleteTempFile(src *model.Download) error { } // GetStatus 将给定的状态字符串转换为状态标识数字 -func GetStatus(status string) int { - switch status { +func GetStatus(status rpc.StatusInfo) int { + switch status.Status { case "complete": return Complete case "active": + if status.BitTorrent.Mode != "" && status.CompletedLength == status.TotalLength { + return Seeding + } return Downloading case "waiting": return Ready diff --git a/pkg/aria2/common/common_test.go b/pkg/aria2/common/common_test.go index a93f5f80..7b0f2378 100644 --- a/pkg/aria2/common/common_test.go +++ b/pkg/aria2/common/common_test.go @@ -1,9 +1,11 @@ package common import ( + "testing" + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/stretchr/testify/assert" - "testing" ) func TestDummyAria2(t *testing.T) { @@ -35,11 +37,18 @@ func TestDummyAria2(t *testing.T) { func TestGetStatus(t *testing.T) { a := assert.New(t) - a.Equal(GetStatus("complete"), Complete) - a.Equal(GetStatus("active"), Downloading) - a.Equal(GetStatus("waiting"), Ready) - a.Equal(GetStatus("paused"), Paused) - a.Equal(GetStatus("error"), Error) - a.Equal(GetStatus("removed"), Canceled) - a.Equal(GetStatus("unknown"), Unknown) + a.Equal(GetStatus(rpc.StatusInfo{Status: "complete"}), Complete) + a.Equal(GetStatus(rpc.StatusInfo{Status: "active", + BitTorrent: rpc.BitTorrentInfo{Mode: ""}}), Downloading) + a.Equal(GetStatus(rpc.StatusInfo{Status: "active", + BitTorrent: rpc.BitTorrentInfo{Mode: "single"}, + TotalLength: "100", CompletedLength: "50"}), Downloading) + a.Equal(GetStatus(rpc.StatusInfo{Status: "active", + BitTorrent: rpc.BitTorrentInfo{Mode: "multi"}, + TotalLength: "100", CompletedLength: "100"}), Seeding) + a.Equal(GetStatus(rpc.StatusInfo{Status: "waiting"}), Ready) + a.Equal(GetStatus(rpc.StatusInfo{Status: "paused"}), Paused) + a.Equal(GetStatus(rpc.StatusInfo{Status: "error"}), Error) + a.Equal(GetStatus(rpc.StatusInfo{Status: "removed"}), Canceled) + a.Equal(GetStatus(rpc.StatusInfo{Status: "unknown"}), Unknown) } diff --git a/pkg/aria2/monitor/monitor.go b/pkg/aria2/monitor/monitor.go index a515b66f..6f6de7e9 100644 --- a/pkg/aria2/monitor/monitor.go +++ b/pkg/aria2/monitor/monitor.go @@ -109,14 +109,14 @@ func (monitor *Monitor) Update() bool { util.Log().Debug("离线下载[%s]更新状态[%s]", status.Gid, status.Status) - switch status.Status { - case "complete": + switch common.GetStatus(status) { + case common.Complete, common.Seeding: return monitor.Complete(task.TaskPoll) - case "error": + case common.Error: return monitor.Error(status) - case "active", "waiting", "paused": + case common.Downloading, common.Ready, common.Paused: return false - case "removed": + case common.Canceled: monitor.Task.Status = common.Canceled monitor.Task.Save() monitor.RemoveTempFolder() @@ -132,7 +132,7 @@ func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error { originSize := monitor.Task.TotalSize monitor.Task.GID = status.Gid - monitor.Task.Status = common.GetStatus(status.Status) + monitor.Task.Status = common.GetStatus(status) // 文件大小、已下载大小 total, err := strconv.ParseUint(status.TotalLength, 10, 64) @@ -235,6 +235,40 @@ func (monitor *Monitor) RemoveTempFolder() { // Complete 完成下载,返回是否中断监控 func (monitor *Monitor) Complete(pool task.Pool) bool { + // 未开始转存,提交转存任务 + if monitor.Task.TaskID == 0 { + return monitor.transfer(pool) + } + + // 做种完成 + if common.GetStatus(monitor.Task.StatusInfo) == common.Complete { + transferTask, err := model.GetTasksByID(monitor.Task.TaskID) + if err != nil { + monitor.setErrorStatus(err) + monitor.RemoveTempFolder() + return true + } + + // 转存完成,回收下载目录 + if transferTask.Type == task.TransferTaskType && transferTask.Status >= task.Error { + job, err := task.NewRecycleTask(monitor.Task.UserID, monitor.Task.Parent, monitor.node.ID()) + if err != nil { + monitor.setErrorStatus(err) + monitor.RemoveTempFolder() + return true + } + + // 提交回收任务 + pool.Submit(job) + + return true + } + } + + return false +} + +func (monitor *Monitor) transfer(pool task.Pool) bool { // 创建中转任务 file := make([]string, 0, len(monitor.Task.StatusInfo.Files)) sizes := make(map[string]uint64, len(monitor.Task.StatusInfo.Files)) @@ -269,7 +303,7 @@ func (monitor *Monitor) Complete(pool task.Pool) bool { monitor.Task.TaskID = job.Model().ID monitor.Task.Save() - return true + return false } func (monitor *Monitor) setErrorStatus(err error) { diff --git a/pkg/aria2/monitor/monitor_test.go b/pkg/aria2/monitor/monitor_test.go index 885484a3..a6be586a 100644 --- a/pkg/aria2/monitor/monitor_test.go +++ b/pkg/aria2/monitor/monitor_test.go @@ -3,6 +3,8 @@ package monitor import ( "database/sql" "errors" + "testing" + "github.com/DATA-DOG/go-sqlmock" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" @@ -13,7 +15,6 @@ import ( "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" testMock "github.com/stretchr/testify/mock" - "testing" ) var mock sqlmock.Sqlmock @@ -431,6 +432,14 @@ func TestMonitor_Complete(t *testing.T) { mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectCommit() + mock.ExpectQuery("SELECT(.+)tasks").WillReturnRows(sqlmock.NewRows([]string{"id", "type", "status"}).AddRow(1, 2, 4)) + mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(9414)) + mock.ExpectBegin() + mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(2, 1)) + mock.ExpectCommit() + + a.False(m.Complete(mockPool)) + m.Task.StatusInfo.Status = "complete" a.True(m.Complete(mockPool)) a.NoError(mock.ExpectationsWereMet()) mockNode.AssertExpectations(t) diff --git a/pkg/aria2/rpc/resp.go b/pkg/aria2/rpc/resp.go index e685ce66..3614228f 100644 --- a/pkg/aria2/rpc/resp.go +++ b/pkg/aria2/rpc/resp.go @@ -4,35 +4,27 @@ package rpc // StatusInfo represents response of aria2.tellStatus type StatusInfo struct { - Gid string `json:"gid"` // GID of the download. - Status string `json:"status"` // active for currently downloading/seeding downloads. waiting for downloads in the queue; download is not started. paused for paused downloads. error for downloads that were stopped because of error. complete for stopped and completed downloads. removed for the downloads removed by user. - TotalLength string `json:"totalLength"` // Total length of the download in bytes. - CompletedLength string `json:"completedLength"` // Completed length of the download in bytes. - UploadLength string `json:"uploadLength"` // Uploaded length of the download in bytes. - BitField string `json:"bitfield"` // Hexadecimal representation of the download progress. The highest bit corresponds to the piece at index 0. Any set bits indicate loaded pieces, while unset bits indicate not yet loaded and/or missing pieces. Any overflow bits at the end are set to zero. When the download was not started yet, this key will not be included in the response. - DownloadSpeed string `json:"downloadSpeed"` // Download speed of this download measured in bytes/sec. - UploadSpeed string `json:"uploadSpeed"` // LocalUpload speed of this download measured in bytes/sec. - InfoHash string `json:"infoHash"` // InfoHash. BitTorrent only. - NumSeeders string `json:"numSeeders"` // The number of seeders aria2 has connected to. BitTorrent only. - Seeder string `json:"seeder"` // true if the local endpoint is a seeder. Otherwise false. BitTorrent only. - PieceLength string `json:"pieceLength"` // Piece length in bytes. - NumPieces string `json:"numPieces"` // The number of pieces. - Connections string `json:"connections"` // The number of peers/servers aria2 has connected to. - ErrorCode string `json:"errorCode"` // The code of the last error for this item, if any. The value is a string. The error codes are defined in the EXIT STATUS section. This value is only available for stopped/completed downloads. - ErrorMessage string `json:"errorMessage"` // The (hopefully) human readable error message associated to errorCode. - FollowedBy []string `json:"followedBy"` // List of GIDs which are generated as the result of this download. For example, when aria2 downloads a Metalink file, it generates downloads described in the Metalink (see the --follow-metalink option). This value is useful to track auto-generated downloads. If there are no such downloads, this key will not be included in the response. - BelongsTo string `json:"belongsTo"` // GID of a parent download. Some downloads are a part of another download. For example, if a file in a Metalink has BitTorrent resources, the downloads of ".torrent" files are parts of that parent. If this download has no parent, this key will not be included in the response. - Dir string `json:"dir"` // Directory to save files. - Files []FileInfo `json:"files"` // Returns the list of files. The elements of this list are the same structs used in aria2.getFiles() method. - BitTorrent struct { - AnnounceList [][]string `json:"announceList"` // List of lists of announce URIs. If the torrent contains announce and no announce-list, announce is converted to the announce-list format. - Comment string `json:"comment"` // The comment of the torrent. comment.utf-8 is used if available. - CreationDate int64 `json:"creationDate"` // The creation time of the torrent. The value is an integer since the epoch, measured in seconds. - Mode string `json:"mode"` // File mode of the torrent. The value is either single or multi. - Info struct { - Name string `json:"name"` // name in info dictionary. name.utf-8 is used if available. - } `json:"info"` // Struct which contains data from Info dictionary. It contains following keys. - } `json:"bittorrent"` // Struct which contains information retrieved from the .torrent (file). BitTorrent only. It contains following keys. + Gid string `json:"gid"` // GID of the download. + Status string `json:"status"` // active for currently downloading/seeding downloads. waiting for downloads in the queue; download is not started. paused for paused downloads. error for downloads that were stopped because of error. complete for stopped and completed downloads. removed for the downloads removed by user. + TotalLength string `json:"totalLength"` // Total length of the download in bytes. + CompletedLength string `json:"completedLength"` // Completed length of the download in bytes. + UploadLength string `json:"uploadLength"` // Uploaded length of the download in bytes. + BitField string `json:"bitfield"` // Hexadecimal representation of the download progress. The highest bit corresponds to the piece at index 0. Any set bits indicate loaded pieces, while unset bits indicate not yet loaded and/or missing pieces. Any overflow bits at the end are set to zero. When the download was not started yet, this key will not be included in the response. + DownloadSpeed string `json:"downloadSpeed"` // Download speed of this download measured in bytes/sec. + UploadSpeed string `json:"uploadSpeed"` // LocalUpload speed of this download measured in bytes/sec. + InfoHash string `json:"infoHash"` // InfoHash. BitTorrent only. + NumSeeders string `json:"numSeeders"` // The number of seeders aria2 has connected to. BitTorrent only. + Seeder string `json:"seeder"` // true if the local endpoint is a seeder. Otherwise false. BitTorrent only. + PieceLength string `json:"pieceLength"` // Piece length in bytes. + NumPieces string `json:"numPieces"` // The number of pieces. + Connections string `json:"connections"` // The number of peers/servers aria2 has connected to. + ErrorCode string `json:"errorCode"` // The code of the last error for this item, if any. The value is a string. The error codes are defined in the EXIT STATUS section. This value is only available for stopped/completed downloads. + ErrorMessage string `json:"errorMessage"` // The (hopefully) human readable error message associated to errorCode. + FollowedBy []string `json:"followedBy"` // List of GIDs which are generated as the result of this download. For example, when aria2 downloads a Metalink file, it generates downloads described in the Metalink (see the --follow-metalink option). This value is useful to track auto-generated downloads. If there are no such downloads, this key will not be included in the response. + BelongsTo string `json:"belongsTo"` // GID of a parent download. Some downloads are a part of another download. For example, if a file in a Metalink has BitTorrent resources, the downloads of ".torrent" files are parts of that parent. If this download has no parent, this key will not be included in the response. + Dir string `json:"dir"` // Directory to save files. + Files []FileInfo `json:"files"` // Returns the list of files. The elements of this list are the same structs used in aria2.getFiles() method. + BitTorrent BitTorrentInfo `json:"bittorrent"` // Struct which contains information retrieved from the .torrent (file). BitTorrent only. It contains following keys. } // URIInfo represents an element of response of aria2.getUris @@ -100,3 +92,13 @@ type Method struct { Name string `json:"methodName"` // Method name to call Params []interface{} `json:"params"` // Array containing parameters to the method call } + +type BitTorrentInfo struct { + AnnounceList [][]string `json:"announceList"` // List of lists of announce URIs. If the torrent contains announce and no announce-list, announce is converted to the announce-list format. + Comment string `json:"comment"` // The comment of the torrent. comment.utf-8 is used if available. + CreationDate int64 `json:"creationDate"` // The creation time of the torrent. The value is an integer since the epoch, measured in seconds. + Mode string `json:"mode"` // File mode of the torrent. The value is either single or multi. + Info struct { + Name string `json:"name"` // name in info dictionary. name.utf-8 is used if available. + } `json:"info"` // Struct which contains data from Info dictionary. It contains following keys. +} diff --git a/pkg/filesystem/driver/shadow/slaveinmaster/handler.go b/pkg/filesystem/driver/shadow/slaveinmaster/handler.go index 7fc7b098..84116394 100644 --- a/pkg/filesystem/driver/shadow/slaveinmaster/handler.go +++ b/pkg/filesystem/driver/shadow/slaveinmaster/handler.go @@ -5,6 +5,9 @@ import ( "context" "encoding/json" "errors" + "net/url" + "time" + model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" @@ -13,8 +16,6 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "net/url" - "time" ) // Driver 影子存储策略,将上传任务指派给从机节点处理,并等待从机通知上传结果 @@ -118,6 +119,45 @@ func (d *Driver) List(ctx context.Context, path string, recursive bool) ([]respo } // 取消上传凭证 -func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error { +func (d *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error { + return nil +} + +func (d *Driver) Recycle(ctx context.Context, path string) error { + req := serializer.SlaveRecycleReq{ + Path: path, + } + + body, err := json.Marshal(req) + if err != nil { + return err + } + + // 订阅回收结果 + resChan := mq.GlobalMQ.Subscribe(req.Hash(model.GetSettingByName("siteID")), 0) + defer mq.GlobalMQ.Unsubscribe(req.Hash(model.GetSettingByName("siteID")), resChan) + + res, err := d.client.Request("PUT", "task/recycle", bytes.NewReader(body)). + CheckHTTPResponse(200). + DecodeResponse() + if err != nil { + return err + } + + if res.Code != 0 { + return serializer.NewErrorFromResponse(res) + } + + // 等待回收结果或者超时 + waitTimeout := model.GetIntSetting("slave_transfer_timeout", 172800) + select { + case <-time.After(time.Duration(waitTimeout) * time.Second): + return ErrWaitResultTimeout + case msg := <-resChan: + if msg.Event != serializer.SlaveRecycleSuccess { + return errors.New(msg.Content.(serializer.SlaveRecycleResult).Error) + } + } + return nil } diff --git a/pkg/serializer/slave.go b/pkg/serializer/slave.go index 245767a9..4179d455 100644 --- a/pkg/serializer/slave.go +++ b/pkg/serializer/slave.go @@ -4,6 +4,7 @@ import ( "crypto/sha1" "encoding/gob" "fmt" + model "github.com/cloudreve/Cloudreve/v3/models" ) @@ -53,15 +54,35 @@ func (s *SlaveTransferReq) Hash(id string) string { return fmt.Sprintf("%x", bs) } +// SlaveRecycleReq 从机回收任务创建请求 +type SlaveRecycleReq struct { + Path string `json:"path"` +} + +// Hash 返回创建请求的唯一标识,保持创建请求幂等 +func (s *SlaveRecycleReq) Hash(id string) string { + h := sha1.New() + h.Write([]byte(fmt.Sprintf("transfer-%s-%s", id, s.Path))) + bs := h.Sum(nil) + return fmt.Sprintf("%x", bs) +} + const ( SlaveTransferSuccess = "success" SlaveTransferFailed = "failed" + SlaveRecycleSuccess = "success" + SlaveRecycleFailed = "failed" ) type SlaveTransferResult struct { Error string } +type SlaveRecycleResult struct { + Error string +} + func init() { gob.Register(SlaveTransferResult{}) + gob.Register(SlaveRecycleResult{}) } diff --git a/pkg/serializer/slave_test.go b/pkg/serializer/slave_test.go index 64715421..add3a634 100644 --- a/pkg/serializer/slave_test.go +++ b/pkg/serializer/slave_test.go @@ -1,9 +1,10 @@ package serializer import ( + "testing" + model "github.com/cloudreve/Cloudreve/v3/models" "github.com/stretchr/testify/assert" - "testing" ) func TestSlaveTransferReq_Hash(t *testing.T) { @@ -18,3 +19,14 @@ func TestSlaveTransferReq_Hash(t *testing.T) { } a.NotEqual(s1.Hash("1"), s2.Hash("1")) } + +func TestSlaveRecycleReq_Hash(t *testing.T) { + a := assert.New(t) + s1 := &SlaveRecycleReq{ + Path: "1", + } + s2 := &SlaveRecycleReq{ + Path: "2", + } + a.NotEqual(s1.Hash("1"), s2.Hash("1")) +} diff --git a/pkg/task/job.go b/pkg/task/job.go index 781c4608..9bf52d74 100644 --- a/pkg/task/job.go +++ b/pkg/task/job.go @@ -13,6 +13,8 @@ const ( DecompressTaskType // TransferTaskType 中转任务 TransferTaskType + // RecycleTaskType 回收任务 + RecycleTaskType // ImportTaskType 导入任务 ImportTaskType ) @@ -113,6 +115,8 @@ func GetJobFromModel(task *model.Task) (Job, error) { return NewTransferTaskFromModel(task) case ImportTaskType: return NewImportTaskFromModel(task) + case RecycleTaskType: + return NewRecycleTaskFromModel(task) default: return nil, ErrUnknownTaskType } diff --git a/pkg/task/job_test.go b/pkg/task/job_test.go index 81793ee6..737f5b76 100644 --- a/pkg/task/job_test.go +++ b/pkg/task/job_test.go @@ -2,12 +2,12 @@ package task import ( "errors" - testMock "github.com/stretchr/testify/mock" "testing" "github.com/DATA-DOG/go-sqlmock" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/stretchr/testify/assert" + testMock "github.com/stretchr/testify/mock" ) func TestRecord(t *testing.T) { @@ -103,4 +103,16 @@ func TestGetJobFromModel(t *testing.T) { asserts.Nil(job) asserts.Error(err) } + // RecycleTaskType + { + task := &model.Task{ + Status: 0, + Type: RecycleTaskType, + } + mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error")) + job, err := GetJobFromModel(task) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Nil(job) + asserts.Error(err) + } } diff --git a/pkg/task/recycle.go b/pkg/task/recycle.go new file mode 100644 index 00000000..23abd96b --- /dev/null +++ b/pkg/task/recycle.go @@ -0,0 +1,155 @@ +package task + +import ( + "context" + "encoding/json" + "os" + + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/slaveinmaster" + "github.com/cloudreve/Cloudreve/v3/pkg/util" +) + +// RecycleTask 文件回收任务 +type RecycleTask struct { + User *model.User + TaskModel *model.Task + TaskProps RecycleProps + Err *JobError +} + +// RecycleProps 回收任务属性 +type RecycleProps struct { + // 回收目录 + Path string `json:"path"` + // 负责处理回收任务的节点ID + NodeID uint `json:"node_id"` +} + +// Props 获取任务属性 +func (job *RecycleTask) Props() string { + res, _ := json.Marshal(job.TaskProps) + return string(res) +} + +// Type 获取任务状态 +func (job *RecycleTask) Type() int { + return RecycleTaskType +} + +// Creator 获取创建者ID +func (job *RecycleTask) Creator() uint { + return job.User.ID +} + +// Model 获取任务的数据库模型 +func (job *RecycleTask) Model() *model.Task { + return job.TaskModel +} + +// SetStatus 设定状态 +func (job *RecycleTask) SetStatus(status int) { + job.TaskModel.SetStatus(status) +} + +// SetError 设定任务失败信息 +func (job *RecycleTask) SetError(err *JobError) { + job.Err = err + res, _ := json.Marshal(job.Err) + job.TaskModel.SetError(string(res)) + +} + +// SetErrorMsg 设定任务失败信息 +func (job *RecycleTask) SetErrorMsg(msg string, err error) { + jobErr := &JobError{Msg: msg} + if err != nil { + jobErr.Error = err.Error() + } + job.SetError(jobErr) +} + +// GetError 返回任务失败信息 +func (job *RecycleTask) GetError() *JobError { + return job.Err +} + +// Do 开始执行任务 +func (job *RecycleTask) Do() { + if job.TaskProps.NodeID == 1 { + err := os.RemoveAll(job.TaskProps.Path) + if err != nil { + util.Log().Warning("无法删除中转临时目录[%s], %s", job.TaskProps.Path, err) + job.SetErrorMsg("文件回收失败", err) + } + } else { + // 指定为从机回收 + + // 创建文件系统 + fs, err := filesystem.NewFileSystem(job.User) + if err != nil { + job.SetErrorMsg(err.Error(), nil) + return + } + + // 获取从机节点 + node := cluster.Default.GetNodeByID(job.TaskProps.NodeID) + if node == nil { + job.SetErrorMsg("从机节点不可用", nil) + } + + // 切换为从机节点处理回收 + fs.SwitchToSlaveHandler(node) + handler := fs.Handler.(*slaveinmaster.Driver) + err = handler.Recycle(context.Background(), job.TaskProps.Path) + if err != nil { + util.Log().Warning("无法删除中转临时目录[%s], %s", job.TaskProps.Path, err) + job.SetErrorMsg("文件回收失败", err) + } + } +} + +// NewRecycleTask 新建回收任务 +func NewRecycleTask(user uint, path string, node uint) (Job, error) { + creator, err := model.GetActiveUserByID(user) + if err != nil { + return nil, err + } + + newTask := &RecycleTask{ + User: &creator, + TaskProps: RecycleProps{ + Path: path, + NodeID: node, + }, + } + + record, err := Record(newTask) + if err != nil { + return nil, err + } + newTask.TaskModel = record + + return newTask, nil +} + +// NewRecycleTaskFromModel 从数据库记录中恢复回收任务 +func NewRecycleTaskFromModel(task *model.Task) (Job, error) { + user, err := model.GetActiveUserByID(task.UserID) + if err != nil { + return nil, err + } + newTask := &RecycleTask{ + User: &user, + TaskModel: task, + } + + err = json.Unmarshal([]byte(task.Props), &newTask.TaskProps) + if err != nil { + return nil, err + } + + return newTask, nil +} diff --git a/pkg/task/recycle_test.go b/pkg/task/recycle_test.go new file mode 100644 index 00000000..3fad4778 --- /dev/null +++ b/pkg/task/recycle_test.go @@ -0,0 +1,131 @@ +package task + +import ( + "errors" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/jinzhu/gorm" + "github.com/stretchr/testify/assert" +) + +func TestRecycleTask_Props(t *testing.T) { + asserts := assert.New(t) + task := &RecycleTask{ + User: &model.User{}, + } + asserts.NotEmpty(task.Props()) + asserts.Equal(RecycleTaskType, task.Type()) + asserts.EqualValues(0, task.Creator()) + asserts.Nil(task.Model()) +} + +func TestRecycleTask_SetStatus(t *testing.T) { + asserts := assert.New(t) + task := &RecycleTask{ + User: &model.User{}, + TaskModel: &model.Task{ + Model: gorm.Model{ID: 1}, + }, + } + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + task.SetStatus(3) + asserts.NoError(mock.ExpectationsWereMet()) +} + +func TestRecycleTask_SetError(t *testing.T) { + asserts := assert.New(t) + task := &RecycleTask{ + User: &model.User{}, + TaskModel: &model.Task{ + Model: gorm.Model{ID: 1}, + }, + } + + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + task.SetErrorMsg("error", nil) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Equal("error", task.GetError().Msg) +} + +func TestRecycleTask_Do(t *testing.T) { + asserts := assert.New(t) + task := &RecycleTask{ + TaskModel: &model.Task{ + Model: gorm.Model{ID: 1}, + }, + } + + // 目录不存在 + { + task.TaskProps.Path = "test/not_exist" + task.User = &model.User{ + Policy: model.Policy{ + Type: "unknown", + }, + } + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, + 1)) + mock.ExpectCommit() + task.Do() + asserts.NoError(mock.ExpectationsWereMet()) + asserts.NotEmpty(task.GetError().Msg) + } +} + +func TestNewRecycleTask(t *testing.T) { + asserts := assert.New(t) + + // 成功 + { + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + mock.ExpectBegin() + mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + job, err := NewRecycleTask(1, "/", 0) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.NotNil(job) + asserts.NoError(err) + } + + // 失败 + { + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + mock.ExpectBegin() + mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) + mock.ExpectRollback() + job, err := NewRecycleTask(1, "test/not_exist", 0) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Nil(job) + asserts.Error(err) + } +} + +func TestNewRecycleTaskFromModel(t *testing.T) { + asserts := assert.New(t) + + // 成功 + { + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + job, err := NewRecycleTaskFromModel(&model.Task{Props: "{}"}) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(err) + asserts.NotNil(job) + } + + // JSON解析失败 + { + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + job, err := NewRecycleTaskFromModel(&model.Task{Props: "?"}) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Error(err) + asserts.Nil(job) + } +} diff --git a/pkg/task/slavetask/recycle.go b/pkg/task/slavetask/recycle.go new file mode 100644 index 00000000..d8c7bc8d --- /dev/null +++ b/pkg/task/slavetask/recycle.go @@ -0,0 +1,95 @@ +package slavetask + +import ( + "os" + + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/cloudreve/Cloudreve/v3/pkg/task" + "github.com/cloudreve/Cloudreve/v3/pkg/util" +) + +// RecycleTask 文件回收任务 +type RecycleTask struct { + Err *task.JobError + Req *serializer.SlaveRecycleReq + MasterID string +} + +// Props 获取任务属性 +func (job *RecycleTask) Props() string { + return "" +} + +// Type 获取任务类型 +func (job *RecycleTask) Type() int { + return 0 +} + +// Creator 获取创建者ID +func (job *RecycleTask) Creator() uint { + return 0 +} + +// Model 获取任务的数据库模型 +func (job *RecycleTask) Model() *model.Task { + return nil +} + +// SetStatus 设定状态 +func (job *RecycleTask) SetStatus(status int) { +} + +// SetError 设定任务失败信息 +func (job *RecycleTask) SetError(err *task.JobError) { + job.Err = err +} + +// SetErrorMsg 设定任务失败信息 +func (job *RecycleTask) SetErrorMsg(msg string, err error) { + jobErr := &task.JobError{Msg: msg} + if err != nil { + jobErr.Error = err.Error() + } + + job.SetError(jobErr) + + notifyMsg := mq.Message{ + TriggeredBy: job.MasterID, + Event: serializer.SlaveRecycleFailed, + Content: serializer.SlaveRecycleResult{ + Error: err.Error(), + }, + } + + if err = cluster.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), notifyMsg); err != nil { + util.Log().Warning("无法发送回收失败通知到从机, %s", err) + } +} + +// GetError 返回任务失败信息 +func (job *RecycleTask) GetError() *task.JobError { + return job.Err +} + +// Do 开始执行任务 +func (job *RecycleTask) Do() { + err := os.RemoveAll(job.Req.Path) + if err != nil { + util.Log().Warning("无法删除中转临时文件[%s], %s", job.Req.Path, err) + job.SetErrorMsg("文件回收失败", err) + return + } + + msg := mq.Message{ + TriggeredBy: job.MasterID, + Event: serializer.SlaveRecycleSuccess, + Content: serializer.SlaveRecycleResult{}, + } + + if err = cluster.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), msg); err != nil { + util.Log().Warning("无法发送回收成功通知到从机, %s", err) + } +} diff --git a/pkg/task/slavetask/transfer.go b/pkg/task/slavetask/transfer.go index 20c5fcc9..818028eb 100644 --- a/pkg/task/slavetask/transfer.go +++ b/pkg/task/slavetask/transfer.go @@ -2,6 +2,8 @@ package slavetask import ( "context" + "os" + model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" @@ -10,7 +12,6 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/task" "github.com/cloudreve/Cloudreve/v3/pkg/util" - "os" ) // TransferTask 文件中转任务 @@ -79,8 +80,6 @@ func (job *TransferTask) GetError() *task.JobError { // Do 开始执行任务 func (job *TransferTask) Do() { - defer job.Recycle() - fs, err := filesystem.NewAnonymousFileSystem() if err != nil { job.SetErrorMsg("无法初始化匿名文件系统", err) @@ -137,11 +136,3 @@ func (job *TransferTask) Do() { util.Log().Warning("无法发送转存成功通知到从机, %s", err) } } - -// Recycle 回收临时文件 -func (job *TransferTask) Recycle() { - err := os.Remove(job.Req.Src) - if err != nil { - util.Log().Warning("无法删除中转临时文件[%s], %s", job.Req.Src, err) - } -} diff --git a/pkg/task/tranfer.go b/pkg/task/tranfer.go index 5f9aa58e..f115e803 100644 --- a/pkg/task/tranfer.go +++ b/pkg/task/tranfer.go @@ -3,7 +3,6 @@ package task import ( "context" "encoding/json" - "os" "path" "path/filepath" "strings" @@ -87,8 +86,6 @@ func (job *TransferTask) GetError() *JobError { // Do 开始执行任务 func (job *TransferTask) Do() { - defer job.Recycle() - // 创建文件系统 fs, err := filesystem.NewFileSystem(job.User) if err != nil { @@ -139,16 +136,6 @@ func (job *TransferTask) Do() { } -// Recycle 回收临时文件 -func (job *TransferTask) Recycle() { - if job.TaskProps.NodeID == 1 { - err := os.RemoveAll(job.TaskProps.Parent) - if err != nil { - util.Log().Warning("无法删除中转临时目录[%s], %s", job.TaskProps.Parent, err) - } - } -} - // NewTransferTask 新建中转任务 func NewTransferTask(user uint, src []string, dst, parent string, trim bool, node uint, sizes map[string]uint64) (Job, error) { creator, err := model.GetActiveUserByID(user) diff --git a/routers/controllers/slave.go b/routers/controllers/slave.go index e1e7de22..2b5b15ce 100644 --- a/routers/controllers/slave.go +++ b/routers/controllers/slave.go @@ -212,6 +212,17 @@ func SlaveCreateTransferTask(c *gin.Context) { } } +// SlaveCreateRecycleTask 从机创建回收任务 +func SlaveCreateRecycleTask(c *gin.Context) { + var service serializer.SlaveRecycleReq + if err := c.ShouldBindJSON(&service); err == nil { + res := explorer.CreateRecycleTask(c, &service) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + // SlaveNotificationPush 处理从机发送的消息推送 func SlaveNotificationPush(c *gin.Context) { var service node.SlaveNotificationService diff --git a/routers/router.go b/routers/router.go index 0727fe6f..f7586b3a 100644 --- a/routers/router.go +++ b/routers/router.go @@ -88,6 +88,7 @@ func InitSlaveRouter() *gin.Engine { task := v3.Group("task") { task.PUT("transfer", controllers.SlaveCreateTransferTask) + task.PUT("recycle", controllers.SlaveCreateRecycleTask) } } return r diff --git a/service/aria2/manage.go b/service/aria2/manage.go index 6344ddd6..6f55a8f9 100644 --- a/service/aria2/manage.go +++ b/service/aria2/manage.go @@ -33,7 +33,7 @@ func (service *DownloadListService) Finished(c *gin.Context, user *model.User) s // Downloading 获取正在下载中的任务 func (service *DownloadListService) Downloading(c *gin.Context, user *model.User) serializer.Response { // 查找下载记录 - downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, common.Downloading, common.Paused, common.Ready) + downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, common.Downloading, common.Seeding, common.Paused, common.Ready) intervals := make(map[uint]int) for _, download := range downloads { if _, ok := intervals[download.ID]; !ok { diff --git a/service/explorer/slave.go b/service/explorer/slave.go index 1435640d..253fcb87 100644 --- a/service/explorer/slave.go +++ b/service/explorer/slave.go @@ -5,6 +5,10 @@ import ( "encoding/base64" "encoding/json" "fmt" + "net/http" + "net/url" + "time" + model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" @@ -16,9 +20,6 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/gin-gonic/gin" "github.com/jinzhu/gorm" - "net/http" - "net/url" - "time" ) // SlaveDownloadService 从机文件下載服务 @@ -165,6 +166,26 @@ func CreateTransferTask(c *gin.Context, req *serializer.SlaveTransferReq) serial return serializer.ParamErr("未知的主机节点ID", nil) } +// CreateRecycleTask 创建从机文件回收任务 +func CreateRecycleTask(c *gin.Context, req *serializer.SlaveRecycleReq) serializer.Response { + if id, ok := c.Get("MasterSiteID"); ok { + job := &slavetask.RecycleTask{ + Req: req, + MasterID: id.(string), + } + + if err := cluster.DefaultController.SubmitTask(job.MasterID, job, req.Hash(job.MasterID), func(job interface{}) { + task.TaskPoll.Submit(job.(task.Job)) + }); err != nil { + return serializer.Err(serializer.CodeCreateTaskError, "", err) + } + + return serializer.Response{} + } + + return serializer.ParamErr("未知的主机节点ID", nil) +} + // SlaveListService 从机上传会话服务 type SlaveCreateUploadSessionService struct { Session serializer.UploadSession `json:"session" binding:"required"` diff --git a/service/user/register.go b/service/user/register.go index d3c81b5e..35e8253d 100644 --- a/service/user/register.go +++ b/service/user/register.go @@ -1,14 +1,15 @@ package user import ( + "net/url" + "strings" + model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/email" "github.com/cloudreve/Cloudreve/v3/pkg/hashid" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/gin-gonic/gin" - "net/url" - "strings" ) // UserRegisterService 管理用户注册的服务