Test: aria2 task monitor

Fix: tmp file not deleted after transfer task failed to create
pull/1048/head
HFO4 3 years ago
parent eeee43d569
commit 4d7b8685b9

@ -6,12 +6,10 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/mocks"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
) )
@ -79,46 +77,12 @@ func TestSlaveRPCSignRequired(t *testing.T) {
} }
} }
type SlaveControllerMock struct {
testMock.Mock
}
func (s SlaveControllerMock) HandleHeartBeat(pingReq *serializer.NodePingReq) (serializer.NodePingResp, error) {
args := s.Called(pingReq)
return args.Get(0).(serializer.NodePingResp), args.Error(1)
}
func (s SlaveControllerMock) GetAria2Instance(s2 string) (common.Aria2, error) {
args := s.Called(s2)
return args.Get(0).(common.Aria2), args.Error(1)
}
func (s SlaveControllerMock) SendNotification(s3 string, s2 string, message mq.Message) error {
args := s.Called(s3, s2, message)
return args.Error(0)
}
func (s SlaveControllerMock) SubmitTask(s3 string, i interface{}, s2 string, f func(interface{})) error {
args := s.Called(s3, i, s2, f)
return args.Error(0)
}
func (s SlaveControllerMock) GetMasterInfo(s2 string) (*cluster.MasterInfo, error) {
args := s.Called(s2)
return args.Get(0).(*cluster.MasterInfo), args.Error(1)
}
func (s SlaveControllerMock) GetOneDriveToken(s2 string, u uint) (string, error) {
args := s.Called(s2, u)
return args.String(0), args.Error(1)
}
func TestUseSlaveAria2Instance(t *testing.T) { func TestUseSlaveAria2Instance(t *testing.T) {
a := assert.New(t) a := assert.New(t)
// MasterSiteID not set // MasterSiteID not set
{ {
testController := &SlaveControllerMock{} testController := &mocks.SlaveControllerMock{}
useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController) useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController)
c, _ := gin.CreateTestContext(httptest.NewRecorder()) c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil) c.Request = httptest.NewRequest("GET", "/", nil)
@ -128,7 +92,7 @@ func TestUseSlaveAria2Instance(t *testing.T) {
// Cannot get aria2 instances // Cannot get aria2 instances
{ {
testController := &SlaveControllerMock{} testController := &mocks.SlaveControllerMock{}
useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController) useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController)
c, _ := gin.CreateTestContext(httptest.NewRecorder()) c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil) c.Request = httptest.NewRequest("GET", "/", nil)
@ -141,7 +105,7 @@ func TestUseSlaveAria2Instance(t *testing.T) {
// Success // Success
{ {
testController := &SlaveControllerMock{} testController := &mocks.SlaveControllerMock{}
useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController) useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController)
c, _ := gin.CreateTestContext(httptest.NewRecorder()) c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil) c.Request = httptest.NewRequest("GET", "/", nil)

@ -3,6 +3,8 @@ package aria2
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"net/url" "net/url"
"sync" "sync"
"time" "time"
@ -42,7 +44,7 @@ func Init(isReload bool) {
for i := 0; i < len(unfinished); i++ { for i := 0; i < len(unfinished); i++ {
// 创建任务监控 // 创建任务监控
monitor.NewMonitor(&unfinished[i]) monitor.NewMonitor(&unfinished[i], cluster.Default, mq.GlobalMQ)
} }
} }
} }

@ -0,0 +1,45 @@
package common
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/stretchr/testify/assert"
"testing"
)
func TestDummyAria2(t *testing.T) {
a := assert.New(t)
d := &DummyAria2{}
a.NoError(d.Init())
res, err := d.CreateTask(&model.Download{}, map[string]interface{}{})
a.Empty(res)
a.Error(err)
_, err = d.Status(&model.Download{})
a.Error(err)
err = d.Cancel(&model.Download{})
a.Error(err)
err = d.Select(&model.Download{}, []int{})
a.Error(err)
configRes := d.GetConfig()
a.NotEmpty(configRes)
err = d.DeleteTempFile(&model.Download{})
a.Error(err)
}
func TestGetStatus(t *testing.T) {
a := assert.New(t)
a.Equal(GetStatus("complete"), Complete)
a.Equal(GetStatus("active"), Downloading)
a.Equal(GetStatus("waiting"), Ready)
a.Equal(GetStatus("paused"), Paused)
a.Equal(GetStatus("error"), Error)
a.Equal(GetStatus("removed"), Canceled)
a.Equal(GetStatus("unknown"), Unknown)
}

