Test: aria2 task monitor

Fix: tmp file not deleted after transfer task failed to create
pull/1048/head
HFO4 3 years ago
parent eeee43d569
commit 4d7b8685b9

@ -6,12 +6,10 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/mocks"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
"net/http/httptest"
"testing"
)
@ -79,46 +77,12 @@ func TestSlaveRPCSignRequired(t *testing.T) {
}
}
type SlaveControllerMock struct {
testMock.Mock
}
func (s SlaveControllerMock) HandleHeartBeat(pingReq *serializer.NodePingReq) (serializer.NodePingResp, error) {
args := s.Called(pingReq)
return args.Get(0).(serializer.NodePingResp), args.Error(1)
}
func (s SlaveControllerMock) GetAria2Instance(s2 string) (common.Aria2, error) {
args := s.Called(s2)
return args.Get(0).(common.Aria2), args.Error(1)
}
func (s SlaveControllerMock) SendNotification(s3 string, s2 string, message mq.Message) error {
args := s.Called(s3, s2, message)
return args.Error(0)
}
func (s SlaveControllerMock) SubmitTask(s3 string, i interface{}, s2 string, f func(interface{})) error {
args := s.Called(s3, i, s2, f)
return args.Error(0)
}
func (s SlaveControllerMock) GetMasterInfo(s2 string) (*cluster.MasterInfo, error) {
args := s.Called(s2)
return args.Get(0).(*cluster.MasterInfo), args.Error(1)
}
func (s SlaveControllerMock) GetOneDriveToken(s2 string, u uint) (string, error) {
args := s.Called(s2, u)
return args.String(0), args.Error(1)
}
func TestUseSlaveAria2Instance(t *testing.T) {
a := assert.New(t)
// MasterSiteID not set
{
testController := &SlaveControllerMock{}
testController := &mocks.SlaveControllerMock{}
useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController)
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
@ -128,7 +92,7 @@ func TestUseSlaveAria2Instance(t *testing.T) {
// Cannot get aria2 instances
{
testController := &SlaveControllerMock{}
testController := &mocks.SlaveControllerMock{}
useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController)
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
@ -141,7 +105,7 @@ func TestUseSlaveAria2Instance(t *testing.T) {
// Success
{
testController := &SlaveControllerMock{}
testController := &mocks.SlaveControllerMock{}
useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController)
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)

@ -3,6 +3,8 @@ package aria2
import (
"context"
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"net/url"
"sync"
"time"
@ -42,7 +44,7 @@ func Init(isReload bool) {
for i := 0; i < len(unfinished); i++ {
// 创建任务监控
monitor.NewMonitor(&unfinished[i])
monitor.NewMonitor(&unfinished[i], cluster.Default, mq.GlobalMQ)
}
}
}

@ -0,0 +1,45 @@
package common
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/stretchr/testify/assert"
"testing"
)
func TestDummyAria2(t *testing.T) {
a := assert.New(t)
d := &DummyAria2{}
a.NoError(d.Init())
res, err := d.CreateTask(&model.Download{}, map[string]interface{}{})
a.Empty(res)
a.Error(err)
_, err = d.Status(&model.Download{})
a.Error(err)
err = d.Cancel(&model.Download{})
a.Error(err)
err = d.Select(&model.Download{}, []int{})
a.Error(err)
configRes := d.GetConfig()
a.NotEmpty(configRes)
err = d.DeleteTempFile(&model.Download{})
a.Error(err)
}
func TestGetStatus(t *testing.T) {
a := assert.New(t)
a.Equal(GetStatus("complete"), Complete)
a.Equal(GetStatus("active"), Downloading)
a.Equal(GetStatus("waiting"), Ready)
a.Equal(GetStatus("paused"), Paused)
a.Equal(GetStatus("error"), Error)
a.Equal(GetStatus("removed"), Canceled)
a.Equal(GetStatus("unknown"), Unknown)
}

@ -33,29 +33,29 @@ type Monitor struct {
var MAX_RETRY = 10
// NewMonitor 新建离线下载状态监控
func NewMonitor(task *model.Download) {
func NewMonitor(task *model.Download, pool cluster.Pool, mqClient mq.MQ) {
monitor := &Monitor{
Task: task,
notifier: make(chan mq.Message),
node: cluster.Default.GetNodeByID(task.GetNodeID()),
node: pool.GetNodeByID(task.GetNodeID()),
}
if monitor.node != nil {
monitor.Interval = time.Duration(monitor.node.GetAria2Instance().GetConfig().Interval) * time.Second
go monitor.Loop()
go monitor.Loop(mqClient)
monitor.notifier = mq.GlobalMQ.Subscribe(monitor.Task.GID, 0)
monitor.notifier = mqClient.Subscribe(monitor.Task.GID, 0)
} else {
monitor.setErrorStatus(errors.New("节点不可用"))
}
}
// Loop 开启监控循环
func (monitor *Monitor) Loop() {
defer mq.GlobalMQ.Unsubscribe(monitor.Task.GID, monitor.notifier)
func (monitor *Monitor) Loop(mqClient mq.MQ) {
defer mqClient.Unsubscribe(monitor.Task.GID, monitor.notifier)
// 首次循环立即更新
interval := time.Duration(0)
interval := 50 * time.Millisecond
for {
select {
@ -259,6 +259,7 @@ func (monitor *Monitor) Complete(status rpc.StatusInfo) bool {
)
if err != nil {
monitor.setErrorStatus(err)
monitor.RemoveTempFolder()
return true
}

@ -1,326 +1,252 @@
package monitor
import (
"database/sql"
"errors"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/task"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/cloudreve/Cloudreve/v3/pkg/mocks"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
"testing"
)
type InstanceMock struct {
testMock.Mock
}
var mock sqlmock.Sqlmock
func (m InstanceMock) CreateTask(task *model.Download, options map[string]interface{}) error {
args := m.Called(task, options)
return args.Error(0)
}
func (m InstanceMock) Status(task *model.Download) (rpc.StatusInfo, error) {
args := m.Called(task)
return args.Get(0).(rpc.StatusInfo), args.Error(1)
}
func (m InstanceMock) Cancel(task *model.Download) error {
args := m.Called(task)
return args.Error(0)
}
func (m InstanceMock) Select(task *model.Download, files []int) error {
args := m.Called(task, files)
return args.Error(0)
}
func TestNewMonitor(t *testing.T) {
asserts := assert.New(t)
NewMonitor(&model.Download{GID: "gid"})
_, ok := common.EventNotifier.Subscribes.Load("gid")
asserts.True(ok)
}
func TestMonitor_Loop(t *testing.T) {
asserts := assert.New(t)
notifier := make(chan common.StatusEvent)
MAX_RETRY = 0
monitor := &Monitor{
Task: &model.Download{GID: "gid"},
Interval: time.Duration(1) * time.Second,
notifier: notifier,
// TestMain 初始化数据库Mock
func TestMain(m *testing.M) {
var db *sql.DB
var err error
db, mock, err = sqlmock.New()
if err != nil {
panic("An error was not expected when opening a stub database connection")
}
asserts.NotPanics(func() {
monitor.Loop()
})
model.DB, _ = gorm.Open("mysql", db)
defer db.Close()
m.Run()
}
func TestMonitor_Update(t *testing.T) {
asserts := assert.New(t)
monitor := &Monitor{
Task: &model.Download{
GID: "gid",
Parent: "TestMonitor_Update",
},
Interval: time.Duration(1) * time.Second,
}
// 无法获取状态
{
MAX_RETRY = 1
testInstance := new(InstanceMock)
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{}, errors.New("error"))
file, _ := util.CreatNestedFile("TestMonitor_Update/1")
file.Close()
aria2.Instance = testInstance
asserts.False(monitor.Update())
asserts.True(monitor.Update())
testInstance.AssertExpectations(t)
asserts.False(util.Exists("TestMonitor_Update"))
}
func TestNewMonitor(t *testing.T) {
a := assert.New(t)
mockMQ := mq.NewMQ()
// 磁力链下载重定向
// node not available
{
testInstance := new(InstanceMock)
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{
FollowedBy: []string{"1"},
}, nil)
monitor.Task.ID = 1
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
aria2.Instance = testInstance
asserts.False(monitor.Update())
asserts.NoError(aria2.mock.ExpectationsWereMet())
testInstance.AssertExpectations(t)
asserts.EqualValues("1", monitor.Task.GID)
}
mockPool := &mocks.NodePoolMock{}
mockPool.On("GetNodeByID", uint(1)).Return(nil)
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
// 无法更新任务信息
{
testInstance := new(InstanceMock)
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{}, nil)
monitor.Task.ID = 1
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
aria2.mock.ExpectRollback()
aria2.Instance = testInstance
asserts.True(monitor.Update())
asserts.NoError(aria2.mock.ExpectationsWereMet())
testInstance.AssertExpectations(t)
task := &model.Download{
Model: gorm.Model{ID: 1},
}
// 返回未知状态
{
testInstance := new(InstanceMock)
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "?"}, nil)
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
aria2.Instance = testInstance
asserts.True(monitor.Update())
asserts.NoError(aria2.mock.ExpectationsWereMet())
testInstance.AssertExpectations(t)
NewMonitor(task, mockPool, mockMQ)
mockPool.AssertExpectations(t)
a.NoError(mock.ExpectationsWereMet())
a.NotEmpty(task.Error)
}
// 返回被取消状态
// success
{
testInstance := new(InstanceMock)
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "removed"}, nil)
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
aria2.Instance = testInstance
asserts.True(monitor.Update())
asserts.NoError(aria2.mock.ExpectationsWereMet())
testInstance.AssertExpectations(t)
}
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
mockPool := &mocks.NodePoolMock{}
mockPool.On("GetNodeByID", uint(1)).Return(mockNode)
// 返回活跃状态
{
testInstance := new(InstanceMock)
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "active"}, nil)
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
aria2.Instance = testInstance
asserts.False(monitor.Update())
asserts.NoError(aria2.mock.ExpectationsWereMet())
testInstance.AssertExpectations(t)
task := &model.Download{
Model: gorm.Model{ID: 1},
}
// 返回错误状态
{
testInstance := new(InstanceMock)
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "error"}, nil)
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
aria2.Instance = testInstance
asserts.True(monitor.Update())
asserts.NoError(aria2.mock.ExpectationsWereMet())
testInstance.AssertExpectations(t)
NewMonitor(task, mockPool, mockMQ)
mockNode.AssertExpectations(t)
mockPool.AssertExpectations(t)
}
// 返回完成
{
testInstance := new(InstanceMock)
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "complete"}, nil)
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
aria2.Instance = testInstance
asserts.True(monitor.Update())
asserts.NoError(aria2.mock.ExpectationsWereMet())
testInstance.AssertExpectations(t)
}
}
func TestMonitor_UpdateTaskInfo(t *testing.T) {
asserts := assert.New(t)
monitor := &Monitor{
Task: &model.Download{
Model: gorm.Model{ID: 1},
GID: "gid",
Parent: "TestMonitor_UpdateTaskInfo",
},
func TestMonitor_Loop(t *testing.T) {
a := assert.New(t)
mockMQ := mq.NewMQ()
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
m := &Monitor{
retried: MAX_RETRY,
node: mockNode,
Task: &model.Download{Model: gorm.Model{ID: 1}},
notifier: mockMQ.Subscribe("test", 1),
}
// into interval loop
{
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
m.Loop(mockMQ)
a.NoError(mock.ExpectationsWereMet())
a.NotEmpty(m.Task.Error)
}
// into notifier loop
{
m.Task.Error = ""
mockMQ.Publish("test", mq.Message{})
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
m.Loop(mockMQ)
a.NoError(mock.ExpectationsWereMet())
a.NotEmpty(m.Task.Error)
}
}
// 失败
{
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
aria2.mock.ExpectRollback()
err := monitor.UpdateTaskInfo(rpc.StatusInfo{})
asserts.NoError(aria2.mock.ExpectationsWereMet())
asserts.Error(err)
func TestMonitor_UpdateFailedAfterRetry(t *testing.T) {
a := assert.New(t)
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
m := &Monitor{
node: mockNode,
Task: &model.Download{Model: gorm.Model{ID: 1}},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
// 更新成功,无需校验
{
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
err := monitor.UpdateTaskInfo(rpc.StatusInfo{})
asserts.NoError(aria2.mock.ExpectationsWereMet())
asserts.NoError(err)
for i := 0; i < MAX_RETRY; i++ {
a.False(m.Update())
}
// 更新成功,大小改变,需要校验,校验失败
{
testInstance := new(InstanceMock)
testInstance.On("SlaveCancel", testMock.Anything).Return(nil)
aria2.Instance = testInstance
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
err := monitor.UpdateTaskInfo(rpc.StatusInfo{TotalLength: "1"})
asserts.NoError(aria2.mock.ExpectationsWereMet())
asserts.Error(err)
testInstance.AssertExpectations(t)
}
mockNode.AssertExpectations(t)
a.True(m.Update())
a.NoError(mock.ExpectationsWereMet())
a.NotEmpty(m.Task.Error)
}
func TestMonitor_ValidateFile(t *testing.T) {
asserts := assert.New(t)
monitor := &Monitor{
Task: &model.Download{
Model: gorm.Model{ID: 1},
GID: "gid",
Parent: "TestMonitor_ValidateFile",
},
}
// 无法创建文件系统
{
monitor.Task.User = &model.User{
Policy: model.Policy{
Type: "unknown",
},
}
asserts.Error(monitor.ValidateFile())
}
// 文件大小超出容量配额
{
cache.Set("pack_size_0", uint64(0), 0)
monitor.Task.TotalSize = 11
monitor.Task.User = &model.User{
Policy: model.Policy{
Type: "mock",
},
Group: model.Group{
MaxStorage: 10,
},
}
asserts.Equal(filesystem.ErrInsufficientCapacity, monitor.ValidateFile())
}
func TestMonitor_UpdateMagentoFollow(t *testing.T) {
a := assert.New(t)
mockAria2 := &mocks.Aria2Mock{}
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
FollowedBy: []string{"next"},
}, nil)
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(mockAria2)
m := &Monitor{
node: mockNode,
Task: &model.Download{Model: gorm.Model{ID: 1}},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.False(m.Update())
a.NoError(mock.ExpectationsWereMet())
a.Equal("next", m.Task.GID)
mockAria2.AssertExpectations(t)
}
// 单文件大小超出容量配额
{
cache.Set("pack_size_0", uint64(0), 0)
monitor.Task.TotalSize = 10
monitor.Task.StatusInfo.Files = []rpc.FileInfo{
{
Selected: "true",
Length: "6",
},
}
monitor.Task.User = &model.User{
Policy: model.Policy{
Type: "mock",
MaxSize: 5,
},
Group: model.Group{
MaxStorage: 10,
},
}
asserts.Equal(filesystem.ErrFileSizeTooBig, monitor.ValidateFile())
}
func TestMonitor_UpdateFailedToUpdateInfo(t *testing.T) {
a := assert.New(t)
mockAria2 := &mocks.Aria2Mock{}
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, nil)
mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil)
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(mockAria2)
m := &Monitor{
node: mockNode,
Task: &model.Download{Model: gorm.Model{ID: 1}},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.True(m.Update())
a.NoError(mock.ExpectationsWereMet())
mockAria2.AssertExpectations(t)
mockNode.AssertExpectations(t)
a.NotEmpty(m.Task.Error)
}
func TestMonitor_Complete(t *testing.T) {
asserts := assert.New(t)
monitor := &Monitor{
Task: &model.Download{
Model: gorm.Model{ID: 1},
GID: "gid",
Parent: "TestMonitor_Complete",
StatusInfo: rpc.StatusInfo{
Files: []rpc.FileInfo{
{
Selected: "true",
Path: "TestMonitor_Complete",
},
},
},
},
}
func TestMonitor_UpdateCompleted(t *testing.T) {
a := assert.New(t)
mockAria2 := &mocks.Aria2Mock{}
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
Status: "complete",
}, nil)
mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil)
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(mockAria2)
mockNode.On("ID").Return(uint(1))
m := &Monitor{
node: mockNode,
Task: &model.Download{Model: gorm.Model{ID: 1}},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error"))
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.True(m.Update())
a.NoError(mock.ExpectationsWereMet())
mockAria2.AssertExpectations(t)
mockNode.AssertExpectations(t)
a.NotEmpty(m.Task.Error)
}
cache.Set("setting_max_worker_num", "1", 0)
aria2.mock.ExpectQuery("SELECT(.+)tasks").WillReturnRows(sqlmock.NewRows([]string{"id"}))
task.Init()
aria2.mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
aria2.mock.ExpectQuery("SELECT(.+)policies").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
func TestMonitor_UpdateError(t *testing.T) {
a := assert.New(t)
mockAria2 := &mocks.Aria2Mock{}
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
Status: "error",
ErrorMessage: "error",
}, nil)
mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil)
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(mockAria2)
m := &Monitor{
node: mockNode,
Task: &model.Download{Model: gorm.Model{ID: 1}},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.True(m.Update())
a.NoError(mock.ExpectationsWereMet())
mockAria2.AssertExpectations(t)
mockNode.AssertExpectations(t)
a.NotEmpty(m.Task.Error)
}
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
asserts.True(monitor.Complete(rpc.StatusInfo{}))
asserts.NoError(aria2.mock.ExpectationsWereMet())
func TestMonitor_UpdateActive(t *testing.T) {
a := assert.New(t)
mockAria2 := &mocks.Aria2Mock{}
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
Status: "active",
}, nil)
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(mockAria2)
m := &Monitor{
node: mockNode,
Task: &model.Download{Model: gorm.Model{ID: 1}},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.False(m.Update())
a.NoError(mock.ExpectationsWereMet())
mockAria2.AssertExpectations(t)
mockNode.AssertExpectations(t)
}

