diff --git a/models/download.go b/models/download.go index 0989b14e..ab28c2f9 100644 --- a/models/download.go +++ b/models/download.go @@ -24,6 +24,7 @@ type Download struct { Dst string `gorm:"type:text"` // 用户文件系统存储父目录路径 UserID uint // 发起者UID TaskID uint // 对应的转存任务ID + NodeID uint // 处理任务的节点ID // 关联模型 User *User `gorm:"PRELOAD:false,association_autoupdate:false"` diff --git a/pkg/aria2/aria2.go b/pkg/aria2/aria2.go index 9aef4953..fe0cc721 100644 --- a/pkg/aria2/aria2.go +++ b/pkg/aria2/aria2.go @@ -4,13 +4,13 @@ import ( "sync" model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor" "github.com/cloudreve/Cloudreve/v3/pkg/balancer" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" ) // Instance 默认使用的Aria2处理实例 -var Instance Aria2 = &DummyAria2{} +var Instance common.Aria2 = &common.DummyAria2{} // LB 获取 Aria2 节点的负载均衡器 var LB balancer.Balancer @@ -18,82 +18,6 @@ var LB balancer.Balancer // Lock Instance的读写锁 var Lock sync.RWMutex -// EventNotifier 任务状态更新通知处理器 -var EventNotifier = &Notifier{} - -// Aria2 离线下载处理接口 -type Aria2 interface { - // Init 初始化客户端连接 - Init() error - // CreateTask 创建新的任务 - CreateTask(task *model.Download, options map[string]interface{}) (string, error) - // 返回状态信息 - Status(task *model.Download) (rpc.StatusInfo, error) - // 取消任务 - Cancel(task *model.Download) error - // 选择要下载的文件 - Select(task *model.Download, files []int) error -} - -const ( - // URLTask 从URL添加的任务 - URLTask = iota - // TorrentTask 种子任务 - TorrentTask -) - -const ( - // Ready 准备就绪 - Ready = iota - // Downloading 下载中 - Downloading - // Paused 暂停中 - Paused - // Error 出错 - Error - // Complete 完成 - Complete - // Canceled 取消/停止 - Canceled - // Unknown 未知状态 - Unknown -) - -var ( - // ErrNotEnabled 功能未开启错误 - ErrNotEnabled = serializer.NewError(serializer.CodeNoPermissionErr, "离线下载功能未开启", nil) - // ErrUserNotFound 未找到下载任务创建者 - ErrUserNotFound = serializer.NewError(serializer.CodeNotFound, "无法找到任务创建者", nil) -) - -// DummyAria2 未开启Aria2功能时使用的默认处理器 -type DummyAria2 struct { -} - -func (instance *DummyAria2) Init() error { - return nil -} - -// CreateTask 创建新任务,此处直接返回未开启错误 -func (instance *DummyAria2) CreateTask(model *model.Download, options map[string]interface{}) (string, error) { - return "", ErrNotEnabled -} - -// Status 返回未开启错误 -func (instance *DummyAria2) Status(task *model.Download) (rpc.StatusInfo, error) { - return rpc.StatusInfo{}, ErrNotEnabled -} - -// Cancel 返回未开启错误 -func (instance *DummyAria2) Cancel(task *model.Download) error { - return ErrNotEnabled -} - -// Select 返回未开启错误 -func (instance *DummyAria2) Select(task *model.Download, files []int) error { - return ErrNotEnabled -} - // Init 初始化 func Init(isReload bool) { Lock.Lock() @@ -102,31 +26,11 @@ func Init(isReload bool) { if !isReload { // 从数据库中读取未完成任务,创建监控 - unfinished := model.GetDownloadsByStatus(Ready, Paused, Downloading) + unfinished := model.GetDownloadsByStatus(common.Ready, common.Paused, common.Downloading) for i := 0; i < len(unfinished); i++ { // 创建任务监控 - NewMonitor(&unfinished[i]) + monitor.NewMonitor(&unfinished[i]) } } } - -// getStatus 将给定的状态字符串转换为状态标识数字 -func getStatus(status string) int { - switch status { - case "complete": - return Complete - case "active": - return Downloading - case "waiting": - return Ready - case "paused": - return Paused - case "error": - return Error - case "removed": - return Canceled - default: - return Unknown - } -} diff --git a/pkg/aria2/aria2_test.go b/pkg/aria2/aria2_test.go index 51605a1a..dfd71a39 100644 --- a/pkg/aria2/aria2_test.go +++ b/pkg/aria2/aria2_test.go @@ -6,6 +6,7 @@ import ( "github.com/DATA-DOG/go-sqlmock" model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor" "github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" @@ -37,7 +38,7 @@ func TestDummyAria2(t *testing.T) { } func TestInit(t *testing.T) { - MAX_RETRY = 0 + monitor.MAX_RETRY = 0 asserts := assert.New(t) cache.Set("setting_aria2_token", "1", 0) cache.Set("setting_aria2_call_timeout", "5", 0) @@ -81,11 +82,11 @@ func TestInit(t *testing.T) { func TestGetStatus(t *testing.T) { asserts := assert.New(t) - asserts.Equal(4, getStatus("complete")) - asserts.Equal(1, getStatus("active")) - asserts.Equal(0, getStatus("waiting")) - asserts.Equal(2, getStatus("paused")) - asserts.Equal(3, getStatus("error")) - asserts.Equal(5, getStatus("removed")) - asserts.Equal(6, getStatus("?")) + asserts.Equal(4, GetStatus("complete")) + asserts.Equal(1, GetStatus("active")) + asserts.Equal(0, GetStatus("waiting")) + asserts.Equal(2, GetStatus("paused")) + asserts.Equal(3, GetStatus("error")) + asserts.Equal(5, GetStatus("removed")) + asserts.Equal(6, GetStatus("?")) } diff --git a/pkg/aria2/caller.go b/pkg/aria2/caller.go index 9f15c878..f3d58e38 100644 --- a/pkg/aria2/caller.go +++ b/pkg/aria2/caller.go @@ -8,6 +8,7 @@ import ( "time" model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/util" ) @@ -33,7 +34,7 @@ func (client *RPCService) Init(server, secret string, timeout int, options map[s Options: options, } caller, err := rpc.New(context.Background(), server, secret, time.Duration(timeout)*time.Second, - EventNotifier) + common.EventNotifier) client.Caller = caller return err } diff --git a/pkg/aria2/common/common.go b/pkg/aria2/common/common.go new file mode 100644 index 00000000..3eaa746b --- /dev/null +++ b/pkg/aria2/common/common.go @@ -0,0 +1,110 @@ +package common + +import ( + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" +) + +// Aria2 离线下载处理接口 +type Aria2 interface { + // Init 初始化客户端连接 + Init() error + // CreateTask 创建新的任务 + CreateTask(task *model.Download, options map[string]interface{}) (string, error) + // 返回状态信息 + Status(task *model.Download) (rpc.StatusInfo, error) + // 取消任务 + Cancel(task *model.Download) error + // 选择要下载的文件 + Select(task *model.Download, files []int) error + // GetConfig 获取离线下载配置 + GetConfig() model.Aria2Option +} + +const ( + // URLTask 从URL添加的任务 + URLTask = iota + // TorrentTask 种子任务 + TorrentTask +) + +const ( + // Ready 准备就绪 + Ready = iota + // Downloading 下载中 + Downloading + // Paused 暂停中 + Paused + // Error 出错 + Error + // Complete 完成 + Complete + // Canceled 取消/停止 + Canceled + // Unknown 未知状态 + Unknown +) + +var ( + // ErrNotEnabled 功能未开启错误 + ErrNotEnabled = serializer.NewError(serializer.CodeNoPermissionErr, "离线下载功能未开启", nil) + // ErrUserNotFound 未找到下载任务创建者 + ErrUserNotFound = serializer.NewError(serializer.CodeNotFound, "无法找到任务创建者", nil) +) + +// DummyAria2 未开启Aria2功能时使用的默认处理器 +type DummyAria2 struct { +} + +func (instance *DummyAria2) Init() error { + return nil +} + +// CreateTask 创建新任务,此处直接返回未开启错误 +func (instance *DummyAria2) CreateTask(model *model.Download, options map[string]interface{}) (string, error) { + return "", ErrNotEnabled +} + +// Status 返回未开启错误 +func (instance *DummyAria2) Status(task *model.Download) (rpc.StatusInfo, error) { + return rpc.StatusInfo{}, ErrNotEnabled +} + +// Cancel 返回未开启错误 +func (instance *DummyAria2) Cancel(task *model.Download) error { + return ErrNotEnabled +} + +// Select 返回未开启错误 +func (instance *DummyAria2) Select(task *model.Download, files []int) error { + return ErrNotEnabled +} + +// GetConfig 返回空的 +func (instance *DummyAria2) GetConfig() model.Aria2Option { + return model.Aria2Option{} +} + +// GetStatus 将给定的状态字符串转换为状态标识数字 +func GetStatus(status string) int { + switch status { + case "complete": + return Complete + case "active": + return Downloading + case "waiting": + return Ready + case "paused": + return Paused + case "error": + return Error + case "removed": + return Canceled + default: + return Unknown + } +} + +// EventNotifier 任务状态更新通知处理器 +var EventNotifier = &Notifier{} diff --git a/pkg/aria2/notification.go b/pkg/aria2/common/notification.go similarity index 91% rename from pkg/aria2/notification.go rename to pkg/aria2/common/notification.go index e2ead914..0a52c7f5 100644 --- a/pkg/aria2/notification.go +++ b/pkg/aria2/common/notification.go @@ -1,4 +1,4 @@ -package aria2 +package common import ( "sync" @@ -6,7 +6,7 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" ) -// Notifier aria2实践通知处理 +// Notifier aria2事件通知处理 type Notifier struct { Subscribes sync.Map } @@ -62,3 +62,9 @@ func (notifier *Notifier) OnDownloadError(events []rpc.Event) { func (notifier *Notifier) OnBtDownloadComplete(events []rpc.Event) { notifier.Notify(events, Complete) } + +// StatusEvent 状态改变事件 +type StatusEvent struct { + GID string + Status int +} diff --git a/pkg/aria2/monitor.go b/pkg/aria2/monitor/monitor.go similarity index 86% rename from pkg/aria2/monitor.go rename to pkg/aria2/monitor/monitor.go index 1667f470..12abfbf4 100644 --- a/pkg/aria2/monitor.go +++ b/pkg/aria2/monitor/monitor.go @@ -1,4 +1,4 @@ -package aria2 +package monitor import ( "context" @@ -10,7 +10,9 @@ import ( "time" model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" @@ -23,32 +25,31 @@ type Monitor struct { Task *model.Download Interval time.Duration - notifier chan StatusEvent + notifier chan common.StatusEvent + node cluster.Node retried int } -// StatusEvent 状态改变事件 -type StatusEvent struct { - GID string - Status int -} - var MAX_RETRY = 10 -// NewMonitor 新建上传状态监控 +// NewMonitor 新建离线下载状态监控 func NewMonitor(task *model.Download) { monitor := &Monitor{ Task: task, - Interval: time.Duration(model.GetIntSetting("aria2_interval", 10)) * time.Second, - notifier: make(chan StatusEvent), + notifier: make(chan common.StatusEvent), + node: cluster.Default.GetNodeByID(task.NodeID), + } + if monitor.node != nil { + monitor.Interval = time.Duration(monitor.node.GetAria2Instance().GetConfig().Interval) * time.Second + go monitor.Loop() + common.EventNotifier.Subscribe(monitor.notifier, monitor.Task.GID) } - go monitor.Loop() - EventNotifier.Subscribe(monitor.notifier, monitor.Task.GID) } // Loop 开启监控循环 func (monitor *Monitor) Loop() { - defer EventNotifier.Unsubscribe(monitor.Task.GID) + defer common.EventNotifier.Unsubscribe(monitor.Task.GID) + fmt.Println(cluster.Default) // 首次循环立即更新 interval := time.Duration(0) @@ -70,9 +71,7 @@ func (monitor *Monitor) Loop() { // Update 更新状态,返回值表示是否退出监控 func (monitor *Monitor) Update() bool { - Lock.RLock() - status, err := Instance.Status(monitor.Task) - Lock.RUnlock() + status, err := monitor.node.GetAria2Instance().Status(monitor.Task) if err != nil { monitor.retried++ @@ -115,7 +114,7 @@ func (monitor *Monitor) Update() bool { case "active", "waiting", "paused": return false case "removed": - monitor.Task.Status = Canceled + monitor.Task.Status = common.Canceled monitor.Task.Save() monitor.RemoveTempFolder() return true @@ -130,7 +129,7 @@ func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error { originSize := monitor.Task.TotalSize monitor.Task.GID = status.Gid - monitor.Task.Status = getStatus(status.Status) + monitor.Task.Status = common.GetStatus(status.Status) // 文件大小、已下载大小 total, err := strconv.ParseUint(status.TotalLength, 10, 64) @@ -164,9 +163,7 @@ func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error { // 文件大小更新后,对文件限制等进行校验 if err := monitor.ValidateFile(); err != nil { // 验证失败时取消任务 - Lock.RLock() - Instance.Cancel(monitor.Task) - Lock.RUnlock() + monitor.node.GetAria2Instance().Cancel(monitor.Task) return err } } @@ -179,7 +176,7 @@ func (monitor *Monitor) ValidateFile() error { // 找到任务创建者 user := monitor.Task.GetOwner() if user == nil { - return ErrUserNotFound + return common.ErrUserNotFound } // 创建文件系统 @@ -269,7 +266,7 @@ func (monitor *Monitor) Complete(status rpc.StatusInfo) bool { } func (monitor *Monitor) setErrorStatus(err error) { - monitor.Task.Status = Error + monitor.Task.Status = common.Error monitor.Task.Error = err.Error() monitor.Task.Save() } diff --git a/pkg/aria2/monitor_test.go b/pkg/aria2/monitor/monitor_test.go similarity index 66% rename from pkg/aria2/monitor_test.go rename to pkg/aria2/monitor/monitor_test.go index 91728942..06d5cd4e 100644 --- a/pkg/aria2/monitor_test.go +++ b/pkg/aria2/monitor/monitor_test.go @@ -1,4 +1,4 @@ -package aria2 +package monitor import ( "errors" @@ -7,6 +7,8 @@ import ( "github.com/DATA-DOG/go-sqlmock" model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" @@ -44,13 +46,13 @@ func (m InstanceMock) Select(task *model.Download, files []int) error { func TestNewMonitor(t *testing.T) { asserts := assert.New(t) NewMonitor(&model.Download{GID: "gid"}) - _, ok := EventNotifier.Subscribes.Load("gid") + _, ok := common.EventNotifier.Subscribes.Load("gid") asserts.True(ok) } func TestMonitor_Loop(t *testing.T) { asserts := assert.New(t) - notifier := make(chan StatusEvent) + notifier := make(chan common.StatusEvent) MAX_RETRY = 0 monitor := &Monitor{ Task: &model.Download{GID: "gid"}, @@ -79,7 +81,7 @@ func TestMonitor_Update(t *testing.T) { testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, errors.New("error")) file, _ := util.CreatNestedFile("TestMonitor_Update/1") file.Close() - Instance = testInstance + aria2.Instance = testInstance asserts.False(monitor.Update()) asserts.True(monitor.Update()) testInstance.AssertExpectations(t) @@ -93,12 +95,12 @@ func TestMonitor_Update(t *testing.T) { FollowedBy: []string{"1"}, }, nil) monitor.Task.ID = 1 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - Instance = testInstance + aria2.mock.ExpectBegin() + aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + aria2.mock.ExpectCommit() + aria2.Instance = testInstance asserts.False(monitor.Update()) - asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(aria2.mock.ExpectationsWereMet()) testInstance.AssertExpectations(t) asserts.EqualValues("1", monitor.Task.GID) } @@ -108,12 +110,12 @@ func TestMonitor_Update(t *testing.T) { testInstance := new(InstanceMock) testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, nil) monitor.Task.ID = 1 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - Instance = testInstance + aria2.mock.ExpectBegin() + aria2.mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error")) + aria2.mock.ExpectRollback() + aria2.Instance = testInstance asserts.True(monitor.Update()) - asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(aria2.mock.ExpectationsWereMet()) testInstance.AssertExpectations(t) } @@ -121,12 +123,12 @@ func TestMonitor_Update(t *testing.T) { { testInstance := new(InstanceMock) testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "?"}, nil) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - Instance = testInstance + aria2.mock.ExpectBegin() + aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + aria2.mock.ExpectCommit() + aria2.Instance = testInstance asserts.True(monitor.Update()) - asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(aria2.mock.ExpectationsWereMet()) testInstance.AssertExpectations(t) } @@ -134,15 +136,15 @@ func TestMonitor_Update(t *testing.T) { { testInstance := new(InstanceMock) testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "removed"}, nil) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - Instance = testInstance + aria2.mock.ExpectBegin() + aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + aria2.mock.ExpectCommit() + aria2.mock.ExpectBegin() + aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + aria2.mock.ExpectCommit() + aria2.Instance = testInstance asserts.True(monitor.Update()) - asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(aria2.mock.ExpectationsWereMet()) testInstance.AssertExpectations(t) } @@ -150,12 +152,12 @@ func TestMonitor_Update(t *testing.T) { { testInstance := new(InstanceMock) testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "active"}, nil) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - Instance = testInstance + aria2.mock.ExpectBegin() + aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + aria2.mock.ExpectCommit() + aria2.Instance = testInstance asserts.False(monitor.Update()) - asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(aria2.mock.ExpectationsWereMet()) testInstance.AssertExpectations(t) } @@ -163,12 +165,12 @@ func TestMonitor_Update(t *testing.T) { { testInstance := new(InstanceMock) testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "error"}, nil) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - Instance = testInstance + aria2.mock.ExpectBegin() + aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + aria2.mock.ExpectCommit() + aria2.Instance = testInstance asserts.True(monitor.Update()) - asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(aria2.mock.ExpectationsWereMet()) testInstance.AssertExpectations(t) } @@ -176,12 +178,12 @@ func TestMonitor_Update(t *testing.T) { { testInstance := new(InstanceMock) testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "complete"}, nil) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - Instance = testInstance + aria2.mock.ExpectBegin() + aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + aria2.mock.ExpectCommit() + aria2.Instance = testInstance asserts.True(monitor.Update()) - asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(aria2.mock.ExpectationsWereMet()) testInstance.AssertExpectations(t) } } @@ -198,21 +200,21 @@ func TestMonitor_UpdateTaskInfo(t *testing.T) { // 失败 { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() + aria2.mock.ExpectBegin() + aria2.mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error")) + aria2.mock.ExpectRollback() err := monitor.UpdateTaskInfo(rpc.StatusInfo{}) - asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(aria2.mock.ExpectationsWereMet()) asserts.Error(err) } // 更新成功,无需校验 { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() + aria2.mock.ExpectBegin() + aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + aria2.mock.ExpectCommit() err := monitor.UpdateTaskInfo(rpc.StatusInfo{}) - asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(aria2.mock.ExpectationsWereMet()) asserts.NoError(err) } @@ -220,12 +222,12 @@ func TestMonitor_UpdateTaskInfo(t *testing.T) { { testInstance := new(InstanceMock) testInstance.On("Cancel", testMock.Anything).Return(nil) - Instance = testInstance - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() + aria2.Instance = testInstance + aria2.mock.ExpectBegin() + aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + aria2.mock.ExpectCommit() err := monitor.UpdateTaskInfo(rpc.StatusInfo{TotalLength: "1"}) - asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(aria2.mock.ExpectationsWereMet()) asserts.Error(err) testInstance.AssertExpectations(t) } @@ -308,17 +310,17 @@ func TestMonitor_Complete(t *testing.T) { } cache.Set("setting_max_worker_num", "1", 0) - mock.ExpectQuery("SELECT(.+)tasks").WillReturnRows(sqlmock.NewRows([]string{"id"})) + aria2.mock.ExpectQuery("SELECT(.+)tasks").WillReturnRows(sqlmock.NewRows([]string{"id"})) task.Init() - mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectQuery("SELECT(.+)policies").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() + aria2.mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + aria2.mock.ExpectQuery("SELECT(.+)policies").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + aria2.mock.ExpectBegin() + aria2.mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(1, 1)) + aria2.mock.ExpectCommit() - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() + aria2.mock.ExpectBegin() + aria2.mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1)) + aria2.mock.ExpectCommit() asserts.True(monitor.Complete(rpc.StatusInfo{})) - asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(aria2.mock.ExpectationsWereMet()) } diff --git a/pkg/cluster/master.go b/pkg/cluster/master.go index 0b6ce3fc..b0074c62 100644 --- a/pkg/cluster/master.go +++ b/pkg/cluster/master.go @@ -3,13 +3,15 @@ package cluster import ( "context" "encoding/json" - "fmt" model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/util" "net/url" + "path/filepath" + "strconv" + "strings" "sync" "time" ) @@ -49,6 +51,13 @@ func (node *MasterNode) Init(nodeModel *model.Node) { node.lock.RUnlock() } +func (node *MasterNode) ID() uint { + node.lock.RLock() + defer node.lock.RUnlock() + + return node.Model.ID +} + func (node *MasterNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) { return &serializer.NodePingResp{}, nil } @@ -76,15 +85,15 @@ func (node *MasterNode) IsActive() bool { } // GetAria2Instance 获取主机Aria2实例 -func (node *MasterNode) GetAria2Instance() aria2.Aria2 { +func (node *MasterNode) GetAria2Instance() common.Aria2 { if !node.Model.Aria2Enabled { - return &aria2.DummyAria2{} + return &common.DummyAria2{} } node.lock.RLock() defer node.lock.RUnlock() if !node.aria2RPC.Initialized { - return &aria2.DummyAria2{} + return &common.DummyAria2{} } return &node.aria2RPC @@ -122,25 +131,76 @@ func (r *rpcService) Init() error { Options: globalOptions, } timeout := r.parent.Model.Aria2OptionsSerialized.Timeout - caller, err := rpc.New(context.Background(), server.String(), r.parent.Model.Aria2OptionsSerialized.Token, time.Duration(timeout)*time.Second, aria2.EventNotifier) + caller, err := rpc.New(context.Background(), server.String(), r.parent.Model.Aria2OptionsSerialized.Token, time.Duration(timeout)*time.Second, common.EventNotifier) r.Caller = caller r.Initialized = true return err } -func (r *rpcService) CreateTask(task *model.Download, options map[string]interface{}) (string, error) { - return "", fmt.Errorf("some error #%d", r.parent.Model.ID) +func (r *rpcService) CreateTask(task *model.Download, groupOptions map[string]interface{}) (string, error) { + r.parent.lock.RLock() + // 生成存储路径 + path := filepath.Join( + r.parent.Model.Aria2OptionsSerialized.TempPath, + "aria2", + strconv.FormatInt(time.Now().UnixNano(), 10), + ) + r.parent.lock.RUnlock() + + // 创建下载任务 + options := map[string]interface{}{ + "dir": path, + } + for k, v := range r.options.Options { + options[k] = v + } + for k, v := range groupOptions { + options[k] = v + } + + gid, err := r.Caller.AddURI(task.Source, options) + if err != nil || gid == "" { + return "", err + } + + return gid, nil } func (r *rpcService) Status(task *model.Download) (rpc.StatusInfo, error) { - panic("implement me") + res, err := r.Caller.TellStatus(task.GID) + if err != nil { + // 失败后重试 + util.Log().Debug("无法获取离线下载状态,%s,10秒钟后重试", err) + time.Sleep(time.Duration(10) * time.Second) + res, err = r.Caller.TellStatus(task.GID) + } + + return res, err } func (r *rpcService) Cancel(task *model.Download) error { - panic("implement me") + // 取消下载任务 + _, err := r.Caller.Remove(task.GID) + if err != nil { + util.Log().Warning("无法取消离线下载任务[%s], %s", task.GID, err) + } + + return err } func (r *rpcService) Select(task *model.Download, files []int) error { - panic("implement me") + var selected = make([]string, len(files)) + for i := 0; i < len(files); i++ { + selected[i] = strconv.Itoa(files[i]) + } + _, err := r.Caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")}) + return err +} + +func (r *rpcService) GetConfig() model.Aria2Option { + r.parent.lock.RLock() + defer r.parent.lock.RUnlock() + + return r.parent.Model.Aria2OptionsSerialized } diff --git a/pkg/cluster/node.go b/pkg/cluster/node.go index 041c11ae..021d133a 100644 --- a/pkg/cluster/node.go +++ b/pkg/cluster/node.go @@ -2,17 +2,25 @@ package cluster import ( model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" ) type Node interface { + // Init a node from database model Init(node *model.Node) + // Check if given feature is enabled IsFeatureEnabled(feature string) bool + // Subscribe node status change to a callback function SubscribeStatusChange(callback func(isActive bool, id uint)) + // Ping the node Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) + // Returns if the node is active IsActive() bool - GetAria2Instance() aria2.Aria2 + // Get instances for aria2 calls + GetAria2Instance() common.Aria2 + // Returns unique id of this node + ID() uint } func getNodeFromDBModel(node *model.Node) Node { diff --git a/pkg/cluster/pool.go b/pkg/cluster/pool.go index de364de0..3caf3c2c 100644 --- a/pkg/cluster/pool.go +++ b/pkg/cluster/pool.go @@ -14,7 +14,11 @@ var featureGroup = []string{"aria2"} // Pool 节点池 type Pool interface { + // Returns active node selected by given feature and load balancer BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, Node) + + // Returns node by ID + GetNodeByID(id uint) Node } // NodePool 通用节点池 @@ -53,6 +57,17 @@ func (pool *NodePool) buildIndexMap() { pool.lock.Unlock() } +func (pool *NodePool) GetNodeByID(id uint) Node { + pool.lock.RLock() + defer pool.lock.RUnlock() + + if node, ok := pool.active[id]; ok { + return node + } + + return pool.inactive[id] +} + func (pool *NodePool) nodeStatusChange(isActive bool, id uint) { util.Log().Debug("从机节点 [ID=%d] 状态变更 [Active=%t]", id, isActive) pool.lock.Lock() diff --git a/pkg/cluster/slave.go b/pkg/cluster/slave.go index 247942a1..343ebc6d 100644 --- a/pkg/cluster/slave.go +++ b/pkg/cluster/slave.go @@ -4,7 +4,7 @@ import ( "encoding/json" "errors" model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" @@ -185,6 +185,13 @@ loop: } // GetAria2Instance 获取从机Aria2实例 -func (node *SlaveNode) GetAria2Instance() aria2.Aria2 { +func (node *SlaveNode) GetAria2Instance() common.Aria2 { return nil } + +func (node *SlaveNode) ID() uint { + node.lock.RLock() + defer node.lock.RUnlock() + + return node.Model.ID +} diff --git a/routers/controllers/aria2.go b/routers/controllers/aria2.go index b2bc6d6d..25a8fb09 100644 --- a/routers/controllers/aria2.go +++ b/routers/controllers/aria2.go @@ -3,7 +3,7 @@ package controllers import ( "context" - ariaCall "github.com/cloudreve/Cloudreve/v3/pkg/aria2" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/service/aria2" "github.com/cloudreve/Cloudreve/v3/service/explorer" "github.com/gin-gonic/gin" @@ -13,7 +13,7 @@ import ( func AddAria2URL(c *gin.Context) { var addService aria2.AddURLService if err := c.ShouldBindJSON(&addService); err == nil { - res := addService.Add(c, ariaCall.URLTask) + res := addService.Add(c, common.URLTask) c.JSON(200, res) } else { c.JSON(200, ErrorResponse(err)) @@ -52,7 +52,7 @@ func AddAria2Torrent(c *gin.Context) { if err := c.ShouldBindJSON(&addService); err == nil { addService.URL = res.Data.(string) - res := addService.Add(c, ariaCall.URLTask) + res := addService.Add(c, common.URLTask) c.JSON(200, res) } else { c.JSON(200, ErrorResponse(err)) diff --git a/service/aria2/add.go b/service/aria2/add.go index 07865bfa..50ad4de0 100644 --- a/service/aria2/add.go +++ b/service/aria2/add.go @@ -3,6 +3,8 @@ package aria2 import ( model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/aria2" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" @@ -36,7 +38,7 @@ func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Respo // 创建任务 task := &model.Download{ - Status: aria2.Ready, + Status: common.Ready, Type: taskType, Dst: service.Dst, UserID: fs.User.ID, @@ -61,14 +63,14 @@ func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Respo } task.GID = gid + task.NodeID = node.ID() _, err = task.Create() if err != nil { return serializer.DBErr("任务创建失败", err) } // 创建任务监控 - aria2.NewMonitor(task) + monitor.NewMonitor(task) - aria2.Lock.RUnlock() return serializer.Response{} } diff --git a/service/aria2/manage.go b/service/aria2/manage.go index d93bbca4..b606c71a 100644 --- a/service/aria2/manage.go +++ b/service/aria2/manage.go @@ -3,6 +3,7 @@ package aria2 import ( model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/aria2" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/gin-gonic/gin" ) @@ -25,14 +26,14 @@ type DownloadListService struct { // Finished 获取已完成的任务 func (service *DownloadListService) Finished(c *gin.Context, user *model.User) serializer.Response { // 查找下载记录 - downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, aria2.Error, aria2.Complete, aria2.Canceled, aria2.Unknown) + downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, common.Error, common.Complete, common.Canceled, common.Unknown) return serializer.BuildFinishedListResponse(downloads) } // Downloading 获取正在下载中的任务 func (service *DownloadListService) Downloading(c *gin.Context, user *model.User) serializer.Response { // 查找下载记录 - downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, aria2.Downloading, aria2.Paused, aria2.Ready) + downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, common.Downloading, common.Paused, common.Ready) return serializer.BuildDownloadingResponse(downloads) } @@ -47,7 +48,7 @@ func (service *DownloadTaskService) Delete(c *gin.Context) serializer.Response { return serializer.Err(serializer.CodeNotFound, "下载记录不存在", err) } - if download.Status >= aria2.Error { + if download.Status >= common.Error { // 如果任务已完成,则删除任务记录 if err := download.Delete(); err != nil { return serializer.Err(serializer.CodeDBError, "任务记录删除失败", err) @@ -76,7 +77,7 @@ func (service *SelectFileService) Select(c *gin.Context) serializer.Response { return serializer.Err(serializer.CodeNotFound, "下载记录不存在", err) } - if download.StatusInfo.BitTorrent.Mode != "multi" || (download.Status != aria2.Downloading && download.Status != aria2.Paused) { + if download.StatusInfo.BitTorrent.Mode != "multi" || (download.Status != common.Downloading && download.Status != common.Paused) { return serializer.Err(serializer.CodeNoPermissionErr, "此下载任务无法选取文件", err) }