@ -33,29 +33,29 @@ type Monitor struct {
var MAX_RETRY = 10 var MAX_RETRY = 10
// NewMonitor 新建离线下载状态监控 // NewMonitor 新建离线下载状态监控
func NewMonitor(task *model.Download) { func NewMonitor(task *model.Download, pool cluster.Pool, mqClient mq.MQ) {
monitor := &Monitor{ monitor := &Monitor{
Task: task, Task: task,
notifier: make(chan mq.Message), notifier: make(chan mq.Message),
node: cluster.Default.GetNodeByID(task.GetNodeID()), node: pool.GetNodeByID(task.GetNodeID()),
} }
if monitor.node != nil { if monitor.node != nil {
monitor.Interval = time.Duration(monitor.node.GetAria2Instance().GetConfig().Interval) * time.Second monitor.Interval = time.Duration(monitor.node.GetAria2Instance().GetConfig().Interval) * time.Second
go monitor.Loop() go monitor.Loop(mqClient)
monitor.notifier = mq.GlobalMQ.Subscribe(monitor.Task.GID, 0) monitor.notifier = mqClient.Subscribe(monitor.Task.GID, 0)
} else { } else {
monitor.setErrorStatus(errors.New("节点不可用")) monitor.setErrorStatus(errors.New("节点不可用"))
} }
} }
// Loop 开启监控循环 // Loop 开启监控循环
func (monitor *Monitor) Loop() { func (monitor *Monitor) Loop(mqClient mq.MQ) {
defer mq.GlobalMQ.Unsubscribe(monitor.Task.GID, monitor.notifier) defer mqClient.Unsubscribe(monitor.Task.GID, monitor.notifier)
// 首次循环立即更新 // 首次循环立即更新
interval := time.Duration(0) interval := 50 * time.Millisecond
for { for {
select { select {
@ -259,6 +259,7 @@ func (monitor *Monitor) Complete(status rpc.StatusInfo) bool {
) )
if err != nil { if err != nil {
monitor.setErrorStatus(err) monitor.setErrorStatus(err)
monitor.RemoveTempFolder()
return true return true
} }

@ -1,326 +1,252 @@
package monitor package monitor
import ( import (
"database/sql"
"errors" "errors"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock" "github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/cloudreve/Cloudreve/v3/pkg/mocks"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem" "github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/task"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock" testMock "github.com/stretchr/testify/mock"
"testing"
) )
type InstanceMock struct { var mock sqlmock.Sqlmock
testMock.Mock
}
func (m InstanceMock) CreateTask(task *model.Download, options map[string]interface{}) error {
args := m.Called(task, options)
return args.Error(0)
}
func (m InstanceMock) Status(task *model.Download) (rpc.StatusInfo, error) {
args := m.Called(task)
return args.Get(0).(rpc.StatusInfo), args.Error(1)
}
func (m InstanceMock) Cancel(task *model.Download) error {
args := m.Called(task)
return args.Error(0)
}
func (m InstanceMock) Select(task *model.Download, files []int) error {
args := m.Called(task, files)
return args.Error(0)
}
func TestNewMonitor(t *testing.T) {
asserts := assert.New(t)
NewMonitor(&model.Download{GID: "gid"})
_, ok := common.EventNotifier.Subscribes.Load("gid")
asserts.True(ok)
}
func TestMonitor_Loop(t *testing.T) { // TestMain 初始化数据库Mock
asserts := assert.New(t) func TestMain(m *testing.M) {
notifier := make(chan common.StatusEvent) var db *sql.DB
MAX_RETRY = 0 var err error
monitor := &Monitor{ db, mock, err = sqlmock.New()
Task: &model.Download{GID: "gid"}, if err != nil {
Interval: time.Duration(1) * time.Second, panic("An error was not expected when opening a stub database connection")
notifier: notifier,
} }
asserts.NotPanics(func() { model.DB, _ = gorm.Open("mysql", db)
monitor.Loop() defer db.Close()
}) m.Run()
} }
func TestMonitor_Update(t *testing.T) { func TestNewMonitor(t *testing.T) {
asserts := assert.New(t) a := assert.New(t)
monitor := &Monitor{ mockMQ := mq.NewMQ()
Task: &model.Download{
GID: "gid",
Parent: "TestMonitor_Update",
},
Interval: time.Duration(1) * time.Second,
}
// 无法获取状态
{
MAX_RETRY = 1
testInstance := new(InstanceMock)
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{}, errors.New("error"))
file, _ := util.CreatNestedFile("TestMonitor_Update/1")
file.Close()
aria2.Instance = testInstance
asserts.False(monitor.Update())
asserts.True(monitor.Update())
testInstance.AssertExpectations(t)
asserts.False(util.Exists("TestMonitor_Update"))
}
// 磁力链下载重定向
{
testInstance := new(InstanceMock)
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{
FollowedBy: []string{"1"},
}, nil)
monitor.Task.ID = 1
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
aria2.Instance = testInstance
asserts.False(monitor.Update())
asserts.NoError(aria2.mock.ExpectationsWereMet())
testInstance.AssertExpectations(t)
asserts.EqualValues("1", monitor.Task.GID)
}
// 无法更新任务信息 // node not available
{ {
testInstance := new(InstanceMock) mockPool := &mocks.NodePoolMock{}
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{}, nil) mockPool.On("GetNodeByID", uint(1)).Return(nil)
monitor.Task.ID = 1 mock.ExpectBegin()
aria2.mock.ExpectBegin() mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error")) mock.ExpectCommit()
aria2.mock.ExpectRollback()
aria2.Instance = testInstance task := &model.Download{
asserts.True(monitor.Update()) Model: gorm.Model{ID: 1},
asserts.NoError(aria2.mock.ExpectationsWereMet()) }
testInstance.AssertExpectations(t) NewMonitor(task, mockPool, mockMQ)
mockPool.AssertExpectations(t)
a.NoError(mock.ExpectationsWereMet())
a.NotEmpty(task.Error)
} }
// 返回未知状态 // success
{ {
testInstance := new(InstanceMock) mockNode := &mocks.NodeMock{}
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "?"}, nil) mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
aria2.mock.ExpectBegin() mockPool := &mocks.NodePoolMock{}
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) mockPool.On("GetNodeByID", uint(1)).Return(mockNode)
aria2.mock.ExpectCommit()
aria2.Instance = testInstance
asserts.True(monitor.Update())
asserts.NoError(aria2.mock.ExpectationsWereMet())
testInstance.AssertExpectations(t)
}
// 返回被取消状态 task := &model.Download{
{ Model: gorm.Model{ID: 1},
testInstance := new(InstanceMock) }
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "removed"}, nil) NewMonitor(task, mockPool, mockMQ)
aria2.mock.ExpectBegin() mockNode.AssertExpectations(t)
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) mockPool.AssertExpectations(t)
aria2.mock.ExpectCommit()
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
aria2.Instance = testInstance
asserts.True(monitor.Update())
asserts.NoError(aria2.mock.ExpectationsWereMet())
testInstance.AssertExpectations(t)
} }
// 返回活跃状态 }
{
testInstance := new(InstanceMock)
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "active"}, nil)
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
aria2.Instance = testInstance
asserts.False(monitor.Update())
asserts.NoError(aria2.mock.ExpectationsWereMet())
testInstance.AssertExpectations(t)
}
// 返回错误状态 func TestMonitor_Loop(t *testing.T) {
a := assert.New(t)
mockMQ := mq.NewMQ()
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
m := &Monitor{
retried: MAX_RETRY,
node: mockNode,
Task: &model.Download{Model: gorm.Model{ID: 1}},
notifier: mockMQ.Subscribe("test", 1),
}
// into interval loop
{ {
testInstance := new(InstanceMock) mock.ExpectBegin()
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "error"}, nil) mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectBegin() mock.ExpectCommit()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) m.Loop(mockMQ)
aria2.mock.ExpectCommit() a.NoError(mock.ExpectationsWereMet())
aria2.Instance = testInstance a.NotEmpty(m.Task.Error)
asserts.True(monitor.Update())
asserts.NoError(aria2.mock.ExpectationsWereMet())
testInstance.AssertExpectations(t)
} }
// 返回完成 // into notifier loop
{ {
testInstance := new(InstanceMock) m.Task.Error = ""
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "complete"}, nil) mockMQ.Publish("test", mq.Message{})
aria2.mock.ExpectBegin() mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit() mock.ExpectCommit()
aria2.Instance = testInstance m.Loop(mockMQ)
asserts.True(monitor.Update()) a.NoError(mock.ExpectationsWereMet())
asserts.NoError(aria2.mock.ExpectationsWereMet()) a.NotEmpty(m.Task.Error)
testInstance.AssertExpectations(t)
} }
} }
func TestMonitor_UpdateTaskInfo(t *testing.T) { func TestMonitor_UpdateFailedAfterRetry(t *testing.T) {
asserts := assert.New(t) a := assert.New(t)
monitor := &Monitor{ mockNode := &mocks.NodeMock{}
Task: &model.Download{ mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
Model: gorm.Model{ID: 1}, m := &Monitor{
GID: "gid", node: mockNode,
Parent: "TestMonitor_UpdateTaskInfo", Task: &model.Download{Model: gorm.Model{ID: 1}},
},
}
// 失败
{
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
aria2.mock.ExpectRollback()
err := monitor.UpdateTaskInfo(rpc.StatusInfo{})
asserts.NoError(aria2.mock.ExpectationsWereMet())
asserts.Error(err)
} }
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
// 更新成功,无需校验 for i := 0; i < MAX_RETRY; i++ {
{ a.False(m.Update())
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
err := monitor.UpdateTaskInfo(rpc.StatusInfo{})
asserts.NoError(aria2.mock.ExpectationsWereMet())
asserts.NoError(err)
} }
// 更新成功,大小改变,需要校验,校验失败 mockNode.AssertExpectations(t)
{ a.True(m.Update())
testInstance := new(InstanceMock) a.NoError(mock.ExpectationsWereMet())
testInstance.On("SlaveCancel", testMock.Anything).Return(nil) a.NotEmpty(m.Task.Error)
aria2.Instance = testInstance
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
err := monitor.UpdateTaskInfo(rpc.StatusInfo{TotalLength: "1"})
asserts.NoError(aria2.mock.ExpectationsWereMet())
asserts.Error(err)
testInstance.AssertExpectations(t)
}
} }
func TestMonitor_ValidateFile(t *testing.T) { func TestMonitor_UpdateMagentoFollow(t *testing.T) {
asserts := assert.New(t) a := assert.New(t)
monitor := &Monitor{ mockAria2 := &mocks.Aria2Mock{}
Task: &model.Download{ mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
Model: gorm.Model{ID: 1}, FollowedBy: []string{"next"},
GID: "gid", }, nil)
Parent: "TestMonitor_ValidateFile", mockNode := &mocks.NodeMock{}
}, mockNode.On("GetAria2Instance").Return(mockAria2)
} m := &Monitor{
node: mockNode,
// 无法创建文件系统 Task: &model.Download{Model: gorm.Model{ID: 1}},
{ }
monitor.Task.User = &model.User{ mock.ExpectBegin()
Policy: model.Policy{ mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
Type: "unknown", mock.ExpectCommit()
},
} a.False(m.Update())
asserts.Error(monitor.ValidateFile()) a.NoError(mock.ExpectationsWereMet())
} a.Equal("next", m.Task.GID)
mockAria2.AssertExpectations(t)
// 文件大小超出容量配额 }
{
cache.Set("pack_size_0", uint64(0), 0)
monitor.Task.TotalSize = 11
monitor.Task.User = &model.User{
Policy: model.Policy{
Type: "mock",
},
Group: model.Group{
MaxStorage: 10,
},
}
asserts.Equal(filesystem.ErrInsufficientCapacity, monitor.ValidateFile())
}
// 单文件大小超出容量配额 func TestMonitor_UpdateFailedToUpdateInfo(t *testing.T) {
{ a := assert.New(t)
cache.Set("pack_size_0", uint64(0), 0) mockAria2 := &mocks.Aria2Mock{}
monitor.Task.TotalSize = 10 mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, nil)
monitor.Task.StatusInfo.Files = []rpc.FileInfo{ mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil)
{ mockNode := &mocks.NodeMock{}
Selected: "true", mockNode.On("GetAria2Instance").Return(mockAria2)
Length: "6", m := &Monitor{
}, node: mockNode,
} Task: &model.Download{Model: gorm.Model{ID: 1}},
monitor.Task.User = &model.User{ }
Policy: model.Policy{ mock.ExpectBegin()
Type: "mock", mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
MaxSize: 5, mock.ExpectRollback()
}, mock.ExpectBegin()
Group: model.Group{ mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
MaxStorage: 10, mock.ExpectCommit()
},
} a.True(m.Update())
asserts.Equal(filesystem.ErrFileSizeTooBig, monitor.ValidateFile()) a.NoError(mock.ExpectationsWereMet())
} mockAria2.AssertExpectations(t)
mockNode.AssertExpectations(t)
a.NotEmpty(m.Task.Error)
} }
func TestMonitor_Complete(t *testing.T) { func TestMonitor_UpdateCompleted(t *testing.T) {
asserts := assert.New(t) a := assert.New(t)
monitor := &Monitor{ mockAria2 := &mocks.Aria2Mock{}
Task: &model.Download{ mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
Model: gorm.Model{ID: 1}, Status: "complete",
GID: "gid", }, nil)
Parent: "TestMonitor_Complete", mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil)
StatusInfo: rpc.StatusInfo{ mockNode := &mocks.NodeMock{}
Files: []rpc.FileInfo{ mockNode.On("GetAria2Instance").Return(mockAria2)
{ mockNode.On("ID").Return(uint(1))
Selected: "true", m := &Monitor{
Path: "TestMonitor_Complete", node: mockNode,
}, Task: &model.Download{Model: gorm.Model{ID: 1}},
}, }
}, mock.ExpectBegin()
}, mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
} mock.ExpectCommit()
mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error"))
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.True(m.Update())
a.NoError(mock.ExpectationsWereMet())
mockAria2.AssertExpectations(t)
mockNode.AssertExpectations(t)
a.NotEmpty(m.Task.Error)
}
cache.Set("setting_max_worker_num", "1", 0) func TestMonitor_UpdateError(t *testing.T) {
aria2.mock.ExpectQuery("SELECT(.+)tasks").WillReturnRows(sqlmock.NewRows([]string{"id"})) a := assert.New(t)
task.Init() mockAria2 := &mocks.Aria2Mock{}
aria2.mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
aria2.mock.ExpectQuery("SELECT(.+)policies").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) Status: "error",
aria2.mock.ExpectBegin() ErrorMessage: "error",
aria2.mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(1, 1)) }, nil)
aria2.mock.ExpectCommit() mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil)
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(mockAria2)
m := &Monitor{
node: mockNode,
Task: &model.Download{Model: gorm.Model{ID: 1}},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.True(m.Update())
a.NoError(mock.ExpectationsWereMet())
mockAria2.AssertExpectations(t)
mockNode.AssertExpectations(t)
a.NotEmpty(m.Task.Error)
}
aria2.mock.ExpectBegin() func TestMonitor_UpdateActive(t *testing.T) {
aria2.mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1)) a := assert.New(t)
aria2.mock.ExpectCommit() mockAria2 := &mocks.Aria2Mock{}
asserts.True(monitor.Complete(rpc.StatusInfo{})) mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
asserts.NoError(aria2.mock.ExpectationsWereMet()) Status: "active",
}, nil)
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(mockAria2)
m := &Monitor{
node: mockNode,
Task: &model.Download{Model: gorm.Model{ID: 1}},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.False(m.Update())
a.NoError(mock.ExpectationsWereMet())
mockAria2.AssertExpectations(t)
mockNode.AssertExpectations(t)
} }