@ -0,0 +1,173 @@
package mocks
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
testMock "github.com/stretchr/testify/mock"
)
type SlaveControllerMock struct {
testMock.Mock
}
func (s SlaveControllerMock) HandleHeartBeat(pingReq *serializer.NodePingReq) (serializer.NodePingResp, error) {
args := s.Called(pingReq)
return args.Get(0).(serializer.NodePingResp), args.Error(1)
}
func (s SlaveControllerMock) GetAria2Instance(s2 string) (common.Aria2, error) {
args := s.Called(s2)
return args.Get(0).(common.Aria2), args.Error(1)
}
func (s SlaveControllerMock) SendNotification(s3 string, s2 string, message mq.Message) error {
args := s.Called(s3, s2, message)
return args.Error(0)
}
func (s SlaveControllerMock) SubmitTask(s3 string, i interface{}, s2 string, f func(interface{})) error {
args := s.Called(s3, i, s2, f)
return args.Error(0)
}
func (s SlaveControllerMock) GetMasterInfo(s2 string) (*cluster.MasterInfo, error) {
args := s.Called(s2)
return args.Get(0).(*cluster.MasterInfo), args.Error(1)
}
func (s SlaveControllerMock) GetOneDriveToken(s2 string, u uint) (string, error) {
args := s.Called(s2, u)
return args.String(0), args.Error(1)
}
type NodePoolMock struct {
testMock.Mock
}
func (n NodePoolMock) BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, cluster.Node) {
args := n.Called(feature, lb)
return args.Error(0), args.Get(1).(cluster.Node)
}
func (n NodePoolMock) GetNodeByID(id uint) cluster.Node {
args := n.Called(id)
if res, ok := args.Get(0).(cluster.Node); ok {
return res
}
return nil
}
func (n NodePoolMock) Add(node *model.Node) {
n.Called(node)
}
func (n NodePoolMock) Delete(id uint) {
n.Called(id)
}
type NodeMock struct {
testMock.Mock
}
func (n NodeMock) Init(node *model.Node) {
n.Called(node)
}
func (n NodeMock) IsFeatureEnabled(feature string) bool {
args := n.Called(feature)
return args.Bool(0)
}
func (n NodeMock) SubscribeStatusChange(callback func(isActive bool, id uint)) {
n.Called(callback)
}
func (n NodeMock) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
args := n.Called(req)
return args.Get(0).(*serializer.NodePingResp), args.Error(1)
}
func (n NodeMock) IsActive() bool {
args := n.Called()
return args.Bool(0)
}
func (n NodeMock) GetAria2Instance() common.Aria2 {
args := n.Called()
return args.Get(0).(common.Aria2)
}
func (n NodeMock) ID() uint {
args := n.Called()
return args.Get(0).(uint)
}
func (n NodeMock) Kill() {
n.Called()
}
func (n NodeMock) IsMater() bool {
args := n.Called()
return args.Bool(0)
}
func (n NodeMock) MasterAuthInstance() auth.Auth {
args := n.Called()
return args.Get(0).(auth.Auth)
}
func (n NodeMock) SlaveAuthInstance() auth.Auth {
args := n.Called()
return args.Get(0).(auth.Auth)
}
func (n NodeMock) DBModel() *model.Node {
args := n.Called()
return args.Get(0).(*model.Node)
}
type Aria2Mock struct {
testMock.Mock
}
func (a Aria2Mock) Init() error {
args := a.Called()
return args.Error(0)
}
func (a Aria2Mock) CreateTask(task *model.Download, options map[string]interface{}) (string, error) {
args := a.Called(task, options)
return args.String(0), args.Error(1)
}
func (a Aria2Mock) Status(task *model.Download) (rpc.StatusInfo, error) {
args := a.Called(task)
return args.Get(0).(rpc.StatusInfo), args.Error(1)
}
func (a Aria2Mock) Cancel(task *model.Download) error {
args := a.Called(task)
return args.Error(0)
}
func (a Aria2Mock) Select(task *model.Download, files []int) error {
args := a.Called(task, files)
return args.Error(0)
}
func (a Aria2Mock) GetConfig() model.Aria2Option {
args := a.Called()
return args.Get(0).(model.Aria2Option)
}
func (a Aria2Mock) DeleteTempFile(download *model.Download) error {
args := a.Called(download)
return args.Error(0)
}

@ -72,7 +72,7 @@ func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Respo
}
// 创建任务监控
monitor.NewMonitor(task)
monitor.NewMonitor(task, cluster.Default, mq.GlobalMQ)
return serializer.Response{}
}

Loading…
Cancel
Save