diff --git a/middleware/cluster_test.go b/middleware/cluster_test.go index e1e61e3..0ccdcb6 100644 --- a/middleware/cluster_test.go +++ b/middleware/cluster_test.go @@ -6,12 +6,10 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/cloudreve/Cloudreve/v3/pkg/mocks" "github.com/gin-gonic/gin" "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" "net/http/httptest" "testing" ) @@ -79,46 +77,12 @@ func TestSlaveRPCSignRequired(t *testing.T) { } } -type SlaveControllerMock struct { - testMock.Mock -} - -func (s SlaveControllerMock) HandleHeartBeat(pingReq *serializer.NodePingReq) (serializer.NodePingResp, error) { - args := s.Called(pingReq) - return args.Get(0).(serializer.NodePingResp), args.Error(1) -} - -func (s SlaveControllerMock) GetAria2Instance(s2 string) (common.Aria2, error) { - args := s.Called(s2) - return args.Get(0).(common.Aria2), args.Error(1) -} - -func (s SlaveControllerMock) SendNotification(s3 string, s2 string, message mq.Message) error { - args := s.Called(s3, s2, message) - return args.Error(0) -} - -func (s SlaveControllerMock) SubmitTask(s3 string, i interface{}, s2 string, f func(interface{})) error { - args := s.Called(s3, i, s2, f) - return args.Error(0) -} - -func (s SlaveControllerMock) GetMasterInfo(s2 string) (*cluster.MasterInfo, error) { - args := s.Called(s2) - return args.Get(0).(*cluster.MasterInfo), args.Error(1) -} - -func (s SlaveControllerMock) GetOneDriveToken(s2 string, u uint) (string, error) { - args := s.Called(s2, u) - return args.String(0), args.Error(1) -} - func TestUseSlaveAria2Instance(t *testing.T) { a := assert.New(t) // MasterSiteID not set { - testController := &SlaveControllerMock{} + testController := &mocks.SlaveControllerMock{} useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController) c, _ := gin.CreateTestContext(httptest.NewRecorder()) c.Request = httptest.NewRequest("GET", "/", nil) @@ -128,7 +92,7 @@ func TestUseSlaveAria2Instance(t *testing.T) { // Cannot get aria2 instances { - testController := &SlaveControllerMock{} + testController := &mocks.SlaveControllerMock{} useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController) c, _ := gin.CreateTestContext(httptest.NewRecorder()) c.Request = httptest.NewRequest("GET", "/", nil) @@ -141,7 +105,7 @@ func TestUseSlaveAria2Instance(t *testing.T) { // Success { - testController := &SlaveControllerMock{} + testController := &mocks.SlaveControllerMock{} useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController) c, _ := gin.CreateTestContext(httptest.NewRecorder()) c.Request = httptest.NewRequest("GET", "/", nil) diff --git a/pkg/aria2/aria2.go b/pkg/aria2/aria2.go index d7f9abe..ef2f0df 100644 --- a/pkg/aria2/aria2.go +++ b/pkg/aria2/aria2.go @@ -3,6 +3,8 @@ package aria2 import ( "context" "fmt" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" "net/url" "sync" "time" @@ -42,7 +44,7 @@ func Init(isReload bool) { for i := 0; i < len(unfinished); i++ { // 创建任务监控 - monitor.NewMonitor(&unfinished[i]) + monitor.NewMonitor(&unfinished[i], cluster.Default, mq.GlobalMQ) } } } diff --git a/pkg/aria2/common/common_test.go b/pkg/aria2/common/common_test.go new file mode 100644 index 0000000..52c555f --- /dev/null +++ b/pkg/aria2/common/common_test.go @@ -0,0 +1,45 @@ +package common + +import ( + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestDummyAria2(t *testing.T) { + a := assert.New(t) + d := &DummyAria2{} + + a.NoError(d.Init()) + + res, err := d.CreateTask(&model.Download{}, map[string]interface{}{}) + a.Empty(res) + a.Error(err) + + _, err = d.Status(&model.Download{}) + a.Error(err) + + err = d.Cancel(&model.Download{}) + a.Error(err) + + err = d.Select(&model.Download{}, []int{}) + a.Error(err) + + configRes := d.GetConfig() + a.NotEmpty(configRes) + + err = d.DeleteTempFile(&model.Download{}) + a.Error(err) +} + +func TestGetStatus(t *testing.T) { + a := assert.New(t) + + a.Equal(GetStatus("complete"), Complete) + a.Equal(GetStatus("active"), Downloading) + a.Equal(GetStatus("waiting"), Ready) + a.Equal(GetStatus("paused"), Paused) + a.Equal(GetStatus("error"), Error) + a.Equal(GetStatus("removed"), Canceled) + a.Equal(GetStatus("unknown"), Unknown) +} diff --git a/pkg/aria2/monitor/monitor.go b/pkg/aria2/monitor/monitor.go index 7a04411..aec1f25 100644 --- a/pkg/aria2/monitor/monitor.go +++ b/pkg/aria2/monitor/monitor.go @@ -33,29 +33,29 @@ type Monitor struct { var MAX_RETRY = 10 // NewMonitor 新建离线下载状态监控 -func NewMonitor(task *model.Download) { +func NewMonitor(task *model.Download, pool cluster.Pool, mqClient mq.MQ) { monitor := &Monitor{ Task: task, notifier: make(chan mq.Message), - node: cluster.Default.GetNodeByID(task.GetNodeID()), + node: pool.GetNodeByID(task.GetNodeID()), } if monitor.node != nil { monitor.Interval = time.Duration(monitor.node.GetAria2Instance().GetConfig().Interval) * time.Second - go monitor.Loop() + go monitor.Loop(mqClient) - monitor.notifier = mq.GlobalMQ.Subscribe(monitor.Task.GID, 0) + monitor.notifier = mqClient.Subscribe(monitor.Task.GID, 0) } else { monitor.setErrorStatus(errors.New("节点不可用")) } } // Loop 开启监控循环 -func (monitor *Monitor) Loop() { - defer mq.GlobalMQ.Unsubscribe(monitor.Task.GID, monitor.notifier) +func (monitor *Monitor) Loop(mqClient mq.MQ) { + defer mqClient.Unsubscribe(monitor.Task.GID, monitor.notifier) // 首次循环立即更新 - interval := time.Duration(0) + interval := 50 * time.Millisecond for { select { @@ -259,6 +259,7 @@ func (monitor *Monitor) Complete(status rpc.StatusInfo) bool { ) if err != nil { monitor.setErrorStatus(err) + monitor.RemoveTempFolder() return true } diff --git a/pkg/aria2/monitor/monitor_test.go b/pkg/aria2/monitor/monitor_test.go index 9d45026..d0760ca 100644 --- a/pkg/aria2/monitor/monitor_test.go +++ b/pkg/aria2/monitor/monitor_test.go @@ -1,326 +1,252 @@ package monitor import ( + "database/sql" "errors" - "testing" - "time" - "github.com/DATA-DOG/go-sqlmock" model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/task" - "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/cloudreve/Cloudreve/v3/pkg/mocks" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" testMock "github.com/stretchr/testify/mock" + "testing" ) -type InstanceMock struct { - testMock.Mock -} - -func (m InstanceMock) CreateTask(task *model.Download, options map[string]interface{}) error { - args := m.Called(task, options) - return args.Error(0) -} - -func (m InstanceMock) Status(task *model.Download) (rpc.StatusInfo, error) { - args := m.Called(task) - return args.Get(0).(rpc.StatusInfo), args.Error(1) -} - -func (m InstanceMock) Cancel(task *model.Download) error { - args := m.Called(task) - return args.Error(0) -} - -func (m InstanceMock) Select(task *model.Download, files []int) error { - args := m.Called(task, files) - return args.Error(0) -} - -func TestNewMonitor(t *testing.T) { - asserts := assert.New(t) - NewMonitor(&model.Download{GID: "gid"}) - _, ok := common.EventNotifier.Subscribes.Load("gid") - asserts.True(ok) -} +var mock sqlmock.Sqlmock -func TestMonitor_Loop(t *testing.T) { - asserts := assert.New(t) - notifier := make(chan common.StatusEvent) - MAX_RETRY = 0 - monitor := &Monitor{ - Task: &model.Download{GID: "gid"}, - Interval: time.Duration(1) * time.Second, - notifier: notifier, +// TestMain 初始化数据库Mock +func TestMain(m *testing.M) { + var db *sql.DB + var err error + db, mock, err = sqlmock.New() + if err != nil { + panic("An error was not expected when opening a stub database connection") } - asserts.NotPanics(func() { - monitor.Loop() - }) + model.DB, _ = gorm.Open("mysql", db) + defer db.Close() + m.Run() } -func TestMonitor_Update(t *testing.T) { - asserts := assert.New(t) - monitor := &Monitor{ - Task: &model.Download{ - GID: "gid", - Parent: "TestMonitor_Update", - }, - Interval: time.Duration(1) * time.Second, - } - - // 无法获取状态 - { - MAX_RETRY = 1 - testInstance := new(InstanceMock) - testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{}, errors.New("error")) - file, _ := util.CreatNestedFile("TestMonitor_Update/1") - file.Close() - aria2.Instance = testInstance - asserts.False(monitor.Update()) - asserts.True(monitor.Update()) - testInstance.AssertExpectations(t) - asserts.False(util.Exists("TestMonitor_Update")) - } - - // 磁力链下载重定向 - { - testInstance := new(InstanceMock) - testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{ - FollowedBy: []string{"1"}, - }, nil) - monitor.Task.ID = 1 - aria2.mock.ExpectBegin() - aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - aria2.mock.ExpectCommit() - aria2.Instance = testInstance - asserts.False(monitor.Update()) - asserts.NoError(aria2.mock.ExpectationsWereMet()) - testInstance.AssertExpectations(t) - asserts.EqualValues("1", monitor.Task.GID) - } +func TestNewMonitor(t *testing.T) { + a := assert.New(t) + mockMQ := mq.NewMQ() - // 无法更新任务信息 + // node not available { - testInstance := new(InstanceMock) - testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{}, nil) - monitor.Task.ID = 1 - aria2.mock.ExpectBegin() - aria2.mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error")) - aria2.mock.ExpectRollback() - aria2.Instance = testInstance - asserts.True(monitor.Update()) - asserts.NoError(aria2.mock.ExpectationsWereMet()) - testInstance.AssertExpectations(t) + mockPool := &mocks.NodePoolMock{} + mockPool.On("GetNodeByID", uint(1)).Return(nil) + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + task := &model.Download{ + Model: gorm.Model{ID: 1}, + } + NewMonitor(task, mockPool, mockMQ) + mockPool.AssertExpectations(t) + a.NoError(mock.ExpectationsWereMet()) + a.NotEmpty(task.Error) } - // 返回未知状态 + // success { - testInstance := new(InstanceMock) - testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "?"}, nil) - aria2.mock.ExpectBegin() - aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - aria2.mock.ExpectCommit() - aria2.Instance = testInstance - asserts.True(monitor.Update()) - asserts.NoError(aria2.mock.ExpectationsWereMet()) - testInstance.AssertExpectations(t) - } + mockNode := &mocks.NodeMock{} + mockNode.On("GetAria2Instance").Return(&common.DummyAria2{}) + mockPool := &mocks.NodePoolMock{} + mockPool.On("GetNodeByID", uint(1)).Return(mockNode) - // 返回被取消状态 - { - testInstance := new(InstanceMock) - testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "removed"}, nil) - aria2.mock.ExpectBegin() - aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - aria2.mock.ExpectCommit() - aria2.mock.ExpectBegin() - aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - aria2.mock.ExpectCommit() - aria2.Instance = testInstance - asserts.True(monitor.Update()) - asserts.NoError(aria2.mock.ExpectationsWereMet()) - testInstance.AssertExpectations(t) + task := &model.Download{ + Model: gorm.Model{ID: 1}, + } + NewMonitor(task, mockPool, mockMQ) + mockNode.AssertExpectations(t) + mockPool.AssertExpectations(t) } - // 返回活跃状态 - { - testInstance := new(InstanceMock) - testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "active"}, nil) - aria2.mock.ExpectBegin() - aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - aria2.mock.ExpectCommit() - aria2.Instance = testInstance - asserts.False(monitor.Update()) - asserts.NoError(aria2.mock.ExpectationsWereMet()) - testInstance.AssertExpectations(t) - } +} - // 返回错误状态 +func TestMonitor_Loop(t *testing.T) { + a := assert.New(t) + mockMQ := mq.NewMQ() + mockNode := &mocks.NodeMock{} + mockNode.On("GetAria2Instance").Return(&common.DummyAria2{}) + m := &Monitor{ + retried: MAX_RETRY, + node: mockNode, + Task: &model.Download{Model: gorm.Model{ID: 1}}, + notifier: mockMQ.Subscribe("test", 1), + } + + // into interval loop { - testInstance := new(InstanceMock) - testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "error"}, nil) - aria2.mock.ExpectBegin() - aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - aria2.mock.ExpectCommit() - aria2.Instance = testInstance - asserts.True(monitor.Update()) - asserts.NoError(aria2.mock.ExpectationsWereMet()) - testInstance.AssertExpectations(t) + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + m.Loop(mockMQ) + a.NoError(mock.ExpectationsWereMet()) + a.NotEmpty(m.Task.Error) } - // 返回完成 + // into notifier loop { - testInstance := new(InstanceMock) - testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "complete"}, nil) - aria2.mock.ExpectBegin() - aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - aria2.mock.ExpectCommit() - aria2.Instance = testInstance - asserts.True(monitor.Update()) - asserts.NoError(aria2.mock.ExpectationsWereMet()) - testInstance.AssertExpectations(t) + m.Task.Error = "" + mockMQ.Publish("test", mq.Message{}) + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + m.Loop(mockMQ) + a.NoError(mock.ExpectationsWereMet()) + a.NotEmpty(m.Task.Error) } } -func TestMonitor_UpdateTaskInfo(t *testing.T) { - asserts := assert.New(t) - monitor := &Monitor{ - Task: &model.Download{ - Model: gorm.Model{ID: 1}, - GID: "gid", - Parent: "TestMonitor_UpdateTaskInfo", - }, - } - - // 失败 - { - aria2.mock.ExpectBegin() - aria2.mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error")) - aria2.mock.ExpectRollback() - err := monitor.UpdateTaskInfo(rpc.StatusInfo{}) - asserts.NoError(aria2.mock.ExpectationsWereMet()) - asserts.Error(err) +func TestMonitor_UpdateFailedAfterRetry(t *testing.T) { + a := assert.New(t) + mockNode := &mocks.NodeMock{} + mockNode.On("GetAria2Instance").Return(&common.DummyAria2{}) + m := &Monitor{ + node: mockNode, + Task: &model.Download{Model: gorm.Model{ID: 1}}, } + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() - // 更新成功,无需校验 - { - aria2.mock.ExpectBegin() - aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - aria2.mock.ExpectCommit() - err := monitor.UpdateTaskInfo(rpc.StatusInfo{}) - asserts.NoError(aria2.mock.ExpectationsWereMet()) - asserts.NoError(err) + for i := 0; i < MAX_RETRY; i++ { + a.False(m.Update()) } - // 更新成功,大小改变,需要校验,校验失败 - { - testInstance := new(InstanceMock) - testInstance.On("SlaveCancel", testMock.Anything).Return(nil) - aria2.Instance = testInstance - aria2.mock.ExpectBegin() - aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - aria2.mock.ExpectCommit() - err := monitor.UpdateTaskInfo(rpc.StatusInfo{TotalLength: "1"}) - asserts.NoError(aria2.mock.ExpectationsWereMet()) - asserts.Error(err) - testInstance.AssertExpectations(t) - } + mockNode.AssertExpectations(t) + a.True(m.Update()) + a.NoError(mock.ExpectationsWereMet()) + a.NotEmpty(m.Task.Error) } -func TestMonitor_ValidateFile(t *testing.T) { - asserts := assert.New(t) - monitor := &Monitor{ - Task: &model.Download{ - Model: gorm.Model{ID: 1}, - GID: "gid", - Parent: "TestMonitor_ValidateFile", - }, - } - - // 无法创建文件系统 - { - monitor.Task.User = &model.User{ - Policy: model.Policy{ - Type: "unknown", - }, - } - asserts.Error(monitor.ValidateFile()) - } - - // 文件大小超出容量配额 - { - cache.Set("pack_size_0", uint64(0), 0) - monitor.Task.TotalSize = 11 - monitor.Task.User = &model.User{ - Policy: model.Policy{ - Type: "mock", - }, - Group: model.Group{ - MaxStorage: 10, - }, - } - asserts.Equal(filesystem.ErrInsufficientCapacity, monitor.ValidateFile()) - } +func TestMonitor_UpdateMagentoFollow(t *testing.T) { + a := assert.New(t) + mockAria2 := &mocks.Aria2Mock{} + mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{ + FollowedBy: []string{"next"}, + }, nil) + mockNode := &mocks.NodeMock{} + mockNode.On("GetAria2Instance").Return(mockAria2) + m := &Monitor{ + node: mockNode, + Task: &model.Download{Model: gorm.Model{ID: 1}}, + } + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + a.False(m.Update()) + a.NoError(mock.ExpectationsWereMet()) + a.Equal("next", m.Task.GID) + mockAria2.AssertExpectations(t) +} - // 单文件大小超出容量配额 - { - cache.Set("pack_size_0", uint64(0), 0) - monitor.Task.TotalSize = 10 - monitor.Task.StatusInfo.Files = []rpc.FileInfo{ - { - Selected: "true", - Length: "6", - }, - } - monitor.Task.User = &model.User{ - Policy: model.Policy{ - Type: "mock", - MaxSize: 5, - }, - Group: model.Group{ - MaxStorage: 10, - }, - } - asserts.Equal(filesystem.ErrFileSizeTooBig, monitor.ValidateFile()) - } +func TestMonitor_UpdateFailedToUpdateInfo(t *testing.T) { + a := assert.New(t) + mockAria2 := &mocks.Aria2Mock{} + mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, nil) + mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil) + mockNode := &mocks.NodeMock{} + mockNode.On("GetAria2Instance").Return(mockAria2) + m := &Monitor{ + node: mockNode, + Task: &model.Download{Model: gorm.Model{ID: 1}}, + } + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error")) + mock.ExpectRollback() + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + a.True(m.Update()) + a.NoError(mock.ExpectationsWereMet()) + mockAria2.AssertExpectations(t) + mockNode.AssertExpectations(t) + a.NotEmpty(m.Task.Error) } -func TestMonitor_Complete(t *testing.T) { - asserts := assert.New(t) - monitor := &Monitor{ - Task: &model.Download{ - Model: gorm.Model{ID: 1}, - GID: "gid", - Parent: "TestMonitor_Complete", - StatusInfo: rpc.StatusInfo{ - Files: []rpc.FileInfo{ - { - Selected: "true", - Path: "TestMonitor_Complete", - }, - }, - }, - }, - } +func TestMonitor_UpdateCompleted(t *testing.T) { + a := assert.New(t) + mockAria2 := &mocks.Aria2Mock{} + mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{ + Status: "complete", + }, nil) + mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil) + mockNode := &mocks.NodeMock{} + mockNode.On("GetAria2Instance").Return(mockAria2) + mockNode.On("ID").Return(uint(1)) + m := &Monitor{ + node: mockNode, + Task: &model.Download{Model: gorm.Model{ID: 1}}, + } + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error")) + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + a.True(m.Update()) + a.NoError(mock.ExpectationsWereMet()) + mockAria2.AssertExpectations(t) + mockNode.AssertExpectations(t) + a.NotEmpty(m.Task.Error) +} - cache.Set("setting_max_worker_num", "1", 0) - aria2.mock.ExpectQuery("SELECT(.+)tasks").WillReturnRows(sqlmock.NewRows([]string{"id"})) - task.Init() - aria2.mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - aria2.mock.ExpectQuery("SELECT(.+)policies").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - aria2.mock.ExpectBegin() - aria2.mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(1, 1)) - aria2.mock.ExpectCommit() +func TestMonitor_UpdateError(t *testing.T) { + a := assert.New(t) + mockAria2 := &mocks.Aria2Mock{} + mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{ + Status: "error", + ErrorMessage: "error", + }, nil) + mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil) + mockNode := &mocks.NodeMock{} + mockNode.On("GetAria2Instance").Return(mockAria2) + m := &Monitor{ + node: mockNode, + Task: &model.Download{Model: gorm.Model{ID: 1}}, + } + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + a.True(m.Update()) + a.NoError(mock.ExpectationsWereMet()) + mockAria2.AssertExpectations(t) + mockNode.AssertExpectations(t) + a.NotEmpty(m.Task.Error) +} - aria2.mock.ExpectBegin() - aria2.mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1)) - aria2.mock.ExpectCommit() - asserts.True(monitor.Complete(rpc.StatusInfo{})) - asserts.NoError(aria2.mock.ExpectationsWereMet()) +func TestMonitor_UpdateActive(t *testing.T) { + a := assert.New(t) + mockAria2 := &mocks.Aria2Mock{} + mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{ + Status: "active", + }, nil) + mockNode := &mocks.NodeMock{} + mockNode.On("GetAria2Instance").Return(mockAria2) + m := &Monitor{ + node: mockNode, + Task: &model.Download{Model: gorm.Model{ID: 1}}, + } + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + a.False(m.Update()) + a.NoError(mock.ExpectationsWereMet()) + mockAria2.AssertExpectations(t) + mockNode.AssertExpectations(t) } diff --git a/pkg/mocks/mocks.go b/pkg/mocks/mocks.go new file mode 100644 index 0000000..64d4425 --- /dev/null +++ b/pkg/mocks/mocks.go @@ -0,0 +1,173 @@ +package mocks + +import ( + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" + "github.com/cloudreve/Cloudreve/v3/pkg/auth" + "github.com/cloudreve/Cloudreve/v3/pkg/balancer" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + testMock "github.com/stretchr/testify/mock" +) + +type SlaveControllerMock struct { + testMock.Mock +} + +func (s SlaveControllerMock) HandleHeartBeat(pingReq *serializer.NodePingReq) (serializer.NodePingResp, error) { + args := s.Called(pingReq) + return args.Get(0).(serializer.NodePingResp), args.Error(1) +} + +func (s SlaveControllerMock) GetAria2Instance(s2 string) (common.Aria2, error) { + args := s.Called(s2) + return args.Get(0).(common.Aria2), args.Error(1) +} + +func (s SlaveControllerMock) SendNotification(s3 string, s2 string, message mq.Message) error { + args := s.Called(s3, s2, message) + return args.Error(0) +} + +func (s SlaveControllerMock) SubmitTask(s3 string, i interface{}, s2 string, f func(interface{})) error { + args := s.Called(s3, i, s2, f) + return args.Error(0) +} + +func (s SlaveControllerMock) GetMasterInfo(s2 string) (*cluster.MasterInfo, error) { + args := s.Called(s2) + return args.Get(0).(*cluster.MasterInfo), args.Error(1) +} + +func (s SlaveControllerMock) GetOneDriveToken(s2 string, u uint) (string, error) { + args := s.Called(s2, u) + return args.String(0), args.Error(1) +} + +type NodePoolMock struct { + testMock.Mock +} + +func (n NodePoolMock) BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, cluster.Node) { + args := n.Called(feature, lb) + return args.Error(0), args.Get(1).(cluster.Node) +} + +func (n NodePoolMock) GetNodeByID(id uint) cluster.Node { + args := n.Called(id) + if res, ok := args.Get(0).(cluster.Node); ok { + return res + } + + return nil +} + +func (n NodePoolMock) Add(node *model.Node) { + n.Called(node) +} + +func (n NodePoolMock) Delete(id uint) { + n.Called(id) +} + +type NodeMock struct { + testMock.Mock +} + +func (n NodeMock) Init(node *model.Node) { + n.Called(node) +} + +func (n NodeMock) IsFeatureEnabled(feature string) bool { + args := n.Called(feature) + return args.Bool(0) +} + +func (n NodeMock) SubscribeStatusChange(callback func(isActive bool, id uint)) { + n.Called(callback) +} + +func (n NodeMock) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) { + args := n.Called(req) + return args.Get(0).(*serializer.NodePingResp), args.Error(1) +} + +func (n NodeMock) IsActive() bool { + args := n.Called() + return args.Bool(0) +} + +func (n NodeMock) GetAria2Instance() common.Aria2 { + args := n.Called() + return args.Get(0).(common.Aria2) +} + +func (n NodeMock) ID() uint { + args := n.Called() + return args.Get(0).(uint) +} + +func (n NodeMock) Kill() { + n.Called() +} + +func (n NodeMock) IsMater() bool { + args := n.Called() + return args.Bool(0) +} + +func (n NodeMock) MasterAuthInstance() auth.Auth { + args := n.Called() + return args.Get(0).(auth.Auth) +} + +func (n NodeMock) SlaveAuthInstance() auth.Auth { + args := n.Called() + return args.Get(0).(auth.Auth) +} + +func (n NodeMock) DBModel() *model.Node { + args := n.Called() + return args.Get(0).(*model.Node) +} + +type Aria2Mock struct { + testMock.Mock +} + +func (a Aria2Mock) Init() error { + args := a.Called() + return args.Error(0) +} + +func (a Aria2Mock) CreateTask(task *model.Download, options map[string]interface{}) (string, error) { + args := a.Called(task, options) + return args.String(0), args.Error(1) +} + +func (a Aria2Mock) Status(task *model.Download) (rpc.StatusInfo, error) { + args := a.Called(task) + return args.Get(0).(rpc.StatusInfo), args.Error(1) +} + +func (a Aria2Mock) Cancel(task *model.Download) error { + args := a.Called(task) + return args.Error(0) +} + +func (a Aria2Mock) Select(task *model.Download, files []int) error { + args := a.Called(task, files) + return args.Error(0) +} + +func (a Aria2Mock) GetConfig() model.Aria2Option { + args := a.Called() + return args.Get(0).(model.Aria2Option) +} + +func (a Aria2Mock) DeleteTempFile(download *model.Download) error { + args := a.Called(download) + return args.Error(0) +} diff --git a/service/aria2/add.go b/service/aria2/add.go index 73446b4..8443c14 100644 --- a/service/aria2/add.go +++ b/service/aria2/add.go @@ -72,7 +72,7 @@ func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Respo } // 创建任务监控 - monitor.NewMonitor(task) + monitor.NewMonitor(task, cluster.Default, mq.GlobalMQ) return serializer.Response{} }