From 2b853dddd3a877f2ec2959f8d3b76573ba692c9a Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Thu, 6 Feb 2020 13:53:47 +0800 Subject: [PATCH] Test: aria2 related --- models/setting.go | 8 +- pkg/aria2/aria2.go | 8 +- pkg/aria2/aria2_test.go | 89 ++++++++ pkg/aria2/caller.go | 3 +- pkg/aria2/caller_test.go | 51 +++++ pkg/aria2/{Monitor.go => monitor.go} | 4 +- pkg/aria2/monitor_test.go | 317 +++++++++++++++++++++++++++ 7 files changed, 471 insertions(+), 9 deletions(-) create mode 100644 pkg/aria2/aria2_test.go create mode 100644 pkg/aria2/caller_test.go rename pkg/aria2/{Monitor.go => monitor.go} (98%) create mode 100644 pkg/aria2/monitor_test.go diff --git a/models/setting.go b/models/setting.go index 0de472c..d14cf72 100644 --- a/models/setting.go +++ b/models/setting.go @@ -43,9 +43,11 @@ func GetSettingByNames(names ...string) map[string]string { var queryRes []Setting res, miss := cache.GetSettings(names, "setting_") - DB.Where("name IN (?)", miss).Find(&queryRes) - for _, setting := range queryRes { - res[setting.Name] = setting.Value + if len(miss) > 0 { + DB.Where("name IN (?)", miss).Find(&queryRes) + for _, setting := range queryRes { + res[setting.Name] = setting.Value + } } _ = cache.SetSettings(res, "setting_") diff --git a/pkg/aria2/aria2.go b/pkg/aria2/aria2.go index d8b8038..5f31967 100644 --- a/pkg/aria2/aria2.go +++ b/pkg/aria2/aria2.go @@ -87,7 +87,7 @@ func Init() { options := model.GetSettingByNames("aria2_rpcurl", "aria2_token", "aria2_options") timeout := model.GetIntSetting("aria2_call_timeout", 5) if options["aria2_rpcurl"] == "" { - // 未开启Aria2服务 + Instance = &DummyAria2{} return } @@ -101,19 +101,23 @@ func Init() { server, err := url.Parse(options["aria2_rpcurl"]) if err != nil { util.Log().Warning("无法解析 aria2 RPC 服务地址,%s", err) + Instance = &DummyAria2{} return } server.Path = "/jsonrpc" - // todo 加载自定义下载配置 + // 加载自定义下载配置 var globalOptions []interface{} err = json.Unmarshal([]byte(options["aria2_options"]), &globalOptions) if err != nil { util.Log().Warning("无法解析 aria2 全局配置,%s", err) + Instance = &DummyAria2{} + return } if err := client.Init(server.String(), options["aria2_token"], timeout, globalOptions); err != nil { util.Log().Warning("初始化 aria2 RPC 服务失败,%s", err) + Instance = &DummyAria2{} return } diff --git a/pkg/aria2/aria2_test.go b/pkg/aria2/aria2_test.go new file mode 100644 index 0000000..e059c24 --- /dev/null +++ b/pkg/aria2/aria2_test.go @@ -0,0 +1,89 @@ +package aria2 + +import ( + "database/sql" + "github.com/DATA-DOG/go-sqlmock" + model "github.com/HFO4/cloudreve/models" + "github.com/HFO4/cloudreve/pkg/cache" + "github.com/jinzhu/gorm" + "github.com/stretchr/testify/assert" + "testing" +) + +var mock sqlmock.Sqlmock + +// TestMain 初始化数据库Mock +func TestMain(m *testing.M) { + var db *sql.DB + var err error + db, mock, err = sqlmock.New() + if err != nil { + panic("An error was not expected when opening a stub database connection") + } + model.DB, _ = gorm.Open("mysql", db) + defer db.Close() + m.Run() +} + +func TestDummyAria2(t *testing.T) { + asserts := assert.New(t) + instance := DummyAria2{} + asserts.Error(instance.CreateTask(nil, nil)) + _, err := instance.Status(nil) + asserts.Error(err) + asserts.Error(instance.Cancel(nil)) + asserts.Error(instance.Select(nil, nil)) +} + +func TestInit(t *testing.T) { + asserts := assert.New(t) + cache.Set("setting_aria2_token", "1", 0) + cache.Set("setting_aria2_call_timeout", "5", 0) + cache.Set("setting_aria2_options", `[]`, 0) + + // 未指定RPC地址,跳过 + { + cache.Set("setting_aria2_rpcurl", "", 0) + Init() + asserts.IsType(&DummyAria2{}, Instance) + } + + // 无法解析服务器地址 + { + cache.Set("setting_aria2_rpcurl", string(byte(0x7f)), 0) + Init() + asserts.IsType(&DummyAria2{}, Instance) + } + + // 无法解析全局配置 + { + Instance = &RPCService{} + cache.Set("setting_aria2_options", "?", 0) + cache.Set("setting_aria2_rpcurl", "ws://127.0.0.1:1234", 0) + Init() + asserts.IsType(&DummyAria2{}, Instance) + } + + // 连接失败 + { + cache.Set("setting_aria2_options", "[]", 0) + cache.Set("setting_aria2_rpcurl", "http://127.0.0.1:1234", 0) + cache.Set("setting_aria2_call_timeout", "1", 0) + cache.Set("setting_aria2_interval", "100", 0) + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"g_id"}).AddRow("1")) + Init() + asserts.NoError(mock.ExpectationsWereMet()) + asserts.IsType(&RPCService{}, Instance) + } +} + +func TestGetStatus(t *testing.T) { + asserts := assert.New(t) + asserts.Equal(4, getStatus("complete")) + asserts.Equal(1, getStatus("active")) + asserts.Equal(0, getStatus("waiting")) + asserts.Equal(2, getStatus("paused")) + asserts.Equal(3, getStatus("error")) + asserts.Equal(5, getStatus("removed")) + asserts.Equal(6, getStatus("?")) +} diff --git a/pkg/aria2/caller.go b/pkg/aria2/caller.go index 65513f7..be13dcc 100644 --- a/pkg/aria2/caller.go +++ b/pkg/aria2/caller.go @@ -72,8 +72,7 @@ func (client *RPCService) Select(task *model.Download, files []int) error { for i := 0; i < len(files); i++ { selected[i] = strconv.Itoa(files[i]) } - ok, err := client.caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")}) - util.Log().Debug(ok) + _, err := client.caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")}) return err } diff --git a/pkg/aria2/caller_test.go b/pkg/aria2/caller_test.go new file mode 100644 index 0000000..7bfec67 --- /dev/null +++ b/pkg/aria2/caller_test.go @@ -0,0 +1,51 @@ +package aria2 + +import ( + model "github.com/HFO4/cloudreve/models" + "github.com/HFO4/cloudreve/pkg/cache" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestRPCService_Init(t *testing.T) { + asserts := assert.New(t) + caller := &RPCService{} + asserts.Error(caller.Init("ws://", "", 1, nil)) + asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil)) +} + +func TestRPCService_Status(t *testing.T) { + asserts := assert.New(t) + caller := &RPCService{} + asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil)) + + _, err := caller.Status(&model.Download{}) + asserts.Error(err) +} + +func TestRPCService_Cancel(t *testing.T) { + asserts := assert.New(t) + caller := &RPCService{} + asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil)) + + err := caller.Cancel(&model.Download{Parent: "test"}) + asserts.Error(err) +} + +func TestRPCService_Select(t *testing.T) { + asserts := assert.New(t) + caller := &RPCService{} + asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil)) + + err := caller.Select(&model.Download{Parent: "test"}, []int{1, 2, 3}) + asserts.Error(err) +} + +func TestRPCService_CreateTask(t *testing.T) { + asserts := assert.New(t) + caller := &RPCService{} + asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil)) + cache.Set("setting_aria2_temp_path", "test", 0) + err := caller.CreateTask(&model.Download{Parent: "test"}, []interface{}{map[string]string{"1": "1"}}) + asserts.Error(err) +} diff --git a/pkg/aria2/Monitor.go b/pkg/aria2/monitor.go similarity index 98% rename from pkg/aria2/Monitor.go rename to pkg/aria2/monitor.go index ee36147..253c71c 100644 --- a/pkg/aria2/Monitor.go +++ b/pkg/aria2/monitor.go @@ -89,7 +89,7 @@ func (monitor *Monitor) Update() bool { return true } - util.Log().Debug(status.Status) + util.Log().Debug("离线下载[%s]更新状态[%s]", status.Gid, status.Status) switch status.Status { case "complete": @@ -140,7 +140,7 @@ func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error { monitor.Task.Attrs = string(attrs) if err := monitor.Task.Save(); err != nil { - return nil + return err } if originSize != monitor.Task.TotalSize { diff --git a/pkg/aria2/monitor_test.go b/pkg/aria2/monitor_test.go new file mode 100644 index 0000000..d85e7d7 --- /dev/null +++ b/pkg/aria2/monitor_test.go @@ -0,0 +1,317 @@ +package aria2 + +import ( + "errors" + "github.com/DATA-DOG/go-sqlmock" + model "github.com/HFO4/cloudreve/models" + "github.com/HFO4/cloudreve/pkg/cache" + "github.com/HFO4/cloudreve/pkg/filesystem" + "github.com/HFO4/cloudreve/pkg/task" + "github.com/HFO4/cloudreve/pkg/util" + "github.com/jinzhu/gorm" + "github.com/stretchr/testify/assert" + testMock "github.com/stretchr/testify/mock" + "github.com/zyxar/argo/rpc" + "testing" + "time" +) + +type InstanceMock struct { + testMock.Mock +} + +func (m InstanceMock) CreateTask(task *model.Download, options []interface{}) error { + args := m.Called(task, options) + return args.Error(0) +} + +func (m InstanceMock) Status(task *model.Download) (rpc.StatusInfo, error) { + args := m.Called(task) + return args.Get(0).(rpc.StatusInfo), args.Error(1) +} + +func (m InstanceMock) Cancel(task *model.Download) error { + args := m.Called(task) + return args.Error(0) +} + +func (m InstanceMock) Select(task *model.Download, files []int) error { + args := m.Called(task, files) + return args.Error(0) +} + +func TestNewMonitor(t *testing.T) { + asserts := assert.New(t) + NewMonitor(&model.Download{GID: "gid"}) + _, ok := EventNotifier.Subscribes.Load("gid") + asserts.True(ok) +} + +func TestMonitor_Loop(t *testing.T) { + asserts := assert.New(t) + notifier := make(chan StatusEvent) + monitor := &Monitor{ + Task: &model.Download{GID: "gid"}, + Interval: time.Duration(1) * time.Second, + notifier: notifier, + } + asserts.NotPanics(func() { + monitor.Loop() + }) +} + +func TestMonitor_Update(t *testing.T) { + asserts := assert.New(t) + monitor := &Monitor{ + Task: &model.Download{ + GID: "gid", + Parent: "TestMonitor_Update", + }, + Interval: time.Duration(1) * time.Second, + } + + // 无法获取状态 + { + testInstance := new(InstanceMock) + testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, errors.New("error")) + file, _ := util.CreatNestedFile("TestMonitor_Update/1") + file.Close() + Instance = testInstance + asserts.True(monitor.Update()) + testInstance.AssertExpectations(t) + asserts.False(util.Exists("TestMonitor_Update")) + } + + // 磁力链下载重定向 + { + testInstance := new(InstanceMock) + testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{ + FollowedBy: []string{"1"}, + }, nil) + monitor.Task.ID = 1 + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + Instance = testInstance + asserts.False(monitor.Update()) + asserts.NoError(mock.ExpectationsWereMet()) + testInstance.AssertExpectations(t) + asserts.EqualValues("1", monitor.Task.GID) + } + + // 无法更新任务信息 + { + testInstance := new(InstanceMock) + testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, nil) + monitor.Task.ID = 1 + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error")) + mock.ExpectRollback() + Instance = testInstance + asserts.True(monitor.Update()) + asserts.NoError(mock.ExpectationsWereMet()) + testInstance.AssertExpectations(t) + } + + // 返回未知状态 + { + testInstance := new(InstanceMock) + testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "?"}, nil) + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + Instance = testInstance + asserts.True(monitor.Update()) + asserts.NoError(mock.ExpectationsWereMet()) + testInstance.AssertExpectations(t) + } + + // 返回被取消状态 + { + testInstance := new(InstanceMock) + testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "removed"}, nil) + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + Instance = testInstance + asserts.True(monitor.Update()) + asserts.NoError(mock.ExpectationsWereMet()) + testInstance.AssertExpectations(t) + } + + // 返回活跃状态 + { + testInstance := new(InstanceMock) + testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "active"}, nil) + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + Instance = testInstance + asserts.False(monitor.Update()) + asserts.NoError(mock.ExpectationsWereMet()) + testInstance.AssertExpectations(t) + } + + // 返回错误状态 + { + testInstance := new(InstanceMock) + testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "error"}, nil) + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + Instance = testInstance + asserts.True(monitor.Update()) + asserts.NoError(mock.ExpectationsWereMet()) + testInstance.AssertExpectations(t) + } + + // 返回完成 + { + testInstance := new(InstanceMock) + testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "complete"}, nil) + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + Instance = testInstance + asserts.True(monitor.Update()) + asserts.NoError(mock.ExpectationsWereMet()) + testInstance.AssertExpectations(t) + } +} + +func TestMonitor_UpdateTaskInfo(t *testing.T) { + asserts := assert.New(t) + monitor := &Monitor{ + Task: &model.Download{ + Model: gorm.Model{ID: 1}, + GID: "gid", + Parent: "TestMonitor_UpdateTaskInfo", + }, + } + + // 失败 + { + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error")) + mock.ExpectRollback() + err := monitor.UpdateTaskInfo(rpc.StatusInfo{}) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Error(err) + } + + // 更新成功,无需校验 + { + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + err := monitor.UpdateTaskInfo(rpc.StatusInfo{}) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(err) + } + + // 更新成功,大小改变,需要校验,校验失败 + { + testInstance := new(InstanceMock) + testInstance.On("Cancel", testMock.Anything).Return(nil) + Instance = testInstance + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + err := monitor.UpdateTaskInfo(rpc.StatusInfo{TotalLength: "1"}) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Error(err) + testInstance.AssertExpectations(t) + } +} + +func TestMonitor_ValidateFile(t *testing.T) { + asserts := assert.New(t) + monitor := &Monitor{ + Task: &model.Download{ + Model: gorm.Model{ID: 1}, + GID: "gid", + Parent: "TestMonitor_ValidateFile", + }, + } + + // 无法创建文件系统 + { + monitor.Task.User = &model.User{ + Policy: model.Policy{ + Type: "unknown", + }, + } + asserts.Error(monitor.ValidateFile()) + } + + // 文件大小超出容量配额 + { + cache.Set("pack_size_0", uint64(0), 0) + monitor.Task.TotalSize = 11 + monitor.Task.User = &model.User{ + Policy: model.Policy{ + Type: "mock", + }, + Group: model.Group{ + MaxStorage: 10, + }, + } + asserts.Equal(filesystem.ErrInsufficientCapacity, monitor.ValidateFile()) + } + + // 单文件大小超出容量配额 + { + cache.Set("pack_size_0", uint64(0), 0) + monitor.Task.TotalSize = 10 + monitor.Task.StatusInfo.Files = []rpc.FileInfo{ + { + Selected: "true", + Length: "6", + }, + } + monitor.Task.User = &model.User{ + Policy: model.Policy{ + Type: "mock", + MaxSize: 5, + }, + Group: model.Group{ + MaxStorage: 10, + }, + } + asserts.Equal(filesystem.ErrFileSizeTooBig, monitor.ValidateFile()) + } +} + +func TestMonitor_Complete(t *testing.T) { + asserts := assert.New(t) + monitor := &Monitor{ + Task: &model.Download{ + Model: gorm.Model{ID: 1}, + GID: "gid", + Parent: "TestMonitor_Complete", + StatusInfo: rpc.StatusInfo{ + Files: []rpc.FileInfo{ + { + Selected: "true", + Path: "TestMonitor_Complete", + }, + }, + }, + }, + } + + cache.Set("setting_max_worker_num", "1", 0) + task.Init() + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + mock.ExpectBegin() + mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + asserts.True(monitor.Complete(rpc.StatusInfo{})) + asserts.NoError(mock.ExpectationsWereMet()) +}