From d0bb123e03709f03ff1daa0f21cf30ddcf098c3b Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Wed, 11 Dec 2019 12:24:09 +0800 Subject: [PATCH] Test: get source URL of files --- conf/conf.ini | 22 +++--- middleware/auth.go | 1 - models/group.go | 1 - models/migration.go | 3 +- models/setting_test.go | 26 +++++++ pkg/auth/auth.go | 1 - pkg/cache/driver.go | 7 ++ pkg/cache/memo.go | 8 ++ pkg/cache/memo_test.go | 19 +++++ pkg/cache/redis.go | 31 ++++++-- pkg/cache/redis_test.go | 44 +++++++++++ pkg/filesystem/file.go | 8 +- pkg/filesystem/file_test.go | 108 ++++++++++++++++++++++++++ pkg/filesystem/filesystem.go | 7 +- pkg/filesystem/filesystem_test.go | 61 +++++++++++++++ pkg/filesystem/image.go | 1 - pkg/filesystem/local/handller_test.go | 39 ++++++++++ routers/router_test.go | 4 +- 18 files changed, 363 insertions(+), 28 deletions(-) diff --git a/conf/conf.ini b/conf/conf.ini index f9b0ebb..25ff95b 100644 --- a/conf/conf.ini +++ b/conf/conf.ini @@ -7,18 +7,18 @@ MaxWidth = 400 MaxHeight = 300 FileSuffix = ._thumb -; [Database] -; Type = mysql -; User = root -; Password = root -; Host = 127.0.0.1:3306 -; Name = v3 -; TablePrefix = v3_ +[Database] +Type = mysql +User = root +Password = root +Host = 127.0.0.1:3306 +Name = v3 +TablePrefix = v3_ -; [Redis] -; Server = 127.0.0.1:6379 -; Password = -; DB = 0 +[Redis] +Server = 127.0.0.1:6379 +Password = +DB = 0 [Captcha] Height = 60 diff --git a/middleware/auth.go b/middleware/auth.go index 963364c..a9b2e8a 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -9,7 +9,6 @@ import ( ) // SignRequired 验证请求签名 -// TODO 测试 func SignRequired() gin.HandlerFunc { return func(c *gin.Context) { err := auth.CheckURI(c.Request.URL) diff --git a/models/group.go b/models/group.go index 64bae46..51ce165 100644 --- a/models/group.go +++ b/models/group.go @@ -22,7 +22,6 @@ type Group struct { } // GetAria2Option 获取用户离线下载设备 -// TODO:测试 func (group *Group) GetAria2Option() [3]bool { if len(group.Aria2Option) != 5 { return [3]bool{false, false, false} diff --git a/models/migration.go b/models/migration.go index b0b700d..349e146 100644 --- a/models/migration.go +++ b/models/migration.go @@ -195,7 +195,8 @@ func addDefaultGroups() { // 未找到初始游客用户组时,则创建 if gorm.IsRecordNotFoundError(err) { defaultAdminGroup := Group{ - Name: "游客", + Name: "游客", + Policies: "[]", } if err := DB.Create(&defaultAdminGroup).Error; err != nil { util.Log().Panic("无法创建初始游客用户组, %s", err) diff --git a/models/setting_test.go b/models/setting_test.go index c4ebe99..d9b1b9b 100644 --- a/models/setting_test.go +++ b/models/setting_test.go @@ -140,3 +140,29 @@ func TestIsTrueVal(t *testing.T) { asserts.False(IsTrueVal("0")) asserts.False(IsTrueVal("false")) } + +func TestGetSiteURL(t *testing.T) { + asserts := assert.New(t) + + // 正常 + { + err := cache.Deletes([]string{"siteURL"}, "setting_") + asserts.NoError(err) + + mock.ExpectQuery("SELECT(.+)").WithArgs("siteURL").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, "https://drive.cloudreve.org")) + siteURL := GetSiteURL() + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Equal("https://drive.cloudreve.org", siteURL.String()) + } + + // 失败 返回默认值 + { + err := cache.Deletes([]string{"siteURL"}, "setting_") + asserts.NoError(err) + + mock.ExpectQuery("SELECT(.+)").WithArgs("siteURL").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, ":][\\/\\]sdf")) + siteURL := GetSiteURL() + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Equal("https://cloudreve.org", siteURL.String()) + } +} diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 11cbc27..2c9cc75 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -23,7 +23,6 @@ type Auth interface { } // SignURI 对URI进行签名,签名只针对Path部分,query部分不做验证 -// TODO 测试 func SignURI(uri string, expires int64) (*url.URL, error) { base, err := url.Parse(uri) if err != nil { diff --git a/pkg/cache/driver.go b/pkg/cache/driver.go index 8a59e92..33e6206 100644 --- a/pkg/cache/driver.go +++ b/pkg/cache/driver.go @@ -33,6 +33,8 @@ type Driver interface { Gets(keys []string, prefix string) (map[string]interface{}, []string) // 批量设置值 Sets(values map[string]interface{}, prefix string) error + // 删除值 + Delete(keys []string, prefix string) error } // Set 设置缓存值 @@ -45,6 +47,11 @@ func Get(key string) (interface{}, bool) { return Store.Get(key) } +// Deletes 删除值 +func Deletes(keys []string, prefix string) error { + return Store.Delete(keys, prefix) +} + // GetSettings 根据名称批量获取设置项缓存 func GetSettings(keys []string, prefix string) (map[string]string, []string) { raw, miss := Store.Gets(keys, prefix) diff --git a/pkg/cache/memo.go b/pkg/cache/memo.go index 858665e..4975544 100644 --- a/pkg/cache/memo.go +++ b/pkg/cache/memo.go @@ -48,3 +48,11 @@ func (store *MemoStore) Sets(values map[string]interface{}, prefix string) error } return nil } + +// Delete 批量删除值 +func (store *MemoStore) Delete(keys []string, prefix string) error { + for _, key := range keys { + store.Store.Delete(prefix + key) + } + return nil +} diff --git a/pkg/cache/memo_test.go b/pkg/cache/memo_test.go index 2f0be0b..bfce404 100644 --- a/pkg/cache/memo_test.go +++ b/pkg/cache/memo_test.go @@ -106,3 +106,22 @@ func TestMemoStore_Sets(t *testing.T) { "4": "4.val", }, vals) } + +func TestMemoStore_Delete(t *testing.T) { + asserts := assert.New(t) + store := NewMemoStore() + + err := store.Sets(map[string]interface{}{ + "1": "1.val", + "2": "2.val", + "3": "3.val", + "4": "4.val", + }, "test_") + asserts.NoError(err) + + err = store.Delete([]string{"1", "2"}, "test_") + asserts.NoError(err) + values, miss := store.Gets([]string{"1", "2", "3", "4"}, "test_") + asserts.Equal([]string{"1", "2"}, miss) + asserts.Equal(map[string]interface{}{"3": "3.val", "4": "4.val"}, values) +} diff --git a/pkg/cache/redis.go b/pkg/cache/redis.go index f7fa63f..fa7819e 100644 --- a/pkg/cache/redis.go +++ b/pkg/cache/redis.go @@ -169,13 +169,30 @@ func (store *RedisStore) Sets(values map[string]interface{}, prefix string) erro setValues[prefix+key] = serialized } - if rc.Err() == nil { - _, err := rc.Do("MSET", redis.Args{}.AddFlat(setValues)...) - if err != nil { - return err - } - return nil + _, err := rc.Do("MSET", redis.Args{}.AddFlat(setValues)...) + if err != nil { + return err + } + return nil + +} + +// Delete 批量删除给定的键 +func (store *RedisStore) Delete(keys []string, prefix string) error { + rc := store.pool.Get() + defer rc.Close() + if rc.Err() != nil { + return rc.Err() + } + + // 处理前缀 + for i := 0; i < len(keys); i++ { + keys[i] = prefix + keys[i] } - return rc.Err() + _, err := rc.Do("DEL", redis.Args{}.AddFlat(keys)...) + if err != nil { + return err + } + return nil } diff --git a/pkg/cache/redis_test.go b/pkg/cache/redis_test.go index ebdd69b..e2cdac9 100644 --- a/pkg/cache/redis_test.go +++ b/pkg/cache/redis_test.go @@ -266,3 +266,47 @@ func TestRedisStore_Sets(t *testing.T) { asserts.Error(err) } } + +func TestRedisStore_Delete(t *testing.T) { + asserts := assert.New(t) + conn := redigomock.NewConn() + pool := &redis.Pool{ + Dial: func() (redis.Conn, error) { return conn, nil }, + MaxIdle: 10, + } + store := &RedisStore{pool: pool} + + // 正常 + { + cmd := conn.Command("DEL", redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData()).ExpectSlice("OK") + err := store.Delete([]string{"1", "2", "3", "4"}, "test_") + asserts.NoError(err) + if conn.Stats(cmd) != 1 { + fmt.Println("Command was not used") + return + } + } + + // 命令执行失败 + { + conn.Clear() + cmd := conn.Command("DEL", redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData()).ExpectError(errors.New("error")) + err := store.Delete([]string{"1", "2", "3", "4"}, "test_") + asserts.Error(err) + if conn.Stats(cmd) != 1 { + fmt.Println("Command was not used") + return + } + } + + // 连接失败 + { + conn.Clear() + store.pool = &redis.Pool{ + Dial: func() (redis.Conn, error) { return nil, errors.New("error") }, + MaxIdle: 10, + } + err := store.Delete([]string{"1", "2", "3", "4"}, "test_") + asserts.Error(err) + } +} diff --git a/pkg/filesystem/file.go b/pkg/filesystem/file.go index 610e2f8..99fbe40 100644 --- a/pkg/filesystem/file.go +++ b/pkg/filesystem/file.go @@ -72,7 +72,6 @@ func (fs *FileSystem) GetDownloadContent(ctx context.Context, path string) (io.R } // GetContent 获取文件内容,path为虚拟路径 -// TODO:测试 func (fs *FileSystem) GetContent(ctx context.Context, path string) (io.ReadSeeker, error) { // 触发`下载前`钩子 err := fs.Trigger(ctx, fs.BeforeFileDownload) @@ -94,6 +93,7 @@ func (fs *FileSystem) GetContent(ctx context.Context, path string) (io.ReadSeeke // 将当前存储策略重设为文件使用的 fs.Policy = fs.FileTarget[0].GetPolicy() err = fs.dispatchHandler() + defer fs.CleanTargets() if err != nil { return nil, err } @@ -176,7 +176,11 @@ func (fs *FileSystem) GetSource(ctx context.Context, fileID uint) (string, error // 检查存储策略是否可以获得外链 if !fs.Policy.IsOriginLinkEnable { - return "", serializer.NewError(serializer.CodePolicyNotAllowed, "当前存储策略无法获得外链", nil) + return "", serializer.NewError( + serializer.CodePolicyNotAllowed, + "当前存储策略无法获得外链", + nil, + ) } // 生成外链地址 diff --git a/pkg/filesystem/file_test.go b/pkg/filesystem/file_test.go index 6b1dd87..6d91abb 100644 --- a/pkg/filesystem/file_test.go +++ b/pkg/filesystem/file_test.go @@ -4,6 +4,8 @@ import ( "context" "github.com/DATA-DOG/go-sqlmock" model "github.com/HFO4/cloudreve/models" + "github.com/HFO4/cloudreve/pkg/auth" + "github.com/HFO4/cloudreve/pkg/cache" "github.com/HFO4/cloudreve/pkg/filesystem/fsctx" "github.com/HFO4/cloudreve/pkg/filesystem/local" "github.com/HFO4/cloudreve/pkg/serializer" @@ -263,3 +265,109 @@ func TestFileSystem_deleteGroupedFile(t *testing.T) { }, failed) } } + +func TestFileSystem_GetSource(t *testing.T) { + asserts := assert.New(t) + ctx := context.Background() + fs := FileSystem{ + User: &model.User{Model: gorm.Model{ID: 1}}, + } + auth.General = auth.HMACAuth{SecretKey: []byte("123")} + + // 正常 + { + // 清空缓存 + err := cache.Deletes([]string{"siteURL"}, "setting_") + asserts.NoError(err) + // 查找文件 + mock.ExpectQuery("SELECT(.+)"). + WithArgs(2, 1). + WillReturnRows( + sqlmock.NewRows([]string{"id", "policy_id", "source_name"}). + AddRow(2, 35, "1.txt"), + ) + // 查找上传策略 + mock.ExpectQuery("SELECT(.+)"). + WillReturnRows( + sqlmock.NewRows([]string{"id", "type", "is_origin_link_enable"}). + AddRow(35, "local", true), + ) + // 查找站点URL + mock.ExpectQuery("SELECT(.+)").WithArgs("siteURL").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, "https://cloudreve.org")) + + sourceURL, err := fs.GetSource(ctx, 2) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(err) + asserts.NotEmpty(sourceURL) + } + + // 文件不存在 + { + // 清空缓存 + err := cache.Deletes([]string{"siteURL"}, "setting_") + asserts.NoError(err) + // 查找文件 + mock.ExpectQuery("SELECT(.+)"). + WithArgs(2, 1). + WillReturnRows( + sqlmock.NewRows([]string{"id", "policy_id", "source_name"}), + ) + + sourceURL, err := fs.GetSource(ctx, 2) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Error(err) + asserts.Equal(ErrObjectNotExist.Code, err.(serializer.AppError).Code) + asserts.Empty(sourceURL) + } + + // 未知上传策略 + { + // 清空缓存 + err := cache.Deletes([]string{"siteURL"}, "setting_") + asserts.NoError(err) + // 查找文件 + mock.ExpectQuery("SELECT(.+)"). + WithArgs(2, 1). + WillReturnRows( + sqlmock.NewRows([]string{"id", "policy_id", "source_name"}). + AddRow(2, 35, "1.txt"), + ) + // 查找上传策略 + mock.ExpectQuery("SELECT(.+)"). + WillReturnRows( + sqlmock.NewRows([]string{"id", "type", "is_origin_link_enable"}). + AddRow(35, "?", true), + ) + + sourceURL, err := fs.GetSource(ctx, 2) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Error(err) + asserts.Empty(sourceURL) + } + + // 不允许获取外链 + { + // 清空缓存 + err := cache.Deletes([]string{"siteURL"}, "setting_") + asserts.NoError(err) + // 查找文件 + mock.ExpectQuery("SELECT(.+)"). + WithArgs(2, 1). + WillReturnRows( + sqlmock.NewRows([]string{"id", "policy_id", "source_name"}). + AddRow(2, 35, "1.txt"), + ) + // 查找上传策略 + mock.ExpectQuery("SELECT(.+)"). + WillReturnRows( + sqlmock.NewRows([]string{"id", "type", "is_origin_link_enable"}). + AddRow(35, "local", false), + ) + + sourceURL, err := fs.GetSource(ctx, 2) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Error(err) + asserts.Equal(serializer.CodePolicyNotAllowed, err.(serializer.AppError).Code) + asserts.Empty(sourceURL) + } +} diff --git a/pkg/filesystem/filesystem.go b/pkg/filesystem/filesystem.go index b52aea3..1ab7369 100644 --- a/pkg/filesystem/filesystem.go +++ b/pkg/filesystem/filesystem.go @@ -80,7 +80,6 @@ func NewFileSystem(user *model.User) (*FileSystem, error) { } // NewAnonymousFileSystem 初始化匿名文件系统 -// TODO 测试 func NewAnonymousFileSystem() (*FileSystem, error) { fs := &FileSystem{ User: &model.User{}, @@ -160,3 +159,9 @@ func (fs *FileSystem) SetTargetFileByIDs(ids []uint) error { fs.SetTargetFile(&files) return nil } + +// CleanTargets 清空目标 +func (fs *FileSystem) CleanTargets() { + fs.FileTarget = []model.File{} + fs.DirTarget = []model.Folder{} +} diff --git a/pkg/filesystem/filesystem_test.go b/pkg/filesystem/filesystem_test.go index b5c34aa..5e97e2a 100644 --- a/pkg/filesystem/filesystem_test.go +++ b/pkg/filesystem/filesystem_test.go @@ -1,6 +1,7 @@ package filesystem import ( + "github.com/DATA-DOG/go-sqlmock" model "github.com/HFO4/cloudreve/models" "github.com/HFO4/cloudreve/pkg/filesystem/local" "github.com/gin-gonic/gin" @@ -63,3 +64,63 @@ func TestDispatchHandler(t *testing.T) { err = fs.dispatchHandler() asserts.Error(err) } + +func TestFileSystem_SetTargetFileByIDs(t *testing.T) { + asserts := assert.New(t) + + // 成功 + { + fs := &FileSystem{} + mock.ExpectQuery("SELECT(.+)"). + WithArgs(1, 2). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1.txt")) + err := fs.SetTargetFileByIDs([]uint{1, 2}) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Len(fs.FileTarget, 1) + asserts.NoError(err) + } + + // 未找到 + { + fs := &FileSystem{} + mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) + err := fs.SetTargetFileByIDs([]uint{1, 2}) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Len(fs.FileTarget, 0) + asserts.Error(err) + } +} + +func TestFileSystem_CleanTargets(t *testing.T) { + asserts := assert.New(t) + fs := &FileSystem{ + FileTarget: []model.File{{}, {}}, + DirTarget: []model.Folder{{}, {}}, + } + + fs.CleanTargets() + asserts.Len(fs.FileTarget, 0) + asserts.Len(fs.DirTarget, 0) +} + +func TestNewAnonymousFileSystem(t *testing.T) { + asserts := assert.New(t) + + // 正常 + { + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policies"}).AddRow(3, "游客", "[]")) + fs, err := NewAnonymousFileSystem() + asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(err) + asserts.Equal("游客", fs.User.Group.Name) + } + + // 游客用户组不存在 + { + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policies"})) + fs, err := NewAnonymousFileSystem() + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Error(err) + asserts.Nil(fs) + } +} diff --git a/pkg/filesystem/image.go b/pkg/filesystem/image.go index 35586a2..53522fc 100644 --- a/pkg/filesystem/image.go +++ b/pkg/filesystem/image.go @@ -29,7 +29,6 @@ func (fs *FileSystem) GetThumb(ctx context.Context, id uint) (*response.ContentR } fs.FileTarget = []model.File{file[0]} - res, err := fs.Handler.Thumb(ctx, file[0].SourceName) // TODO 出错时重新生成缩略图 diff --git a/pkg/filesystem/local/handller_test.go b/pkg/filesystem/local/handller_test.go index 71da4ca..258c02f 100644 --- a/pkg/filesystem/local/handller_test.go +++ b/pkg/filesystem/local/handller_test.go @@ -2,11 +2,16 @@ package local import ( "context" + model "github.com/HFO4/cloudreve/models" + "github.com/HFO4/cloudreve/pkg/auth" "github.com/HFO4/cloudreve/pkg/conf" + "github.com/HFO4/cloudreve/pkg/filesystem/fsctx" "github.com/HFO4/cloudreve/pkg/util" + "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" "io" "io/ioutil" + "net/url" "os" "strings" "testing" @@ -112,3 +117,37 @@ func TestHandler_Thumb(t *testing.T) { asserts.Error(err) } } + +func TestHandler_Source(t *testing.T) { + asserts := assert.New(t) + handler := Handler{} + ctx := context.Background() + auth.General = auth.HMACAuth{SecretKey: []byte("test")} + + // 成功 + { + file := model.File{ + Model: gorm.Model{ + ID: 1, + }, + Name: "test.jpg", + } + ctx := context.WithValue(ctx, fsctx.FileModelCtx, file) + baseURL, err := url.Parse("https://cloudreve.org") + asserts.NoError(err) + sourceURL, err := handler.Source(ctx, "", *baseURL, 0) + asserts.NoError(err) + asserts.NotEmpty(sourceURL) + asserts.Contains(sourceURL, "sign=") + asserts.Contains(sourceURL, "https://cloudreve.org") + } + + // 无法获取上下文 + { + baseURL, err := url.Parse("https://cloudreve.org") + asserts.NoError(err) + sourceURL, err := handler.Source(ctx, "", *baseURL, 0) + asserts.Error(err) + asserts.Empty(sourceURL) + } +} diff --git a/routers/router_test.go b/routers/router_test.go index 572564c..0b0c56c 100644 --- a/routers/router_test.go +++ b/routers/router_test.go @@ -28,7 +28,7 @@ func TestCaptcha(t *testing.T) { req, _ := http.NewRequest( "GET", - "/api/v3/captcha", + "/api/v3/site/captcha", nil, ) @@ -239,7 +239,7 @@ func TestSiteConfigRoute(t *testing.T) { ) router.ServeHTTP(w, req) asserts.Equal(200, w.Code) - asserts.Contains(w.Body.String(), "\"title\":\"\"") + asserts.Contains(w.Body.String(), "\"title\"") model.DB.Model(&model.Setting{ Model: gorm.Model{