From f0089045d7874ecfa8682359cadc53ea41a21c5b Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Thu, 11 Nov 2021 19:49:02 +0800 Subject: [PATCH] Test: aria2 task monitor 100% cover --- pkg/aria2/monitor/monitor.go | 6 +- pkg/aria2/monitor/monitor_test.go | 186 ++++++++++++++++++++++++++++++ pkg/mocks/mocks.go | 13 +++ pkg/task/pool.go | 25 ++-- pkg/task/pool_test.go | 2 +- 5 files changed, 218 insertions(+), 14 deletions(-) diff --git a/pkg/aria2/monitor/monitor.go b/pkg/aria2/monitor/monitor.go index aec1f25..b989826 100644 --- a/pkg/aria2/monitor/monitor.go +++ b/pkg/aria2/monitor/monitor.go @@ -112,7 +112,7 @@ func (monitor *Monitor) Update() bool { switch status.Status { case "complete": - return monitor.Complete(status) + return monitor.Complete(task.TaskPoll) case "error": return monitor.Error(status) case "active", "waiting", "paused": @@ -235,7 +235,7 @@ func (monitor *Monitor) RemoveTempFolder() { } // Complete 完成下载,返回是否中断监控 -func (monitor *Monitor) Complete(status rpc.StatusInfo) bool { +func (monitor *Monitor) Complete(pool task.Pool) bool { // 创建中转任务 file := make([]string, 0, len(monitor.Task.StatusInfo.Files)) sizes := make(map[string]uint64, len(monitor.Task.StatusInfo.Files)) @@ -264,7 +264,7 @@ func (monitor *Monitor) Complete(status rpc.StatusInfo) bool { } // 提交中转任务 - task.TaskPoll.Submit(job) + pool.Submit(job) // 更新任务ID monitor.Task.TaskID = job.Model().ID diff --git a/pkg/aria2/monitor/monitor_test.go b/pkg/aria2/monitor/monitor_test.go index d0760ca..885484a 100644 --- a/pkg/aria2/monitor/monitor_test.go +++ b/pkg/aria2/monitor/monitor_test.go @@ -7,6 +7,7 @@ import ( model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" "github.com/cloudreve/Cloudreve/v3/pkg/mocks" "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/jinzhu/gorm" @@ -250,3 +251,188 @@ func TestMonitor_UpdateActive(t *testing.T) { mockAria2.AssertExpectations(t) mockNode.AssertExpectations(t) } + +func TestMonitor_UpdateRemoved(t *testing.T) { + a := assert.New(t) + mockAria2 := &mocks.Aria2Mock{} + mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{ + Status: "removed", + }, nil) + mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil) + mockNode := &mocks.NodeMock{} + mockNode.On("GetAria2Instance").Return(mockAria2) + m := &Monitor{ + node: mockNode, + Task: &model.Download{Model: gorm.Model{ID: 1}}, + } + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + a.True(m.Update()) + a.Equal(common.Canceled, m.Task.Status) + a.NoError(mock.ExpectationsWereMet()) + mockAria2.AssertExpectations(t) + mockNode.AssertExpectations(t) +} + +func TestMonitor_UpdateUnknown(t *testing.T) { + a := assert.New(t) + mockAria2 := &mocks.Aria2Mock{} + mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{ + Status: "unknown", + }, nil) + mockNode := &mocks.NodeMock{} + mockNode.On("GetAria2Instance").Return(mockAria2) + m := &Monitor{ + node: mockNode, + Task: &model.Download{Model: gorm.Model{ID: 1}}, + } + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + a.True(m.Update()) + a.NoError(mock.ExpectationsWereMet()) + mockAria2.AssertExpectations(t) + mockNode.AssertExpectations(t) +} + +func TestMonitor_UpdateTaskInfoValidateFailed(t *testing.T) { + a := assert.New(t) + status := rpc.StatusInfo{ + Status: "completed", + TotalLength: "100", + CompletedLength: "50", + DownloadSpeed: "20", + } + mockNode := &mocks.NodeMock{} + mockNode.On("GetAria2Instance").Return(&common.DummyAria2{}) + m := &Monitor{ + node: mockNode, + Task: &model.Download{Model: gorm.Model{ID: 1}}, + } + + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + err := m.UpdateTaskInfo(status) + a.Error(err) + a.NoError(mock.ExpectationsWereMet()) + mockNode.AssertExpectations(t) +} + +func TestMonitor_ValidateFile(t *testing.T) { + a := assert.New(t) + m := &Monitor{ + Task: &model.Download{ + Model: gorm.Model{ID: 1}, + TotalSize: 100, + }, + } + + // failed to create filesystem + { + m.Task.User = &model.User{ + Policy: model.Policy{ + Type: "random", + }, + } + a.Equal(filesystem.ErrUnknownPolicyType, m.ValidateFile()) + } + + // User capacity not enough + { + m.Task.User = &model.User{ + Group: model.Group{ + MaxStorage: 99, + }, + Policy: model.Policy{ + Type: "local", + }, + } + a.Equal(filesystem.ErrInsufficientCapacity, m.ValidateFile()) + } + + // single file too big + { + m.Task.StatusInfo.Files = []rpc.FileInfo{ + { + Length: "100", + Selected: "true", + }, + } + m.Task.User = &model.User{ + Group: model.Group{ + MaxStorage: 100, + }, + Policy: model.Policy{ + Type: "local", + MaxSize: 99, + }, + } + a.Equal(filesystem.ErrFileSizeTooBig, m.ValidateFile()) + } + + // all pass + { + m.Task.StatusInfo.Files = []rpc.FileInfo{ + { + Length: "100", + Selected: "true", + }, + } + m.Task.User = &model.User{ + Group: model.Group{ + MaxStorage: 100, + }, + Policy: model.Policy{ + Type: "local", + MaxSize: 100, + }, + } + a.NoError(m.ValidateFile()) + } +} + +func TestMonitor_Complete(t *testing.T) { + a := assert.New(t) + mockNode := &mocks.NodeMock{} + mockNode.On("ID").Return(uint(1)) + mockPool := &mocks.TaskPoolMock{} + mockPool.On("Submit", testMock.Anything) + m := &Monitor{ + node: mockNode, + Task: &model.Download{ + Model: gorm.Model{ID: 1}, + TotalSize: 100, + UserID: 9414, + }, + } + m.Task.StatusInfo.Files = []rpc.FileInfo{ + { + Length: "100", + Selected: "true", + }, + } + + mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(9414)) + + mock.ExpectBegin() + mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + a.True(m.Complete(mockPool)) + a.NoError(mock.ExpectationsWereMet()) + mockNode.AssertExpectations(t) + mockPool.AssertExpectations(t) +} diff --git a/pkg/mocks/mocks.go b/pkg/mocks/mocks.go index 64d4425..2134e86 100644 --- a/pkg/mocks/mocks.go +++ b/pkg/mocks/mocks.go @@ -9,6 +9,7 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/cloudreve/Cloudreve/v3/pkg/task" testMock "github.com/stretchr/testify/mock" ) @@ -171,3 +172,15 @@ func (a Aria2Mock) DeleteTempFile(download *model.Download) error { args := a.Called(download) return args.Error(0) } + +type TaskPoolMock struct { + testMock.Mock +} + +func (t TaskPoolMock) Add(num int) { + t.Called(num) +} + +func (t TaskPoolMock) Submit(job task.Job) { + t.Called(job) +} diff --git a/pkg/task/pool.go b/pkg/task/pool.go index 4fe550f..3188d7e 100644 --- a/pkg/task/pool.go +++ b/pkg/task/pool.go @@ -7,23 +7,28 @@ import ( ) // TaskPoll 要使用的任务池 -var TaskPoll *Pool +var TaskPoll Pool -// Pool 带有最大配额的任务池 -type Pool struct { +type Pool interface { + Add(num int) + Submit(job Job) +} + +// AsyncPool 带有最大配额的任务池 +type AsyncPool struct { // 容量 idleWorker chan int } // Add 增加可用Worker数量 -func (pool *Pool) Add(num int) { +func (pool *AsyncPool) Add(num int) { for i := 0; i < num; i++ { pool.idleWorker <- 1 } } // ObtainWorker 阻塞直到获取新的Worker -func (pool *Pool) ObtainWorker() Worker { +func (pool *AsyncPool) obtainWorker() Worker { select { case <-pool.idleWorker: // 有空闲Worker名额时,返回新Worker @@ -32,26 +37,26 @@ func (pool *Pool) ObtainWorker() Worker { } // FreeWorker 添加空闲Worker -func (pool *Pool) FreeWorker() { +func (pool *AsyncPool) freeWorker() { pool.Add(1) } // Submit 开始提交任务 -func (pool *Pool) Submit(job Job) { +func (pool *AsyncPool) Submit(job Job) { go func() { util.Log().Debug("等待获取Worker") - worker := pool.ObtainWorker() + worker := pool.obtainWorker() util.Log().Debug("获取到Worker") worker.Do(job) util.Log().Debug("释放Worker") - pool.FreeWorker() + pool.freeWorker() }() } // Init 初始化任务池 func Init() { maxWorker := model.GetIntSetting("max_worker_num", 10) - TaskPoll = &Pool{ + TaskPoll = &AsyncPool{ idleWorker: make(chan int, maxWorker), } TaskPoll.Add(maxWorker) diff --git a/pkg/task/pool_test.go b/pkg/task/pool_test.go index 0ed9641..5b7f74e 100644 --- a/pkg/task/pool_test.go +++ b/pkg/task/pool_test.go @@ -37,7 +37,7 @@ func TestInit(t *testing.T) { func TestPool_Submit(t *testing.T) { asserts := assert.New(t) - pool := &Pool{ + pool := &AsyncPool{ idleWorker: make(chan int, 1), } pool.Add(1)