Fix: tmp file not deleted after transfer task failed to createpull/1048/head
parent
eeee43d569
commit
4d7b8685b9
@ -0,0 +1,45 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDummyAria2(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
d := &DummyAria2{}
|
||||
|
||||
a.NoError(d.Init())
|
||||
|
||||
res, err := d.CreateTask(&model.Download{}, map[string]interface{}{})
|
||||
a.Empty(res)
|
||||
a.Error(err)
|
||||
|
||||
_, err = d.Status(&model.Download{})
|
||||
a.Error(err)
|
||||
|
||||
err = d.Cancel(&model.Download{})
|
||||
a.Error(err)
|
||||
|
||||
err = d.Select(&model.Download{}, []int{})
|
||||
a.Error(err)
|
||||
|
||||
configRes := d.GetConfig()
|
||||
a.NotEmpty(configRes)
|
||||
|
||||
err = d.DeleteTempFile(&model.Download{})
|
||||
a.Error(err)
|
||||
}
|
||||
|
||||
func TestGetStatus(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
|
||||
a.Equal(GetStatus("complete"), Complete)
|
||||
a.Equal(GetStatus("active"), Downloading)
|
||||
a.Equal(GetStatus("waiting"), Ready)
|
||||
a.Equal(GetStatus("paused"), Paused)
|
||||
a.Equal(GetStatus("error"), Error)
|
||||
a.Equal(GetStatus("removed"), Canceled)
|
||||
a.Equal(GetStatus("unknown"), Unknown)
|
||||
}
|
@ -1,326 +1,252 @@
|
||||
package monitor
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/task"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mocks"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
testMock "github.com/stretchr/testify/mock"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type InstanceMock struct {
|
||||
testMock.Mock
|
||||
}
|
||||
|
||||
func (m InstanceMock) CreateTask(task *model.Download, options map[string]interface{}) error {
|
||||
args := m.Called(task, options)
|
||||
return args.Error(0)
|
||||
}
|
||||
var mock sqlmock.Sqlmock
|
||||
|
||||
func (m InstanceMock) Status(task *model.Download) (rpc.StatusInfo, error) {
|
||||
args := m.Called(task)
|
||||
return args.Get(0).(rpc.StatusInfo), args.Error(1)
|
||||
// TestMain 初始化数据库Mock
|
||||
func TestMain(m *testing.M) {
|
||||
var db *sql.DB
|
||||
var err error
|
||||
db, mock, err = sqlmock.New()
|
||||
if err != nil {
|
||||
panic("An error was not expected when opening a stub database connection")
|
||||
}
|
||||
|
||||
func (m InstanceMock) Cancel(task *model.Download) error {
|
||||
args := m.Called(task)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m InstanceMock) Select(task *model.Download, files []int) error {
|
||||
args := m.Called(task, files)
|
||||
return args.Error(0)
|
||||
model.DB, _ = gorm.Open("mysql", db)
|
||||
defer db.Close()
|
||||
m.Run()
|
||||
}
|
||||
|
||||
func TestNewMonitor(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
NewMonitor(&model.Download{GID: "gid"})
|
||||
_, ok := common.EventNotifier.Subscribes.Load("gid")
|
||||
asserts.True(ok)
|
||||
}
|
||||
|
||||
func TestMonitor_Loop(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
notifier := make(chan common.StatusEvent)
|
||||
MAX_RETRY = 0
|
||||
monitor := &Monitor{
|
||||
Task: &model.Download{GID: "gid"},
|
||||
Interval: time.Duration(1) * time.Second,
|
||||
notifier: notifier,
|
||||
}
|
||||
asserts.NotPanics(func() {
|
||||
monitor.Loop()
|
||||
})
|
||||
}
|
||||
|
||||
func TestMonitor_Update(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
monitor := &Monitor{
|
||||
Task: &model.Download{
|
||||
GID: "gid",
|
||||
Parent: "TestMonitor_Update",
|
||||
},
|
||||
Interval: time.Duration(1) * time.Second,
|
||||
}
|
||||
|
||||
// 无法获取状态
|
||||
{
|
||||
MAX_RETRY = 1
|
||||
testInstance := new(InstanceMock)
|
||||
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{}, errors.New("error"))
|
||||
file, _ := util.CreatNestedFile("TestMonitor_Update/1")
|
||||
file.Close()
|
||||
aria2.Instance = testInstance
|
||||
asserts.False(monitor.Update())
|
||||
asserts.True(monitor.Update())
|
||||
testInstance.AssertExpectations(t)
|
||||
asserts.False(util.Exists("TestMonitor_Update"))
|
||||
}
|
||||
|
||||
// 磁力链下载重定向
|
||||
{
|
||||
testInstance := new(InstanceMock)
|
||||
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{
|
||||
FollowedBy: []string{"1"},
|
||||
}, nil)
|
||||
monitor.Task.ID = 1
|
||||
aria2.mock.ExpectBegin()
|
||||
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
aria2.mock.ExpectCommit()
|
||||
aria2.Instance = testInstance
|
||||
asserts.False(monitor.Update())
|
||||
asserts.NoError(aria2.mock.ExpectationsWereMet())
|
||||
testInstance.AssertExpectations(t)
|
||||
asserts.EqualValues("1", monitor.Task.GID)
|
||||
}
|
||||
|
||||
// 无法更新任务信息
|
||||
{
|
||||
testInstance := new(InstanceMock)
|
||||
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{}, nil)
|
||||
monitor.Task.ID = 1
|
||||
aria2.mock.ExpectBegin()
|
||||
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
|
||||
aria2.mock.ExpectRollback()
|
||||
aria2.Instance = testInstance
|
||||
asserts.True(monitor.Update())
|
||||
asserts.NoError(aria2.mock.ExpectationsWereMet())
|
||||
testInstance.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// 返回未知状态
|
||||
{
|
||||
testInstance := new(InstanceMock)
|
||||
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "?"}, nil)
|
||||
aria2.mock.ExpectBegin()
|
||||
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
aria2.mock.ExpectCommit()
|
||||
aria2.Instance = testInstance
|
||||
asserts.True(monitor.Update())
|
||||
asserts.NoError(aria2.mock.ExpectationsWereMet())
|
||||
testInstance.AssertExpectations(t)
|
||||
}
|
||||
a := assert.New(t)
|
||||
mockMQ := mq.NewMQ()
|
||||
|
||||
// 返回被取消状态
|
||||
// node not available
|
||||
{
|
||||
testInstance := new(InstanceMock)
|
||||
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "removed"}, nil)
|
||||
aria2.mock.ExpectBegin()
|
||||
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
aria2.mock.ExpectCommit()
|
||||
aria2.mock.ExpectBegin()
|
||||
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
aria2.mock.ExpectCommit()
|
||||
aria2.Instance = testInstance
|
||||
asserts.True(monitor.Update())
|
||||
asserts.NoError(aria2.mock.ExpectationsWereMet())
|
||||
testInstance.AssertExpectations(t)
|
||||
}
|
||||
mockPool := &mocks.NodePoolMock{}
|
||||
mockPool.On("GetNodeByID", uint(1)).Return(nil)
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
// 返回活跃状态
|
||||
{
|
||||
testInstance := new(InstanceMock)
|
||||
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "active"}, nil)
|
||||
aria2.mock.ExpectBegin()
|
||||
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
aria2.mock.ExpectCommit()
|
||||
aria2.Instance = testInstance
|
||||
asserts.False(monitor.Update())
|
||||
asserts.NoError(aria2.mock.ExpectationsWereMet())
|
||||
testInstance.AssertExpectations(t)
|
||||
task := &model.Download{
|
||||
Model: gorm.Model{ID: 1},
|
||||
}
|
||||
|
||||
// 返回错误状态
|
||||
{
|
||||
testInstance := new(InstanceMock)
|
||||
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "error"}, nil)
|
||||
aria2.mock.ExpectBegin()
|
||||
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
aria2.mock.ExpectCommit()
|
||||
aria2.Instance = testInstance
|
||||
asserts.True(monitor.Update())
|
||||
asserts.NoError(aria2.mock.ExpectationsWereMet())
|
||||
testInstance.AssertExpectations(t)
|
||||
NewMonitor(task, mockPool, mockMQ)
|
||||
mockPool.AssertExpectations(t)
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
a.NotEmpty(task.Error)
|
||||
}
|
||||
|
||||
// 返回完成
|
||||
// success
|
||||
{
|
||||
testInstance := new(InstanceMock)
|
||||
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "complete"}, nil)
|
||||
aria2.mock.ExpectBegin()
|
||||
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
aria2.mock.ExpectCommit()
|
||||
aria2.Instance = testInstance
|
||||
asserts.True(monitor.Update())
|
||||
asserts.NoError(aria2.mock.ExpectationsWereMet())
|
||||
testInstance.AssertExpectations(t)
|
||||
}
|
||||
}
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
|
||||
mockPool := &mocks.NodePoolMock{}
|
||||
mockPool.On("GetNodeByID", uint(1)).Return(mockNode)
|
||||
|
||||
func TestMonitor_UpdateTaskInfo(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
monitor := &Monitor{
|
||||
Task: &model.Download{
|
||||
task := &model.Download{
|
||||
Model: gorm.Model{ID: 1},
|
||||
GID: "gid",
|
||||
Parent: "TestMonitor_UpdateTaskInfo",
|
||||
},
|
||||
}
|
||||
|
||||
// 失败
|
||||
{
|
||||
aria2.mock.ExpectBegin()
|
||||
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
|
||||
aria2.mock.ExpectRollback()
|
||||
err := monitor.UpdateTaskInfo(rpc.StatusInfo{})
|
||||
asserts.NoError(aria2.mock.ExpectationsWereMet())
|
||||
asserts.Error(err)
|
||||
NewMonitor(task, mockPool, mockMQ)
|
||||
mockNode.AssertExpectations(t)
|
||||
mockPool.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// 更新成功,无需校验
|
||||
{
|
||||
aria2.mock.ExpectBegin()
|
||||
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
aria2.mock.ExpectCommit()
|
||||
err := monitor.UpdateTaskInfo(rpc.StatusInfo{})
|
||||
asserts.NoError(aria2.mock.ExpectationsWereMet())
|
||||
asserts.NoError(err)
|
||||
}
|
||||
|
||||
// 更新成功,大小改变,需要校验,校验失败
|
||||
{
|
||||
testInstance := new(InstanceMock)
|
||||
testInstance.On("SlaveCancel", testMock.Anything).Return(nil)
|
||||
aria2.Instance = testInstance
|
||||
aria2.mock.ExpectBegin()
|
||||
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
aria2.mock.ExpectCommit()
|
||||
err := monitor.UpdateTaskInfo(rpc.StatusInfo{TotalLength: "1"})
|
||||
asserts.NoError(aria2.mock.ExpectationsWereMet())
|
||||
asserts.Error(err)
|
||||
testInstance.AssertExpectations(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMonitor_ValidateFile(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
monitor := &Monitor{
|
||||
Task: &model.Download{
|
||||
Model: gorm.Model{ID: 1},
|
||||
GID: "gid",
|
||||
Parent: "TestMonitor_ValidateFile",
|
||||
},
|
||||
func TestMonitor_Loop(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockMQ := mq.NewMQ()
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
|
||||
m := &Monitor{
|
||||
retried: MAX_RETRY,
|
||||
node: mockNode,
|
||||
Task: &model.Download{Model: gorm.Model{ID: 1}},
|
||||
notifier: mockMQ.Subscribe("test", 1),
|
||||
}
|
||||
|
||||
// 无法创建文件系统
|
||||
// into interval loop
|
||||
{
|
||||
monitor.Task.User = &model.User{
|
||||
Policy: model.Policy{
|
||||
Type: "unknown",
|
||||
},
|
||||
}
|
||||
asserts.Error(monitor.ValidateFile())
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
m.Loop(mockMQ)
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
a.NotEmpty(m.Task.Error)
|
||||
}
|
||||
|
||||
// 文件大小超出容量配额
|
||||
// into notifier loop
|
||||
{
|
||||
cache.Set("pack_size_0", uint64(0), 0)
|
||||
monitor.Task.TotalSize = 11
|
||||
monitor.Task.User = &model.User{
|
||||
Policy: model.Policy{
|
||||
Type: "mock",
|
||||
},
|
||||
Group: model.Group{
|
||||
MaxStorage: 10,
|
||||
},
|
||||
m.Task.Error = ""
|
||||
mockMQ.Publish("test", mq.Message{})
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
m.Loop(mockMQ)
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
a.NotEmpty(m.Task.Error)
|
||||
}
|
||||
asserts.Equal(filesystem.ErrInsufficientCapacity, monitor.ValidateFile())
|
||||
}
|
||||
|
||||
// 单文件大小超出容量配额
|
||||
{
|
||||
cache.Set("pack_size_0", uint64(0), 0)
|
||||
monitor.Task.TotalSize = 10
|
||||
monitor.Task.StatusInfo.Files = []rpc.FileInfo{
|
||||
{
|
||||
Selected: "true",
|
||||
Length: "6",
|
||||
},
|
||||
}
|
||||
monitor.Task.User = &model.User{
|
||||
Policy: model.Policy{
|
||||
Type: "mock",
|
||||
MaxSize: 5,
|
||||
},
|
||||
Group: model.Group{
|
||||
MaxStorage: 10,
|
||||
},
|
||||
}
|
||||
asserts.Equal(filesystem.ErrFileSizeTooBig, monitor.ValidateFile())
|
||||
}
|
||||
func TestMonitor_UpdateFailedAfterRetry(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
|
||||
m := &Monitor{
|
||||
node: mockNode,
|
||||
Task: &model.Download{Model: gorm.Model{ID: 1}},
|
||||
}
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
func TestMonitor_Complete(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
monitor := &Monitor{
|
||||
Task: &model.Download{
|
||||
Model: gorm.Model{ID: 1},
|
||||
GID: "gid",
|
||||
Parent: "TestMonitor_Complete",
|
||||
StatusInfo: rpc.StatusInfo{
|
||||
Files: []rpc.FileInfo{
|
||||
{
|
||||
Selected: "true",
|
||||
Path: "TestMonitor_Complete",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
for i := 0; i < MAX_RETRY; i++ {
|
||||
a.False(m.Update())
|
||||
}
|
||||
|
||||
cache.Set("setting_max_worker_num", "1", 0)
|
||||
aria2.mock.ExpectQuery("SELECT(.+)tasks").WillReturnRows(sqlmock.NewRows([]string{"id"}))
|
||||
task.Init()
|
||||
aria2.mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
|
||||
aria2.mock.ExpectQuery("SELECT(.+)policies").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
|
||||
aria2.mock.ExpectBegin()
|
||||
aria2.mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
aria2.mock.ExpectCommit()
|
||||
mockNode.AssertExpectations(t)
|
||||
a.True(m.Update())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
a.NotEmpty(m.Task.Error)
|
||||
}
|
||||
|
||||
aria2.mock.ExpectBegin()
|
||||
aria2.mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
aria2.mock.ExpectCommit()
|
||||
asserts.True(monitor.Complete(rpc.StatusInfo{}))
|
||||
asserts.NoError(aria2.mock.ExpectationsWereMet())
|
||||
func TestMonitor_UpdateMagentoFollow(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockAria2 := &mocks.Aria2Mock{}
|
||||
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
|
||||
FollowedBy: []string{"next"},
|
||||
}, nil)
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(mockAria2)
|
||||
m := &Monitor{
|
||||
node: mockNode,
|
||||
Task: &model.Download{Model: gorm.Model{ID: 1}},
|
||||
}
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
a.False(m.Update())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
a.Equal("next", m.Task.GID)
|
||||
mockAria2.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestMonitor_UpdateFailedToUpdateInfo(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockAria2 := &mocks.Aria2Mock{}
|
||||
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, nil)
|
||||
mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil)
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(mockAria2)
|
||||
m := &Monitor{
|
||||
node: mockNode,
|
||||
Task: &model.Download{Model: gorm.Model{ID: 1}},
|
||||
}
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
|
||||
mock.ExpectRollback()
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
a.True(m.Update())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
mockAria2.AssertExpectations(t)
|
||||
mockNode.AssertExpectations(t)
|
||||
a.NotEmpty(m.Task.Error)
|
||||
}
|
||||
|
||||
func TestMonitor_UpdateCompleted(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockAria2 := &mocks.Aria2Mock{}
|
||||
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
|
||||
Status: "complete",
|
||||
}, nil)
|
||||
mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil)
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(mockAria2)
|
||||
mockNode.On("ID").Return(uint(1))
|
||||
m := &Monitor{
|
||||
node: mockNode,
|
||||
Task: &model.Download{Model: gorm.Model{ID: 1}},
|
||||
}
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error"))
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
a.True(m.Update())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
mockAria2.AssertExpectations(t)
|
||||
mockNode.AssertExpectations(t)
|
||||
a.NotEmpty(m.Task.Error)
|
||||
}
|
||||
|
||||
func TestMonitor_UpdateError(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockAria2 := &mocks.Aria2Mock{}
|
||||
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
|
||||
Status: "error",
|
||||
ErrorMessage: "error",
|
||||
}, nil)
|
||||
mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil)
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(mockAria2)
|
||||
m := &Monitor{
|
||||
node: mockNode,
|
||||
Task: &model.Download{Model: gorm.Model{ID: 1}},
|
||||
}
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
a.True(m.Update())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
mockAria2.AssertExpectations(t)
|
||||
mockNode.AssertExpectations(t)
|
||||
a.NotEmpty(m.Task.Error)
|
||||
}
|
||||
|
||||
func TestMonitor_UpdateActive(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockAria2 := &mocks.Aria2Mock{}
|
||||
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
|
||||
Status: "active",
|
||||
}, nil)
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(mockAria2)
|
||||
m := &Monitor{
|
||||
node: mockNode,
|
||||
Task: &model.Download{Model: gorm.Model{ID: 1}},
|
||||
}
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
a.False(m.Update())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
mockAria2.AssertExpectations(t)
|
||||
mockNode.AssertExpectations(t)
|
||||
}
|
||||
|
@ -0,0 +1,173 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
testMock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type SlaveControllerMock struct {
|
||||
testMock.Mock
|
||||
}
|
||||
|
||||
func (s SlaveControllerMock) HandleHeartBeat(pingReq *serializer.NodePingReq) (serializer.NodePingResp, error) {
|
||||
args := s.Called(pingReq)
|
||||
return args.Get(0).(serializer.NodePingResp), args.Error(1)
|
||||
}
|
||||
|
||||
func (s SlaveControllerMock) GetAria2Instance(s2 string) (common.Aria2, error) {
|
||||
args := s.Called(s2)
|
||||
return args.Get(0).(common.Aria2), args.Error(1)
|
||||
}
|
||||
|
||||
func (s SlaveControllerMock) SendNotification(s3 string, s2 string, message mq.Message) error {
|
||||
args := s.Called(s3, s2, message)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (s SlaveControllerMock) SubmitTask(s3 string, i interface{}, s2 string, f func(interface{})) error {
|
||||
args := s.Called(s3, i, s2, f)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (s SlaveControllerMock) GetMasterInfo(s2 string) (*cluster.MasterInfo, error) {
|
||||
args := s.Called(s2)
|
||||
return args.Get(0).(*cluster.MasterInfo), args.Error(1)
|
||||
}
|
||||
|
||||
func (s SlaveControllerMock) GetOneDriveToken(s2 string, u uint) (string, error) {
|
||||
args := s.Called(s2, u)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
type NodePoolMock struct {
|
||||
testMock.Mock
|
||||
}
|
||||
|
||||
func (n NodePoolMock) BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, cluster.Node) {
|
||||
args := n.Called(feature, lb)
|
||||
return args.Error(0), args.Get(1).(cluster.Node)
|
||||
}
|
||||
|
||||
func (n NodePoolMock) GetNodeByID(id uint) cluster.Node {
|
||||
args := n.Called(id)
|
||||
if res, ok := args.Get(0).(cluster.Node); ok {
|
||||
return res
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n NodePoolMock) Add(node *model.Node) {
|
||||
n.Called(node)
|
||||
}
|
||||
|
||||
func (n NodePoolMock) Delete(id uint) {
|
||||
n.Called(id)
|
||||
}
|
||||
|
||||
type NodeMock struct {
|
||||
testMock.Mock
|
||||
}
|
||||
|
||||
func (n NodeMock) Init(node *model.Node) {
|
||||
n.Called(node)
|
||||
}
|
||||
|
||||
func (n NodeMock) IsFeatureEnabled(feature string) bool {
|
||||
args := n.Called(feature)
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (n NodeMock) SubscribeStatusChange(callback func(isActive bool, id uint)) {
|
||||
n.Called(callback)
|
||||
}
|
||||
|
||||
func (n NodeMock) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
|
||||
args := n.Called(req)
|
||||
return args.Get(0).(*serializer.NodePingResp), args.Error(1)
|
||||
}
|
||||
|
||||
func (n NodeMock) IsActive() bool {
|
||||
args := n.Called()
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (n NodeMock) GetAria2Instance() common.Aria2 {
|
||||
args := n.Called()
|
||||
return args.Get(0).(common.Aria2)
|
||||
}
|
||||
|
||||
func (n NodeMock) ID() uint {
|
||||
args := n.Called()
|
||||
return args.Get(0).(uint)
|
||||
}
|
||||
|
||||
func (n NodeMock) Kill() {
|
||||
n.Called()
|
||||
}
|
||||
|
||||
func (n NodeMock) IsMater() bool {
|
||||
args := n.Called()
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (n NodeMock) MasterAuthInstance() auth.Auth {
|
||||
args := n.Called()
|
||||
return args.Get(0).(auth.Auth)
|
||||
}
|
||||
|
||||
func (n NodeMock) SlaveAuthInstance() auth.Auth {
|
||||
args := n.Called()
|
||||
return args.Get(0).(auth.Auth)
|
||||
}
|
||||
|
||||
func (n NodeMock) DBModel() *model.Node {
|
||||
args := n.Called()
|
||||
return args.Get(0).(*model.Node)
|
||||
}
|
||||
|
||||
type Aria2Mock struct {
|
||||
testMock.Mock
|
||||
}
|
||||
|
||||
func (a Aria2Mock) Init() error {
|
||||
args := a.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (a Aria2Mock) CreateTask(task *model.Download, options map[string]interface{}) (string, error) {
|
||||
args := a.Called(task, options)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (a Aria2Mock) Status(task *model.Download) (rpc.StatusInfo, error) {
|
||||
args := a.Called(task)
|
||||
return args.Get(0).(rpc.StatusInfo), args.Error(1)
|
||||
}
|
||||
|
||||
func (a Aria2Mock) Cancel(task *model.Download) error {
|
||||
args := a.Called(task)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (a Aria2Mock) Select(task *model.Download, files []int) error {
|
||||
args := a.Called(task, files)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (a Aria2Mock) GetConfig() model.Aria2Option {
|
||||
args := a.Called()
|
||||
return args.Get(0).(model.Aria2Option)
|
||||
}
|
||||
|
||||
func (a Aria2Mock) DeleteTempFile(download *model.Download) error {
|
||||
args := a.Called(download)
|
||||
return args.Error(0)
|
||||
}
|
Loading…
Reference in new issue