feat: add aria2 seeding

pull/1422/head
XYenon 3 years ago
parent 96daed26b4
commit 8f955f130f

@ -1 +1 @@
Subproject commit a1028e7e0ae96be4bb67d8c117cf39e07c207473 Subproject commit 02d93206cc5b943c34b5f5ac86c23dd96f5ef603

@ -3,8 +3,6 @@ package aria2
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"net/url" "net/url"
"sync" "sync"
"time" "time"
@ -14,6 +12,8 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/balancer" "github.com/cloudreve/Cloudreve/v3/pkg/balancer"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
) )
// Instance 默认使用的Aria2处理实例 // Instance 默认使用的Aria2处理实例
@ -40,7 +40,7 @@ func Init(isReload bool, pool cluster.Pool, mqClient mq.MQ) {
if !isReload { if !isReload {
// 从数据库中读取未完成任务,创建监控 // 从数据库中读取未完成任务,创建监控
unfinished := model.GetDownloadsByStatus(common.Ready, common.Paused, common.Downloading) unfinished := model.GetDownloadsByStatus(common.Ready, common.Paused, common.Downloading, common.Seeding)
for i := 0; i < len(unfinished); i++ { for i := 0; i < len(unfinished); i++ {
// 创建任务监控 // 创建任务监控

@ -38,6 +38,8 @@ const (
Downloading Downloading
// Paused 暂停中 // Paused 暂停中
Paused Paused
// Seeding 做种中
Seeding
// Error 出错 // Error 出错
Error Error
// Complete 完成 // Complete 完成
@ -94,11 +96,14 @@ func (instance *DummyAria2) DeleteTempFile(src *model.Download) error {
} }
// GetStatus 将给定的状态字符串转换为状态标识数字 // GetStatus 将给定的状态字符串转换为状态标识数字
func GetStatus(status string) int { func GetStatus(status rpc.StatusInfo) int {
switch status { switch status.Status {
case "complete": case "complete":
return Complete return Complete
case "active": case "active":
if status.BitTorrent.Mode != "" && status.CompletedLength == status.TotalLength {
return Seeding
}
return Downloading return Downloading
case "waiting": case "waiting":
return Ready return Ready

@ -1,9 +1,11 @@
package common package common
import ( import (
"testing"
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"testing"
) )
func TestDummyAria2(t *testing.T) { func TestDummyAria2(t *testing.T) {
@ -35,11 +37,18 @@ func TestDummyAria2(t *testing.T) {
func TestGetStatus(t *testing.T) { func TestGetStatus(t *testing.T) {
a := assert.New(t) a := assert.New(t)
a.Equal(GetStatus("complete"), Complete) a.Equal(GetStatus(rpc.StatusInfo{Status: "complete"}), Complete)
a.Equal(GetStatus("active"), Downloading) a.Equal(GetStatus(rpc.StatusInfo{Status: "active",
a.Equal(GetStatus("waiting"), Ready) BitTorrent: rpc.BitTorrentInfo{Mode: ""}}), Downloading)
a.Equal(GetStatus("paused"), Paused) a.Equal(GetStatus(rpc.StatusInfo{Status: "active",
a.Equal(GetStatus("error"), Error) BitTorrent: rpc.BitTorrentInfo{Mode: "single"},
a.Equal(GetStatus("removed"), Canceled) TotalLength: "100", CompletedLength: "50"}), Downloading)
a.Equal(GetStatus("unknown"), Unknown) a.Equal(GetStatus(rpc.StatusInfo{Status: "active",
BitTorrent: rpc.BitTorrentInfo{Mode: "multi"},
TotalLength: "100", CompletedLength: "100"}), Seeding)
a.Equal(GetStatus(rpc.StatusInfo{Status: "waiting"}), Ready)
a.Equal(GetStatus(rpc.StatusInfo{Status: "paused"}), Paused)
a.Equal(GetStatus(rpc.StatusInfo{Status: "error"}), Error)
a.Equal(GetStatus(rpc.StatusInfo{Status: "removed"}), Canceled)
a.Equal(GetStatus(rpc.StatusInfo{Status: "unknown"}), Unknown)
} }

@ -109,14 +109,14 @@ func (monitor *Monitor) Update() bool {
util.Log().Debug("离线下载[%s]更新状态[%s]", status.Gid, status.Status) util.Log().Debug("离线下载[%s]更新状态[%s]", status.Gid, status.Status)
switch status.Status { switch common.GetStatus(status) {
case "complete": case common.Complete, common.Seeding:
return monitor.Complete(task.TaskPoll) return monitor.Complete(task.TaskPoll)
case "error": case common.Error:
return monitor.Error(status) return monitor.Error(status)
case "active", "waiting", "paused": case common.Downloading, common.Ready, common.Paused:
return false return false
case "removed": case common.Canceled:
monitor.Task.Status = common.Canceled monitor.Task.Status = common.Canceled
monitor.Task.Save() monitor.Task.Save()
monitor.RemoveTempFolder() monitor.RemoveTempFolder()
@ -132,7 +132,7 @@ func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error {
originSize := monitor.Task.TotalSize originSize := monitor.Task.TotalSize
monitor.Task.GID = status.Gid monitor.Task.GID = status.Gid
monitor.Task.Status = common.GetStatus(status.Status) monitor.Task.Status = common.GetStatus(status)
// 文件大小、已下载大小 // 文件大小、已下载大小
total, err := strconv.ParseUint(status.TotalLength, 10, 64) total, err := strconv.ParseUint(status.TotalLength, 10, 64)
@ -235,6 +235,40 @@ func (monitor *Monitor) RemoveTempFolder() {
// Complete 完成下载,返回是否中断监控 // Complete 完成下载,返回是否中断监控
func (monitor *Monitor) Complete(pool task.Pool) bool { func (monitor *Monitor) Complete(pool task.Pool) bool {
// 未开始转存,提交转存任务
if monitor.Task.TaskID == 0 {
return monitor.transfer(pool)
}
// 做种完成
if common.GetStatus(monitor.Task.StatusInfo) == common.Complete {
transferTask, err := model.GetTasksByID(monitor.Task.TaskID)
if err != nil {
monitor.setErrorStatus(err)
monitor.RemoveTempFolder()
return true
}
// 转存完成,回收下载目录
if transferTask.Type == task.TransferTaskType && transferTask.Status >= task.Error {
job, err := task.NewRecycleTask(monitor.Task.UserID, monitor.Task.Parent, monitor.node.ID())
if err != nil {
monitor.setErrorStatus(err)
monitor.RemoveTempFolder()
return true
}
// 提交回收任务
pool.Submit(job)
return true
}
}
return false
}
func (monitor *Monitor) transfer(pool task.Pool) bool {
// 创建中转任务 // 创建中转任务
file := make([]string, 0, len(monitor.Task.StatusInfo.Files)) file := make([]string, 0, len(monitor.Task.StatusInfo.Files))
sizes := make(map[string]uint64, len(monitor.Task.StatusInfo.Files)) sizes := make(map[string]uint64, len(monitor.Task.StatusInfo.Files))
@ -269,7 +303,7 @@ func (monitor *Monitor) Complete(pool task.Pool) bool {
monitor.Task.TaskID = job.Model().ID monitor.Task.TaskID = job.Model().ID
monitor.Task.Save() monitor.Task.Save()
return true return false
} }
func (monitor *Monitor) setErrorStatus(err error) { func (monitor *Monitor) setErrorStatus(err error) {

@ -3,6 +3,8 @@ package monitor
import ( import (
"database/sql" "database/sql"
"errors" "errors"
"testing"
"github.com/DATA-DOG/go-sqlmock" "github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
@ -13,7 +15,6 @@ import (
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock" testMock "github.com/stretchr/testify/mock"
"testing"
) )
var mock sqlmock.Sqlmock var mock sqlmock.Sqlmock
@ -431,6 +432,14 @@ func TestMonitor_Complete(t *testing.T) {
mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit() mock.ExpectCommit()
mock.ExpectQuery("SELECT(.+)tasks").WillReturnRows(sqlmock.NewRows([]string{"id", "type", "status"}).AddRow(1, 2, 4))
mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(9414))
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(2, 1))
mock.ExpectCommit()
a.False(m.Complete(mockPool))
m.Task.StatusInfo.Status = "complete"
a.True(m.Complete(mockPool)) a.True(m.Complete(mockPool))
a.NoError(mock.ExpectationsWereMet()) a.NoError(mock.ExpectationsWereMet())
mockNode.AssertExpectations(t) mockNode.AssertExpectations(t)

@ -24,15 +24,7 @@ type StatusInfo struct {
BelongsTo string `json:"belongsTo"` // GID of a parent download. Some downloads are a part of another download. For example, if a file in a Metalink has BitTorrent resources, the downloads of ".torrent" files are parts of that parent. If this download has no parent, this key will not be included in the response. BelongsTo string `json:"belongsTo"` // GID of a parent download. Some downloads are a part of another download. For example, if a file in a Metalink has BitTorrent resources, the downloads of ".torrent" files are parts of that parent. If this download has no parent, this key will not be included in the response.
Dir string `json:"dir"` // Directory to save files. Dir string `json:"dir"` // Directory to save files.
Files []FileInfo `json:"files"` // Returns the list of files. The elements of this list are the same structs used in aria2.getFiles() method. Files []FileInfo `json:"files"` // Returns the list of files. The elements of this list are the same structs used in aria2.getFiles() method.
BitTorrent struct { BitTorrent BitTorrentInfo `json:"bittorrent"` // Struct which contains information retrieved from the .torrent (file). BitTorrent only. It contains following keys.
AnnounceList [][]string `json:"announceList"` // List of lists of announce URIs. If the torrent contains announce and no announce-list, announce is converted to the announce-list format.
Comment string `json:"comment"` // The comment of the torrent. comment.utf-8 is used if available.
CreationDate int64 `json:"creationDate"` // The creation time of the torrent. The value is an integer since the epoch, measured in seconds.
Mode string `json:"mode"` // File mode of the torrent. The value is either single or multi.
Info struct {
Name string `json:"name"` // name in info dictionary. name.utf-8 is used if available.
} `json:"info"` // Struct which contains data from Info dictionary. It contains following keys.
} `json:"bittorrent"` // Struct which contains information retrieved from the .torrent (file). BitTorrent only. It contains following keys.
} }
// URIInfo represents an element of response of aria2.getUris // URIInfo represents an element of response of aria2.getUris
@ -100,3 +92,13 @@ type Method struct {
Name string `json:"methodName"` // Method name to call Name string `json:"methodName"` // Method name to call
Params []interface{} `json:"params"` // Array containing parameters to the method call Params []interface{} `json:"params"` // Array containing parameters to the method call
} }
type BitTorrentInfo struct {
AnnounceList [][]string `json:"announceList"` // List of lists of announce URIs. If the torrent contains announce and no announce-list, announce is converted to the announce-list format.
Comment string `json:"comment"` // The comment of the torrent. comment.utf-8 is used if available.
CreationDate int64 `json:"creationDate"` // The creation time of the torrent. The value is an integer since the epoch, measured in seconds.
Mode string `json:"mode"` // File mode of the torrent. The value is either single or multi.
Info struct {
Name string `json:"name"` // name in info dictionary. name.utf-8 is used if available.
} `json:"info"` // Struct which contains data from Info dictionary. It contains following keys.
}

@ -5,6 +5,9 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"net/url"
"time"
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
@ -13,8 +16,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"net/url"
"time"
) )
// Driver 影子存储策略,将上传任务指派给从机节点处理,并等待从机通知上传结果 // Driver 影子存储策略,将上传任务指派给从机节点处理,并等待从机通知上传结果
@ -118,6 +119,45 @@ func (d *Driver) List(ctx context.Context, path string, recursive bool) ([]respo
} }
// 取消上传凭证 // 取消上传凭证
func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error { func (d *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
return nil
}
func (d *Driver) Recycle(ctx context.Context, path string) error {
req := serializer.SlaveRecycleReq{
Path: path,
}
body, err := json.Marshal(req)
if err != nil {
return err
}
// 订阅回收结果
resChan := mq.GlobalMQ.Subscribe(req.Hash(model.GetSettingByName("siteID")), 0)
defer mq.GlobalMQ.Unsubscribe(req.Hash(model.GetSettingByName("siteID")), resChan)
res, err := d.client.Request("PUT", "task/recycle", bytes.NewReader(body)).
CheckHTTPResponse(200).
DecodeResponse()
if err != nil {
return err
}
if res.Code != 0 {
return serializer.NewErrorFromResponse(res)
}
// 等待回收结果或者超时
waitTimeout := model.GetIntSetting("slave_transfer_timeout", 172800)
select {
case <-time.After(time.Duration(waitTimeout) * time.Second):
return ErrWaitResultTimeout
case msg := <-resChan:
if msg.Event != serializer.SlaveRecycleSuccess {
return errors.New(msg.Content.(serializer.SlaveRecycleResult).Error)
}
}
return nil return nil
} }

@ -4,6 +4,7 @@ import (
"crypto/sha1" "crypto/sha1"
"encoding/gob" "encoding/gob"
"fmt" "fmt"
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
) )
@ -53,15 +54,35 @@ func (s *SlaveTransferReq) Hash(id string) string {
return fmt.Sprintf("%x", bs) return fmt.Sprintf("%x", bs)
} }
// SlaveRecycleReq 从机回收任务创建请求
type SlaveRecycleReq struct {
Path string `json:"path"`
}
// Hash 返回创建请求的唯一标识,保持创建请求幂等
func (s *SlaveRecycleReq) Hash(id string) string {
h := sha1.New()
h.Write([]byte(fmt.Sprintf("transfer-%s-%s", id, s.Path)))
bs := h.Sum(nil)
return fmt.Sprintf("%x", bs)
}
const ( const (
SlaveTransferSuccess = "success" SlaveTransferSuccess = "success"
SlaveTransferFailed = "failed" SlaveTransferFailed = "failed"
SlaveRecycleSuccess = "success"
SlaveRecycleFailed = "failed"
) )
type SlaveTransferResult struct { type SlaveTransferResult struct {
Error string Error string
} }
type SlaveRecycleResult struct {
Error string
}
func init() { func init() {
gob.Register(SlaveTransferResult{}) gob.Register(SlaveTransferResult{})
gob.Register(SlaveRecycleResult{})
} }

@ -1,9 +1,10 @@
package serializer package serializer
import ( import (
"testing"
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"testing"
) )
func TestSlaveTransferReq_Hash(t *testing.T) { func TestSlaveTransferReq_Hash(t *testing.T) {
@ -18,3 +19,14 @@ func TestSlaveTransferReq_Hash(t *testing.T) {
} }
a.NotEqual(s1.Hash("1"), s2.Hash("1")) a.NotEqual(s1.Hash("1"), s2.Hash("1"))
} }
func TestSlaveRecycleReq_Hash(t *testing.T) {
a := assert.New(t)
s1 := &SlaveRecycleReq{
Path: "1",
}
s2 := &SlaveRecycleReq{
Path: "2",
}
a.NotEqual(s1.Hash("1"), s2.Hash("1"))
}

@ -13,6 +13,8 @@ const (
DecompressTaskType DecompressTaskType
// TransferTaskType 中转任务 // TransferTaskType 中转任务
TransferTaskType TransferTaskType
// RecycleTaskType 回收任务
RecycleTaskType
// ImportTaskType 导入任务 // ImportTaskType 导入任务
ImportTaskType ImportTaskType
) )
@ -113,6 +115,8 @@ func GetJobFromModel(task *model.Task) (Job, error) {
return NewTransferTaskFromModel(task) return NewTransferTaskFromModel(task)
case ImportTaskType: case ImportTaskType:
return NewImportTaskFromModel(task) return NewImportTaskFromModel(task)
case RecycleTaskType:
return NewRecycleTaskFromModel(task)
default: default:
return nil, ErrUnknownTaskType return nil, ErrUnknownTaskType
} }

@ -2,12 +2,12 @@ package task
import ( import (
"errors" "errors"
testMock "github.com/stretchr/testify/mock"
"testing" "testing"
"github.com/DATA-DOG/go-sqlmock" "github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
) )
func TestRecord(t *testing.T) { func TestRecord(t *testing.T) {
@ -103,4 +103,16 @@ func TestGetJobFromModel(t *testing.T) {
asserts.Nil(job) asserts.Nil(job)
asserts.Error(err) asserts.Error(err)
} }
// RecycleTaskType
{
task := &model.Task{
Status: 0,
Type: RecycleTaskType,
}
mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error"))
job, err := GetJobFromModel(task)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Nil(job)
asserts.Error(err)
}
} }

@ -0,0 +1,155 @@
package task
import (
"context"
"encoding/json"
"os"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/slaveinmaster"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
// RecycleTask 文件回收任务
type RecycleTask struct {
User *model.User
TaskModel *model.Task
TaskProps RecycleProps
Err *JobError
}
// RecycleProps 回收任务属性
type RecycleProps struct {
// 回收目录
Path string `json:"path"`
// 负责处理回收任务的节点ID
NodeID uint `json:"node_id"`
}
// Props 获取任务属性
func (job *RecycleTask) Props() string {
res, _ := json.Marshal(job.TaskProps)
return string(res)
}
// Type 获取任务状态
func (job *RecycleTask) Type() int {
return RecycleTaskType
}
// Creator 获取创建者ID
func (job *RecycleTask) Creator() uint {
return job.User.ID
}
// Model 获取任务的数据库模型
func (job *RecycleTask) Model() *model.Task {
return job.TaskModel
}
// SetStatus 设定状态
func (job *RecycleTask) SetStatus(status int) {
job.TaskModel.SetStatus(status)
}
// SetError 设定任务失败信息
func (job *RecycleTask) SetError(err *JobError) {
job.Err = err
res, _ := json.Marshal(job.Err)
job.TaskModel.SetError(string(res))
}
// SetErrorMsg 设定任务失败信息
func (job *RecycleTask) SetErrorMsg(msg string, err error) {
jobErr := &JobError{Msg: msg}
if err != nil {
jobErr.Error = err.Error()
}
job.SetError(jobErr)
}
// GetError 返回任务失败信息
func (job *RecycleTask) GetError() *JobError {
return job.Err
}
// Do 开始执行任务
func (job *RecycleTask) Do() {
if job.TaskProps.NodeID == 1 {
err := os.RemoveAll(job.TaskProps.Path)
if err != nil {
util.Log().Warning("无法删除中转临时目录[%s], %s", job.TaskProps.Path, err)
job.SetErrorMsg("文件回收失败", err)
}
} else {
// 指定为从机回收
// 创建文件系统
fs, err := filesystem.NewFileSystem(job.User)
if err != nil {
job.SetErrorMsg(err.Error(), nil)
return
}
// 获取从机节点
node := cluster.Default.GetNodeByID(job.TaskProps.NodeID)
if node == nil {
job.SetErrorMsg("从机节点不可用", nil)
}
// 切换为从机节点处理回收
fs.SwitchToSlaveHandler(node)
handler := fs.Handler.(*slaveinmaster.Driver)
err = handler.Recycle(context.Background(), job.TaskProps.Path)
if err != nil {
util.Log().Warning("无法删除中转临时目录[%s], %s", job.TaskProps.Path, err)
job.SetErrorMsg("文件回收失败", err)
}
}
}
// NewRecycleTask 新建回收任务
func NewRecycleTask(user uint, path string, node uint) (Job, error) {
creator, err := model.GetActiveUserByID(user)
if err != nil {
return nil, err
}
newTask := &RecycleTask{
User: &creator,
TaskProps: RecycleProps{
Path: path,
NodeID: node,
},
}
record, err := Record(newTask)
if err != nil {
return nil, err
}
newTask.TaskModel = record
return newTask, nil
}
// NewRecycleTaskFromModel 从数据库记录中恢复回收任务
func NewRecycleTaskFromModel(task *model.Task) (Job, error) {
user, err := model.GetActiveUserByID(task.UserID)
if err != nil {
return nil, err
}
newTask := &RecycleTask{
User: &user,
TaskModel: task,
}
err = json.Unmarshal([]byte(task.Props), &newTask.TaskProps)
if err != nil {
return nil, err
}
return newTask, nil
}

@ -0,0 +1,131 @@
package task
import (
"errors"
"testing"
"github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
)
func TestRecycleTask_Props(t *testing.T) {
asserts := assert.New(t)
task := &RecycleTask{
User: &model.User{},
}
asserts.NotEmpty(task.Props())
asserts.Equal(RecycleTaskType, task.Type())
asserts.EqualValues(0, task.Creator())
asserts.Nil(task.Model())
}
func TestRecycleTask_SetStatus(t *testing.T) {
asserts := assert.New(t)
task := &RecycleTask{
User: &model.User{},
TaskModel: &model.Task{
Model: gorm.Model{ID: 1},
},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
task.SetStatus(3)
asserts.NoError(mock.ExpectationsWereMet())
}
func TestRecycleTask_SetError(t *testing.T) {
asserts := assert.New(t)
task := &RecycleTask{
User: &model.User{},
TaskModel: &model.Task{
Model: gorm.Model{ID: 1},
},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
task.SetErrorMsg("error", nil)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal("error", task.GetError().Msg)
}
func TestRecycleTask_Do(t *testing.T) {
asserts := assert.New(t)
task := &RecycleTask{
TaskModel: &model.Task{
Model: gorm.Model{ID: 1},
},
}
// 目录不存在
{
task.TaskProps.Path = "test/not_exist"
task.User = &model.User{
Policy: model.Policy{
Type: "unknown",
},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1,
1))
mock.ExpectCommit()
task.Do()
asserts.NoError(mock.ExpectationsWereMet())
asserts.NotEmpty(task.GetError().Msg)
}
}
func TestNewRecycleTask(t *testing.T) {
asserts := assert.New(t)
// 成功
{
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
job, err := NewRecycleTask(1, "/", 0)
asserts.NoError(mock.ExpectationsWereMet())
asserts.NotNil(job)
asserts.NoError(err)
}
// 失败
{
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
job, err := NewRecycleTask(1, "test/not_exist", 0)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Nil(job)
asserts.Error(err)
}
}
func TestNewRecycleTaskFromModel(t *testing.T) {
asserts := assert.New(t)
// 成功
{
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
job, err := NewRecycleTaskFromModel(&model.Task{Props: "{}"})
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.NotNil(job)
}
// JSON解析失败
{
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
job, err := NewRecycleTaskFromModel(&model.Task{Props: "?"})
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
asserts.Nil(job)
}
}

@ -0,0 +1,95 @@
package slavetask
import (
"os"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/task"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
// RecycleTask 文件回收任务
type RecycleTask struct {
Err *task.JobError
Req *serializer.SlaveRecycleReq
MasterID string
}
// Props 获取任务属性
func (job *RecycleTask) Props() string {
return ""
}
// Type 获取任务类型
func (job *RecycleTask) Type() int {
return 0
}
// Creator 获取创建者ID
func (job *RecycleTask) Creator() uint {
return 0
}
// Model 获取任务的数据库模型
func (job *RecycleTask) Model() *model.Task {
return nil
}
// SetStatus 设定状态
func (job *RecycleTask) SetStatus(status int) {
}
// SetError 设定任务失败信息
func (job *RecycleTask) SetError(err *task.JobError) {
job.Err = err
}
// SetErrorMsg 设定任务失败信息
func (job *RecycleTask) SetErrorMsg(msg string, err error) {
jobErr := &task.JobError{Msg: msg}
if err != nil {
jobErr.Error = err.Error()
}
job.SetError(jobErr)
notifyMsg := mq.Message{
TriggeredBy: job.MasterID,
Event: serializer.SlaveRecycleFailed,
Content: serializer.SlaveRecycleResult{
Error: err.Error(),
},
}
if err = cluster.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), notifyMsg); err != nil {
util.Log().Warning("无法发送回收失败通知到从机, %s", err)
}
}
// GetError 返回任务失败信息
func (job *RecycleTask) GetError() *task.JobError {
return job.Err
}
// Do 开始执行任务
func (job *RecycleTask) Do() {
err := os.RemoveAll(job.Req.Path)
if err != nil {
util.Log().Warning("无法删除中转临时文件[%s], %s", job.Req.Path, err)
job.SetErrorMsg("文件回收失败", err)
return
}
msg := mq.Message{
TriggeredBy: job.MasterID,
Event: serializer.SlaveRecycleSuccess,
Content: serializer.SlaveRecycleResult{},
}
if err = cluster.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), msg); err != nil {
util.Log().Warning("无法发送回收成功通知到从机, %s", err)
}
}

@ -2,6 +2,8 @@ package slavetask
import ( import (
"context" "context"
"os"
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
@ -10,7 +12,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/task" "github.com/cloudreve/Cloudreve/v3/pkg/task"
"github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/cloudreve/Cloudreve/v3/pkg/util"
"os"
) )
// TransferTask 文件中转任务 // TransferTask 文件中转任务
@ -79,8 +80,6 @@ func (job *TransferTask) GetError() *task.JobError {
// Do 开始执行任务 // Do 开始执行任务
func (job *TransferTask) Do() { func (job *TransferTask) Do() {
defer job.Recycle()
fs, err := filesystem.NewAnonymousFileSystem() fs, err := filesystem.NewAnonymousFileSystem()
if err != nil { if err != nil {
job.SetErrorMsg("无法初始化匿名文件系统", err) job.SetErrorMsg("无法初始化匿名文件系统", err)
@ -137,11 +136,3 @@ func (job *TransferTask) Do() {
util.Log().Warning("无法发送转存成功通知到从机, %s", err) util.Log().Warning("无法发送转存成功通知到从机, %s", err)
} }
} }
// Recycle 回收临时文件
func (job *TransferTask) Recycle() {
err := os.Remove(job.Req.Src)
if err != nil {
util.Log().Warning("无法删除中转临时文件[%s], %s", job.Req.Src, err)
}
}

@ -3,7 +3,6 @@ package task
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"os"
"path" "path"
"path/filepath" "path/filepath"
"strings" "strings"
@ -87,8 +86,6 @@ func (job *TransferTask) GetError() *JobError {
// Do 开始执行任务 // Do 开始执行任务
func (job *TransferTask) Do() { func (job *TransferTask) Do() {
defer job.Recycle()
// 创建文件系统 // 创建文件系统
fs, err := filesystem.NewFileSystem(job.User) fs, err := filesystem.NewFileSystem(job.User)
if err != nil { if err != nil {
@ -139,16 +136,6 @@ func (job *TransferTask) Do() {
} }
// Recycle 回收临时文件
func (job *TransferTask) Recycle() {
if job.TaskProps.NodeID == 1 {
err := os.RemoveAll(job.TaskProps.Parent)
if err != nil {
util.Log().Warning("无法删除中转临时目录[%s], %s", job.TaskProps.Parent, err)
}
}
}
// NewTransferTask 新建中转任务 // NewTransferTask 新建中转任务
func NewTransferTask(user uint, src []string, dst, parent string, trim bool, node uint, sizes map[string]uint64) (Job, error) { func NewTransferTask(user uint, src []string, dst, parent string, trim bool, node uint, sizes map[string]uint64) (Job, error) {
creator, err := model.GetActiveUserByID(user) creator, err := model.GetActiveUserByID(user)

@ -212,6 +212,17 @@ func SlaveCreateTransferTask(c *gin.Context) {
} }
} }
// SlaveCreateRecycleTask 从机创建回收任务
func SlaveCreateRecycleTask(c *gin.Context) {
var service serializer.SlaveRecycleReq
if err := c.ShouldBindJSON(&service); err == nil {
res := explorer.CreateRecycleTask(c, &service)
c.JSON(200, res)
} else {
c.JSON(200, ErrorResponse(err))
}
}
// SlaveNotificationPush 处理从机发送的消息推送 // SlaveNotificationPush 处理从机发送的消息推送
func SlaveNotificationPush(c *gin.Context) { func SlaveNotificationPush(c *gin.Context) {
var service node.SlaveNotificationService var service node.SlaveNotificationService

@ -88,6 +88,7 @@ func InitSlaveRouter() *gin.Engine {
task := v3.Group("task") task := v3.Group("task")
{ {
task.PUT("transfer", controllers.SlaveCreateTransferTask) task.PUT("transfer", controllers.SlaveCreateTransferTask)
task.PUT("recycle", controllers.SlaveCreateRecycleTask)
} }
} }
return r return r

@ -33,7 +33,7 @@ func (service *DownloadListService) Finished(c *gin.Context, user *model.User) s
// Downloading 获取正在下载中的任务 // Downloading 获取正在下载中的任务
func (service *DownloadListService) Downloading(c *gin.Context, user *model.User) serializer.Response { func (service *DownloadListService) Downloading(c *gin.Context, user *model.User) serializer.Response {
// 查找下载记录 // 查找下载记录
downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, common.Downloading, common.Paused, common.Ready) downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, common.Downloading, common.Seeding, common.Paused, common.Ready)
intervals := make(map[uint]int) intervals := make(map[uint]int)
for _, download := range downloads { for _, download := range downloads {
if _, ok := intervals[download.ID]; !ok { if _, ok := intervals[download.ID]; !ok {

@ -5,6 +5,10 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http"
"net/url"
"time"
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/cluster"
@ -16,9 +20,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"net/http"
"net/url"
"time"
) )
// SlaveDownloadService 从机文件下載服务 // SlaveDownloadService 从机文件下載服务
@ -165,6 +166,26 @@ func CreateTransferTask(c *gin.Context, req *serializer.SlaveTransferReq) serial
return serializer.ParamErr("未知的主机节点ID", nil) return serializer.ParamErr("未知的主机节点ID", nil)
} }
// CreateRecycleTask 创建从机文件回收任务
func CreateRecycleTask(c *gin.Context, req *serializer.SlaveRecycleReq) serializer.Response {
if id, ok := c.Get("MasterSiteID"); ok {
job := &slavetask.RecycleTask{
Req: req,
MasterID: id.(string),
}
if err := cluster.DefaultController.SubmitTask(job.MasterID, job, req.Hash(job.MasterID), func(job interface{}) {
task.TaskPoll.Submit(job.(task.Job))
}); err != nil {
return serializer.Err(serializer.CodeCreateTaskError, "", err)
}
return serializer.Response{}
}
return serializer.ParamErr("未知的主机节点ID", nil)
}
// SlaveListService 从机上传会话服务 // SlaveListService 从机上传会话服务
type SlaveCreateUploadSessionService struct { type SlaveCreateUploadSessionService struct {
Session serializer.UploadSession `json:"session" binding:"required"` Session serializer.UploadSession `json:"session" binding:"required"`

@ -1,14 +1,15 @@
package user package user
import ( import (
"net/url"
"strings"
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/email" "github.com/cloudreve/Cloudreve/v3/pkg/email"
"github.com/cloudreve/Cloudreve/v3/pkg/hashid" "github.com/cloudreve/Cloudreve/v3/pkg/hashid"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/url"
"strings"
) )
// UserRegisterService 管理用户注册的服务 // UserRegisterService 管理用户注册的服务

Loading…
Cancel
Save