@ -0,0 +1,173 @@
package mocks
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
testMock "github.com/stretchr/testify/mock"
)
type SlaveControllerMock struct {
testMock.Mock
}
func (s SlaveControllerMock) HandleHeartBeat(pingReq *serializer.NodePingReq) (serializer.NodePingResp, error) {
args := s.Called(pingReq)
return args.Get(0).(serializer.NodePingResp), args.Error(1)
}
func (s SlaveControllerMock) GetAria2Instance(s2 string) (common.Aria2, error) {
args := s.Called(s2)
return args.Get(0).(common.Aria2), args.Error(1)
}
func (s SlaveControllerMock) SendNotification(s3 string, s2 string, message mq.Message) error {
args := s.Called(s3, s2, message)
return args.Error(0)
}
func (s SlaveControllerMock) SubmitTask(s3 string, i interface{}, s2 string, f func(interface{})) error {
args := s.Called(s3, i, s2, f)
return args.Error(0)
}
func (s SlaveControllerMock) GetMasterInfo(s2 string) (*cluster.MasterInfo, error) {
args := s.Called(s2)
return args.Get(0).(*cluster.MasterInfo), args.Error(1)
}
func (s SlaveControllerMock) GetOneDriveToken(s2 string, u uint) (string, error) {
args := s.Called(s2, u)
return args.String(0), args.Error(1)
}
type NodePoolMock struct {
testMock.Mock
}
func (n NodePoolMock) BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, cluster.Node) {
args := n.Called(feature, lb)
return args.Error(0), args.Get(1).(cluster.Node)
}
func (n NodePoolMock) GetNodeByID(id uint) cluster.Node {
args := n.Called(id)
if res, ok := args.Get(0).(cluster.Node); ok {
return res
}
return nil
}
func (n NodePoolMock) Add(node *model.Node) {
n.Called(node)
}
func (n NodePoolMock) Delete(id uint) {
n.Called(id)
}
type NodeMock struct {
testMock.Mock
}
func (n NodeMock) Init(node *model.Node) {
n.Called(node)
}
func (n NodeMock) IsFeatureEnabled(feature string) bool {
args := n.Called(feature)
return args.Bool(0)
}
func (n NodeMock) SubscribeStatusChange(callback func(isActive bool, id uint)) {
n.Called(callback)
}
func (n NodeMock) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
args := n.Called(req)
return args.Get(0).(*serializer.NodePingResp), args.Error(1)
}
func (n NodeMock) IsActive() bool {
args := n.Called()
return args.Bool(0)
}
func (n NodeMock) GetAria2Instance() common.Aria2 {
args := n.Called()
return args.Get(0).(common.Aria2)
}
func (n NodeMock) ID() uint {
args := n.Called()
return args.Get(0).(uint)
}
func (n NodeMock) Kill() {
n.Called()
}
func (n NodeMock) IsMater() bool {
args := n.Called()
return args.Bool(0)
}
func (n NodeMock) MasterAuthInstance() auth.Auth {
args := n.Called()
return args.Get(0).(auth.Auth)
}
func (n NodeMock) SlaveAuthInstance() auth.Auth {
args := n.Called()
return args.Get(0).(auth.Auth)
}
func (n NodeMock) DBModel() *model.Node {
args := n.Called()
return args.Get(0).(*model.Node)
}
type Aria2Mock struct {
testMock.Mock
}
func (a Aria2Mock) Init() error {
args := a.Called()
return args.Error(0)
}
func (a Aria2Mock) CreateTask(task *model.Download, options map[string]interface{}) (string, error) {
args := a.Called(task, options)
return args.String(0), args.Error(1)
}
func (a Aria2Mock) Status(task *model.Download) (rpc.StatusInfo, error) {
args := a.Called(task)
return args.Get(0).(rpc.StatusInfo), args.Error(1)
}
func (a Aria2Mock) Cancel(task *model.Download) error {
args := a.Called(task)
return args.Error(0)
}
func (a Aria2Mock) Select(task *model.Download, files []int) error {
args := a.Called(task, files)
return args.Error(0)
}
func (a Aria2Mock) GetConfig() model.Aria2Option {
args := a.Called()
return args.Get(0).(model.Aria2Option)
}
func (a Aria2Mock) DeleteTempFile(download *model.Download) error {
args := a.Called(download)
return args.Error(0)
}

@ -72,7 +72,7 @@ func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Respo
} }
// 创建任务监控 // 创建任务监控
monitor.NewMonitor(task) monitor.NewMonitor(task, cluster.Default, mq.GlobalMQ)
return serializer.Response{} return serializer.Response{}
} }

Loading…
Cancel
Save