diff --git a/middleware/explorer.go b/middleware/hahsid.go similarity index 90% rename from middleware/explorer.go rename to middleware/hahsid.go index 8ea9add..f66d945 100644 --- a/middleware/explorer.go +++ b/middleware/hahsid.go @@ -6,7 +6,7 @@ import ( "github.com/gin-gonic/gin" ) -// HashID 将给定文件的HashID转换为真实ID +// HashID 将给定对象的HashID转换为真实ID func HashID(IDType int) gin.HandlerFunc { return func(c *gin.Context) { if c.Param("id") != "" { diff --git a/middleware/hashid_test.go b/middleware/hashid_test.go new file mode 100644 index 0000000..4cdfe83 --- /dev/null +++ b/middleware/hashid_test.go @@ -0,0 +1,50 @@ +package middleware + +import ( + "github.com/HFO4/cloudreve/pkg/hashid" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "testing" +) + +func TestHashID(t *testing.T) { + asserts := assert.New(t) + rec := httptest.NewRecorder() + TestFunc := HashID(hashid.FolderID) + + // 未给定ID对象,跳过 + { + c, _ := gin.CreateTestContext(rec) + c.Params = []gin.Param{} + c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil) + TestFunc(c) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.False(c.IsAborted()) + } + + // 给定ID,解析失败 + { + c, _ := gin.CreateTestContext(rec) + c.Params = []gin.Param{ + {"id", "2333"}, + } + c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil) + TestFunc(c) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.True(c.IsAborted()) + } + + // 给定ID,解析成功 + { + c, _ := gin.CreateTestContext(rec) + c.Params = []gin.Param{ + {"id", hashid.HashID(1, hashid.FolderID)}, + } + c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil) + TestFunc(c) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.False(c.IsAborted()) + } +} diff --git a/models/download.go b/models/download.go index 207f560..d32064f 100644 --- a/models/download.go +++ b/models/download.go @@ -78,7 +78,6 @@ func GetDownloadsByStatus(status ...int) []Download { // GetDownloadsByStatusAndUser 根据状态检索和用户ID下载 // page 为 0 表示列出所有,非零时分页 -// TODO 测试 func GetDownloadsByStatusAndUser(page, uid uint, status ...int) []Download { var tasks []Download dbChain := DB diff --git a/models/download_test.go b/models/download_test.go index f0a43b1..d24531b 100644 --- a/models/download_test.go +++ b/models/download_test.go @@ -62,6 +62,15 @@ func TestDownload_AfterFind(t *testing.T) { asserts.Error(err) asserts.Equal("", download.StatusInfo.Gid) } + + // 关联任务 + { + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "error"}).AddRow(1, "error")) + download := Download{TaskID: 1} + download.BeforeSave() + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Equal("error", download.Task.Error) + } } func TestDownload_Save(t *testing.T) { @@ -140,3 +149,23 @@ func TestDownload_GetOwner(t *testing.T) { asserts.Equal("nick", user.Nick) } } + +func TestGetDownloadsByStatusAndUser(t *testing.T) { + asserts := assert.New(t) + + // 列出全部 + { + mock.ExpectQuery("SELECT(.+)").WithArgs(1, 1, 2).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2).AddRow(3)) + res := GetDownloadsByStatusAndUser(0, 1, 1, 2) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Len(res, 2) + } + + // 列出全部,分页 + { + mock.ExpectQuery("SELECT(.+)DESC(.+)").WithArgs(1, 1, 2).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2).AddRow(3)) + res := GetDownloadsByStatusAndUser(2, 1, 1, 2) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Len(res, 2) + } +} diff --git a/models/file.go b/models/file.go index 49a5c77..e5c3dd7 100644 --- a/models/file.go +++ b/models/file.go @@ -80,7 +80,6 @@ func GetFilesByIDs(ids []uint, uid uint) ([]File, error) { // GetFilesByKeywords 根据关键字搜索文件, // UID为0表示忽略用户,只根据文件ID检索 -// TODO 测试 func GetFilesByKeywords(uid uint, keywords ...interface{}) ([]File, error) { var ( files []File diff --git a/models/file_test.go b/models/file_test.go index 61d5c6b..7e53c49 100644 --- a/models/file_test.go +++ b/models/file_test.go @@ -385,3 +385,25 @@ func TestFile_FileInfoInterface(t *testing.T) { asserts.False(file.IsDir()) asserts.Equal("/test", file.GetPosition()) } + +func TestGetFilesByKeywords(t *testing.T) { + asserts := assert.New(t) + + // 未指定用户 + { + mock.ExpectQuery("SELECT(.+)").WithArgs("k1", "k2").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + res, err := GetFilesByKeywords(0, "k1", "k2") + asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(err) + asserts.Len(res, 1) + } + + // 指定用户 + { + mock.ExpectQuery("SELECT(.+)").WithArgs(1, "k1", "k2").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + res, err := GetFilesByKeywords(1, "k1", "k2") + asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(err) + asserts.Len(res, 1) + } +} diff --git a/models/tag_test.go b/models/tag_test.go new file mode 100644 index 0000000..4ecdc5d --- /dev/null +++ b/models/tag_test.go @@ -0,0 +1,63 @@ +package model + +import ( + "errors" + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestTag_Create(t *testing.T) { + asserts := assert.New(t) + tag := Tag{} + + // 成功 + { + mock.ExpectBegin() + mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + id, err := tag.Create() + asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(err) + asserts.EqualValues(1, id) + } + + // 失败 + { + mock.ExpectBegin() + mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) + mock.ExpectRollback() + id, err := tag.Create() + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Error(err) + asserts.EqualValues(0, id) + } +} + +func TestDeleteTagByID(t *testing.T) { + asserts := assert.New(t) + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + err := DeleteTagByID(1, 2) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(err) +} + +func TestGetTagsByUID(t *testing.T) { + asserts := assert.New(t) + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + res, err := GetTagsByUID(1) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(err) + asserts.Len(res, 1) +} + +func TestGetTagsByID(t *testing.T) { + asserts := assert.New(t) + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + res, err := GetTasksByID(1) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(err) + asserts.EqualValues(1, res.ID) +} diff --git a/models/task_test.go b/models/task_test.go index 9978641..32619ed 100644 --- a/models/task_test.go +++ b/models/task_test.go @@ -70,3 +70,12 @@ func TestTask_SetProgress(t *testing.T) { asserts.NoError(task.SetProgress(1)) asserts.NoError(mock.ExpectationsWereMet()) } + +func TestGetTasksByID(t *testing.T) { + asserts := assert.New(t) + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + res, err := GetTasksByID(1) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(err) + asserts.EqualValues(1, res.ID) +} diff --git a/pkg/aria2/aria2_test.go b/pkg/aria2/aria2_test.go index e059c24..aebbe37 100644 --- a/pkg/aria2/aria2_test.go +++ b/pkg/aria2/aria2_test.go @@ -36,6 +36,7 @@ func TestDummyAria2(t *testing.T) { } func TestInit(t *testing.T) { + MAX_RETRY = 0 asserts := assert.New(t) cache.Set("setting_aria2_token", "1", 0) cache.Set("setting_aria2_call_timeout", "5", 0) diff --git a/pkg/aria2/monitor.go b/pkg/aria2/monitor.go index 9579051..09eb343 100644 --- a/pkg/aria2/monitor.go +++ b/pkg/aria2/monitor.go @@ -32,6 +32,8 @@ type StatusEvent struct { Status int } +var MAX_RETRY = 10 + // NewMonitor 新建上传状态监控 func NewMonitor(task *model.Download) { monitor := &Monitor{ @@ -73,7 +75,7 @@ func (monitor *Monitor) Update() bool { util.Log().Warning("无法获取下载任务[%s]的状态,%s", monitor.Task.GID, err) // 十次重试后认定为任务失败 - if monitor.retried > 10 { + if monitor.retried > MAX_RETRY { util.Log().Warning("无法获取下载任务[%s]的状态,超过最大重试次数限制,%s", monitor.Task.GID, err) monitor.setErrorStatus(err) monitor.RemoveTempFolder() diff --git a/pkg/aria2/monitor_test.go b/pkg/aria2/monitor_test.go index d85e7d7..3596208 100644 --- a/pkg/aria2/monitor_test.go +++ b/pkg/aria2/monitor_test.go @@ -50,6 +50,7 @@ func TestNewMonitor(t *testing.T) { func TestMonitor_Loop(t *testing.T) { asserts := assert.New(t) notifier := make(chan StatusEvent) + MAX_RETRY = 0 monitor := &Monitor{ Task: &model.Download{GID: "gid"}, Interval: time.Duration(1) * time.Second, @@ -72,11 +73,13 @@ func TestMonitor_Update(t *testing.T) { // 无法获取状态 { + MAX_RETRY = 1 testInstance := new(InstanceMock) testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, errors.New("error")) file, _ := util.CreatNestedFile("TestMonitor_Update/1") file.Close() Instance = testInstance + asserts.False(monitor.Update()) asserts.True(monitor.Update()) testInstance.AssertExpectations(t) asserts.False(util.Exists("TestMonitor_Update")) diff --git a/pkg/filesystem/file_test.go b/pkg/filesystem/file_test.go index 2c5348d..f4f590a 100644 --- a/pkg/filesystem/file_test.go +++ b/pkg/filesystem/file_test.go @@ -73,7 +73,7 @@ func TestFileSystem_GetContent(t *testing.T) { } // 文件不存在 - rs, err := fs.GetContent(ctx, "not exist file") + rs, err := fs.GetContent(ctx, 1) asserts.Equal(ErrObjectNotExist, err) asserts.Nil(rs) fs.CleanTargets() @@ -84,39 +84,30 @@ func TestFileSystem_GetContent(t *testing.T) { _ = file.Close() cache.Deletes([]string{"1"}, "policy_") - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id"}).AddRow(1, "TestFileSystem_GetContent.txt", 1)) mock.ExpectQuery("SELECT(.+)poli(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(1, "unknown")) - rs, err = fs.GetContent(ctx, "/TestFileSystem_GetContent.txt") + rs, err = fs.GetContent(ctx, 1) asserts.Error(err) asserts.NoError(mock.ExpectationsWereMet()) fs.CleanTargets() // 打开文件失败 cache.Deletes([]string{"1"}, "policy_") - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id"}).AddRow(1, "TestFileSystem_GetContent.txt", 1)) mock.ExpectQuery("SELECT(.+)poli(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type", "source_name"}).AddRow(1, "local", "not exist")) - rs, err = fs.GetContent(ctx, "/TestFileSystem_GetContent.txt") + rs, err = fs.GetContent(ctx, 1) asserts.Equal(serializer.CodeIOFailed, err.(serializer.AppError).Code) asserts.NoError(mock.ExpectationsWereMet()) fs.CleanTargets() // 打开成功 cache.Deletes([]string{"1"}, "policy_") - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id", "source_name"}).AddRow(1, "TestFileSystem_GetContent.txt", 1, "TestFileSystem_GetContent.txt")) mock.ExpectQuery("SELECT(.+)poli(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(1, "local")) - rs, err = fs.GetContent(ctx, "/TestFileSystem_GetContent.txt") + rs, err = fs.GetContent(ctx, 1) asserts.NoError(err) asserts.NoError(mock.ExpectationsWereMet()) } @@ -141,29 +132,23 @@ func TestFileSystem_GetDownloadContent(t *testing.T) { _ = file.Close() cache.Deletes([]string{"599"}, "policy_") - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id", "source_name"}).AddRow(1, "TestFileSystem_GetDownloadContent.txt", 599, "TestFileSystem_GetDownloadContent.txt")) mock.ExpectQuery("SELECT(.+)poli(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(1, "local")) // 无限速 cache.Deletes([]string{"599"}, "policy_") - _, err = fs.GetDownloadContent(ctx, "/TestFileSystem_GetDownloadContent.txt") + _, err = fs.GetDownloadContent(ctx, 1) asserts.NoError(err) asserts.NoError(mock.ExpectationsWereMet()) fs.CleanTargets() // 有限速 cache.Deletes([]string{"599"}, "policy_") - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id", "source_name"}).AddRow(1, "TestFileSystem_GetDownloadContent.txt", 599, "TestFileSystem_GetDownloadContent.txt")) mock.ExpectQuery("SELECT(.+)poli(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(1, "local")) fs.User.Group.SpeedLimit = 1 - _, err = fs.GetDownloadContent(ctx, "/TestFileSystem_GetDownloadContent.txt") + _, err = fs.GetDownloadContent(ctx, 1) asserts.NoError(err) asserts.NoError(mock.ExpectationsWereMet()) } @@ -411,9 +396,6 @@ func TestFileSystem_GetDownloadURL(t *testing.T) { cache.Set("setting_siteURL", "https://cloudreve.org", 0) asserts.NoError(err) // 查找文件 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id"}).AddRow(1, "1.txt", 35)) // 查找上传策略 mock.ExpectQuery("SELECT(.+)"). @@ -422,7 +404,7 @@ func TestFileSystem_GetDownloadURL(t *testing.T) { AddRow(35, "local", true), ) // 相关设置 - downloadURL, err := fs.GetDownloadURL(ctx, "/1.txt", "download_timeout") + downloadURL, err := fs.GetDownloadURL(ctx, 1, "download_timeout") asserts.NoError(mock.ExpectationsWereMet()) asserts.NoError(err) asserts.NotEmpty(downloadURL) @@ -436,12 +418,9 @@ func TestFileSystem_GetDownloadURL(t *testing.T) { err = cache.Deletes([]string{"download_timeout"}, "setting_") asserts.NoError(err) // 查找文件 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id"})) - downloadURL, err := fs.GetDownloadURL(ctx, "/1.txt", "download_timeout") + downloadURL, err := fs.GetDownloadURL(ctx, 1, "download_timeout") asserts.NoError(mock.ExpectationsWereMet()) asserts.Error(err) asserts.Empty(downloadURL) @@ -455,9 +434,6 @@ func TestFileSystem_GetDownloadURL(t *testing.T) { err = cache.Deletes([]string{"download_timeout"}, "setting_") asserts.NoError(err) // 查找文件 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id"}).AddRow(1, "1.txt", 35)) // 查找上传策略 mock.ExpectQuery("SELECT(.+)"). @@ -466,7 +442,7 @@ func TestFileSystem_GetDownloadURL(t *testing.T) { AddRow(35, "unknown", true), ) - downloadURL, err := fs.GetDownloadURL(ctx, "/1.txt", "download_timeout") + downloadURL, err := fs.GetDownloadURL(ctx, 1, "download_timeout") asserts.NoError(mock.ExpectationsWereMet()) asserts.Error(err) asserts.Empty(downloadURL) @@ -511,7 +487,7 @@ func TestFileSystem_Preview(t *testing.T) { User: &model.User{}, } mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"})) - resp, err := fs.Preview(ctx, "/1.txt", false) + resp, err := fs.Preview(ctx, 1, false) asserts.NoError(mock.ExpectationsWereMet()) asserts.Error(err) asserts.Nil(resp) @@ -531,7 +507,7 @@ func TestFileSystem_Preview(t *testing.T) { }, }, } - resp, err := fs.Preview(ctx, "/1.txt", false) + resp, err := fs.Preview(ctx, 1, false) asserts.Error(err) asserts.Nil(resp) } @@ -551,7 +527,7 @@ func TestFileSystem_Preview(t *testing.T) { }, }, } - resp, err := fs.Preview(ctx, "/1.txt", false) + resp, err := fs.Preview(ctx, 1, false) asserts.NoError(err) asserts.NotNil(resp) asserts.False(resp.Redirect) @@ -574,7 +550,7 @@ func TestFileSystem_Preview(t *testing.T) { }, } asserts.NoError(cache.Set("setting_preview_timeout", "233", 0)) - resp, err := fs.Preview(ctx, "/1.txt", false) + resp, err := fs.Preview(ctx, 1, false) asserts.NoError(err) asserts.NotNil(resp) asserts.True(resp.Redirect) @@ -597,7 +573,7 @@ func TestFileSystem_Preview(t *testing.T) { }, } asserts.NoError(cache.Set("setting_maxEditSize", "10", 0)) - resp, err := fs.Preview(ctx, "/1.txt", true) + resp, err := fs.Preview(ctx, 1, true) asserts.Equal(ErrFileSizeTooBig, err) asserts.Nil(resp) } @@ -615,3 +591,18 @@ func TestFileSystem_ResetFileIDIfNotExist(t *testing.T) { } asserts.Equal(ErrObjectNotExist, fs.resetFileIDIfNotExist(ctx, 1)) } + +func TestFileSystem_Search(t *testing.T) { + asserts := assert.New(t) + ctx := context.Background() + fs := &FileSystem{ + User: &model.User{}, + } + fs.User.ID = 1 + + mock.ExpectQuery("SELECT(.+)").WithArgs(1, "k1", "k2").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + res, err := fs.Search(ctx, "k1", "k2") + asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(err) + asserts.Len(res, 1) +} diff --git a/pkg/filesystem/hooks_test.go b/pkg/filesystem/hooks_test.go index 55e3a5e..0ebce5c 100644 --- a/pkg/filesystem/hooks_test.go +++ b/pkg/filesystem/hooks_test.go @@ -136,7 +136,7 @@ func TestGenericAfterUpload(t *testing.T) { mock.NewRows([]string{"name"}), ) err = GenericAfterUpload(ctx, &fs) - asserts.Equal(ErrPathNotExist, err) + asserts.Equal(ErrRootProtected, err) asserts.NoError(mock.ExpectationsWereMet()) // 文件已存在 diff --git a/pkg/filesystem/manage.go b/pkg/filesystem/manage.go index 4ec9c25..af4e728 100644 --- a/pkg/filesystem/manage.go +++ b/pkg/filesystem/manage.go @@ -390,7 +390,6 @@ func (fs *FileSystem) CreateDirectory(ctx context.Context, fullPath string) (*mo } // SaveTo 将别人分享的文件转存到目标路径下 -// TODO 测试 func (fs *FileSystem) SaveTo(ctx context.Context, path string) error { // 获取父目录 isExist, folder := fs.IsPathExist(path) diff --git a/pkg/serializer/aria2_test.go b/pkg/serializer/aria2_test.go new file mode 100644 index 0000000..1c7b2b2 --- /dev/null +++ b/pkg/serializer/aria2_test.go @@ -0,0 +1,92 @@ +package serializer + +import ( + model "github.com/HFO4/cloudreve/models" + "github.com/HFO4/cloudreve/pkg/cache" + "github.com/jinzhu/gorm" + "github.com/stretchr/testify/assert" + "github.com/zyxar/argo/rpc" + "testing" +) + +func TestBuildFinishedListResponse(t *testing.T) { + asserts := assert.New(t) + tasks := []model.Download{ + { + StatusInfo: rpc.StatusInfo{ + Files: []rpc.FileInfo{ + { + Path: "/file/name.txt", + }, + }, + }, + Task: &model.Task{ + Model: gorm.Model{}, + Error: "error", + }, + }, + { + StatusInfo: rpc.StatusInfo{ + Files: []rpc.FileInfo{ + { + Path: "/file/name1.txt", + }, + { + Path: "/file/name2.txt", + }, + }, + }, + }, + } + tasks[1].StatusInfo.BitTorrent.Info.Name = "name.txt" + res := BuildFinishedListResponse(tasks).Data.([]FinishedListResponse) + asserts.Len(res, 2) + asserts.Equal("name.txt", res[1].Name) + asserts.Equal("name.txt", res[0].Name) + asserts.Equal("name.txt", res[0].Files[0].Path) + asserts.Equal("name1.txt", res[1].Files[0].Path) + asserts.Equal("name2.txt", res[1].Files[1].Path) + asserts.EqualValues(0, res[0].TaskStatus) + asserts.Equal("error", res[0].TaskError) +} + +func TestBuildDownloadingResponse(t *testing.T) { + asserts := assert.New(t) + cache.Set("setting_aria2_interval", "10", 0) + tasks := []model.Download{ + { + StatusInfo: rpc.StatusInfo{ + Files: []rpc.FileInfo{ + { + Path: "/file/name.txt", + }, + }, + }, + Task: &model.Task{ + Model: gorm.Model{}, + Error: "error", + }, + }, + { + StatusInfo: rpc.StatusInfo{ + Files: []rpc.FileInfo{ + { + Path: "/file/name1.txt", + }, + { + Path: "/file/name2.txt", + }, + }, + }, + }, + } + tasks[1].StatusInfo.BitTorrent.Info.Name = "name.txt" + + res := BuildDownloadingResponse(tasks).Data.([]DownloadListResponse) + asserts.Len(res, 2) + asserts.Equal("name1.txt", res[1].Name) + asserts.Equal("name.txt", res[0].Name) + asserts.Equal("name.txt", res[0].Info.Files[0].Path) + asserts.Equal("name1.txt", res[1].Info.Files[0].Path) + asserts.Equal("name2.txt", res[1].Info.Files[1].Path) +} diff --git a/pkg/serializer/user.go b/pkg/serializer/user.go index e84905e..c118e2b 100644 --- a/pkg/serializer/user.go +++ b/pkg/serializer/user.go @@ -92,7 +92,7 @@ func BuildUser(user model.User) User { ShareDownload: user.Group.OptionsSerialized.ShareDownload, CompressEnabled: user.Group.OptionsSerialized.ArchiveTask, }, - Tags: BuildTagRes(tags), + Tags: buildTagRes(tags), } } @@ -121,20 +121,20 @@ func BuildUserStorageResponse(user model.User) Response { } } -// BuildTagRes 构建标签列表 -func BuildTagRes(tags []model.Tag) []tag { +// buildTagRes 构建标签列表 +func buildTagRes(tags []model.Tag) []tag { res := make([]tag, 0, len(tags)) for i := 0; i < len(tags); i++ { newTag := tag{ - ID: hashid.HashID(tags[i].ID, hashid.TagID), - Name: tags[i].Name, - Icon: tags[i].Icon, - Color: tags[i].Color, - Type: tags[i].Type, - Expression: tags[i].Expression, + ID: hashid.HashID(tags[i].ID, hashid.TagID), + Name: tags[i].Name, + Icon: tags[i].Icon, + Color: tags[i].Color, + Type: tags[i].Type, } - if newTag.Type == 0 { - newTag.Expression = "" + if newTag.Type != 0 { + newTag.Expression = tags[i].Expression + } res = append(res, newTag) } diff --git a/pkg/serializer/user_test.go b/pkg/serializer/user_test.go index 06eafb4..1c2159a 100644 --- a/pkg/serializer/user_test.go +++ b/pkg/serializer/user_test.go @@ -1,18 +1,38 @@ package serializer import ( + "database/sql" + "github.com/DATA-DOG/go-sqlmock" model "github.com/HFO4/cloudreve/models" "github.com/HFO4/cloudreve/pkg/cache" + "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" "testing" ) +var mock sqlmock.Sqlmock + +// TestMain 初始化数据库Mock +func TestMain(m *testing.M) { + var db *sql.DB + var err error + db, mock, err = sqlmock.New() + if err != nil { + panic("An error was not expected when opening a stub database connection") + } + model.DB, _ = gorm.Open("mysql", db) + defer db.Close() + m.Run() +} + func TestBuildUser(t *testing.T) { asserts := assert.New(t) user := model.User{ Policy: model.Policy{MaxSize: 1024 * 1024}, } + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"})) res := BuildUser(user) + asserts.NoError(mock.ExpectationsWereMet()) asserts.Equal("1.00mb", res.Policy.MaxSize) } @@ -72,3 +92,21 @@ func TestBuildUserStorageResponse(t *testing.T) { asserts.Equal(uint64(5), res.Data.(storage).Free) } } + +func TestBuildTagRes(t *testing.T) { + asserts := assert.New(t) + tags := []model.Tag{ + { + Type: 0, + Expression: "exp", + }, + { + Type: 1, + Expression: "exp", + }, + } + res := buildTagRes(tags) + asserts.Len(res, 2) + asserts.Equal("", res[0].Expression) + asserts.Equal("exp", res[1].Expression) +}