diff --git a/models/policy_test.go b/models/policy_test.go index a4c78a4..c253877 100644 --- a/models/policy_test.go +++ b/models/policy_test.go @@ -167,7 +167,7 @@ func TestPolicy_GetUploadURL(t *testing.T) { { cache.Set("setting_siteURL", "http://127.0.0.1", 0) policy := Policy{Type: "local", Server: "http://127.0.0.1"} - asserts.Equal("http://127.0.0.1/api/v3/file/upload", policy.GetUploadURL()) + asserts.Equal("/api/v3/file/upload", policy.GetUploadURL()) } // 远程 diff --git a/models/share.go b/models/share.go index fb3d89a..2df25e3 100644 --- a/models/share.go +++ b/models/share.go @@ -33,7 +33,6 @@ type Share struct { } // Create 创建分享 -// TODO 测试 func (share *Share) Create() (uint, error) { if err := DB.Create(share).Error; err != nil { util.Log().Warning("无法插入数据库记录, %s", err) @@ -43,7 +42,6 @@ func (share *Share) Create() (uint, error) { } // GetShareByHashID 根据HashID查找分享 -// TODO 测试 func GetShareByHashID(hashID string) *Share { id, err := hashid.DecodeHashID(hashID, hashid.ShareID) if err != nil { @@ -59,7 +57,6 @@ func GetShareByHashID(hashID string) *Share { } // IsAvailable 返回此分享是否可用(是否过期) -// TODO 测试 func (share *Share) IsAvailable() bool { if share.RemainDownloads == 0 { return false @@ -71,37 +68,38 @@ func (share *Share) IsAvailable() bool { // 检查源对象是否存在 var sourceID uint if share.IsDir { - folder := share.GetSourceFolder() + folder := share.SourceFolder() sourceID = folder.ID } else { - file := share.GetSourceFile() + file := share.SourceFile() sourceID = file.ID } if sourceID == 0 { + // TODO 是否要在这里删除这个无效分享? return false } return true } -// GetCreator 获取分享的创建者 -func (share *Share) GetCreator() *User { +// Creator 获取分享的创建者 +func (share *Share) Creator() *User { if share.User.ID == 0 { share.User, _ = GetUserByID(share.UserID) } return &share.User } -// GetSource 返回源对象 -func (share *Share) GetSource() interface{} { +// Source 返回源对象 +func (share *Share) Source() interface{} { if share.IsDir { - return share.GetSourceFolder() + return share.SourceFolder() } - return share.GetSourceFile() + return share.SourceFile() } -// GetSourceFolder 获取源目录 -func (share *Share) GetSourceFolder() *Folder { +// SourceFolder 获取源目录 +func (share *Share) SourceFolder() *Folder { if share.Folder.ID == 0 { folders, _ := GetFoldersByIDs([]uint{share.SourceID}, share.UserID) if len(folders) > 0 { @@ -111,8 +109,8 @@ func (share *Share) GetSourceFolder() *Folder { return &share.Folder } -// GetSourceFile 获取源文件 -func (share *Share) GetSourceFile() *File { +// SourceFile 获取源文件 +func (share *Share) SourceFile() *File { if share.File.ID == 0 { files, _ := GetFilesByIDs([]uint{share.SourceID}, share.UserID) if len(files) > 0 { @@ -182,7 +180,7 @@ func (share *Share) Purchase(user *User) error { scoreRate := GetIntSetting("share_score_rate", 100) gainedScore := int(math.Ceil(float64(share.Score*scoreRate) / 100)) - share.GetCreator().AddScore(gainedScore) + share.Creator().AddScore(gainedScore) return nil } diff --git a/models/share_test.go b/models/share_test.go new file mode 100644 index 0000000..de4faf9 --- /dev/null +++ b/models/share_test.go @@ -0,0 +1,322 @@ +package model + +import ( + "errors" + "github.com/DATA-DOG/go-sqlmock" + "github.com/HFO4/cloudreve/pkg/cache" + "github.com/HFO4/cloudreve/pkg/conf" + "github.com/gin-gonic/gin" + "github.com/jinzhu/gorm" + "github.com/stretchr/testify/assert" + "net/http/httptest" + "testing" + "time" +) + +func TestShare_Create(t *testing.T) { + asserts := assert.New(t) + share := Share{UserID: 1} + + // 成功 + { + mock.ExpectBegin() + mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(2, 1)) + mock.ExpectCommit() + id, err := share.Create() + asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(err) + asserts.EqualValues(2, id) + } + + // 失败 + { + mock.ExpectBegin() + mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) + mock.ExpectRollback() + id, err := share.Create() + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Error(err) + asserts.EqualValues(0, id) + } +} + +func TestGetShareByHashID(t *testing.T) { + asserts := assert.New(t) + conf.SystemConfig.HashIDSalt = "" + + // 成功 + { + mock.ExpectQuery("SELECT(.+)"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + res := GetShareByHashID("x9T4") + asserts.NoError(mock.ExpectationsWereMet()) + asserts.NotNil(res) + } + + // 查询失败 + { + mock.ExpectQuery("SELECT(.+)"). + WillReturnError(errors.New("error")) + res := GetShareByHashID("x9T4") + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Nil(res) + } + + // ID解码失败 + { + res := GetShareByHashID("empty") + asserts.Nil(res) + } + +} + +func TestShare_IsAvailable(t *testing.T) { + asserts := assert.New(t) + + // 下载剩余次数为0 + { + share := Share{} + asserts.False(share.IsAvailable()) + } + + // 时效过期 + { + expires := time.Unix(10, 10) + share := Share{ + RemainDownloads: -1, + Expires: &expires, + } + asserts.False(share.IsAvailable()) + } + + // 源对象为目录,但不存在 + { + share := Share{ + RemainDownloads: -1, + SourceID: 2, + IsDir: true, + } + mock.ExpectQuery("SELECT(.+)"). + WillReturnRows(sqlmock.NewRows([]string{"id"})) + asserts.False(share.IsAvailable()) + asserts.NoError(mock.ExpectationsWereMet()) + } + + // 源对象为目录,存在 + { + share := Share{ + RemainDownloads: -1, + SourceID: 2, + IsDir: false, + } + mock.ExpectQuery("SELECT(.+)files(.+)"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(13)) + asserts.True(share.IsAvailable()) + asserts.NoError(mock.ExpectationsWereMet()) + } +} + +func TestShare_GetCreator(t *testing.T) { + asserts := assert.New(t) + + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + share := Share{UserID: 1} + res := share.Creator() + asserts.NoError(mock.ExpectationsWereMet()) + asserts.EqualValues(1, res.ID) +} + +func TestShare_Source(t *testing.T) { + asserts := assert.New(t) + + // 目录 + { + share := Share{IsDir: true, SourceID: 3} + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(3)) + asserts.EqualValues(3, share.Source().(*Folder).ID) + asserts.NoError(mock.ExpectationsWereMet()) + } + + // 文件 + { + share := Share{IsDir: false, SourceID: 3} + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(3)) + asserts.EqualValues(3, share.Source().(*File).ID) + asserts.NoError(mock.ExpectationsWereMet()) + } +} + +func TestShare_CanBeDownloadBy(t *testing.T) { + asserts := assert.New(t) + share := Share{} + + // 未登录,无权 + { + user := &User{ + Group: Group{ + OptionsSerialized: GroupOption{ + ShareDownloadEnabled: false, + }, + }, + } + asserts.Error(share.CanBeDownloadBy(user)) + } + + // 已登录,无权 + { + user := &User{ + Model: gorm.Model{ID: 1}, + Group: Group{ + OptionsSerialized: GroupOption{ + ShareDownloadEnabled: false, + }, + }, + } + asserts.Error(share.CanBeDownloadBy(user)) + } + + // 未登录,需要积分 + { + user := &User{ + Group: Group{ + OptionsSerialized: GroupOption{ + ShareDownloadEnabled: true, + }, + }, + } + share.Score = 1 + asserts.Error(share.CanBeDownloadBy(user)) + } + + // 成功 + { + user := &User{ + Model: gorm.Model{ID: 1}, + Group: Group{ + OptionsSerialized: GroupOption{ + ShareDownloadEnabled: true, + }, + }, + } + share.Score = 1 + asserts.NoError(share.CanBeDownloadBy(user)) + } +} + +func TestShare_WasDownloadedBy(t *testing.T) { + asserts := assert.New(t) + share := Share{ + Model: gorm.Model{ID: 1}, + } + + // 已登录,已下载 + { + user := User{ + Model: gorm.Model{ + ID: 1, + }, + } + r := httptest.NewRecorder() + c, _ := gin.CreateTestContext(r) + cache.Set("share_1_1", true, 0) + asserts.True(share.WasDownloadedBy(&user, c)) + } +} + +func TestShare_DownloadBy(t *testing.T) { + asserts := assert.New(t) + share := Share{ + Model: gorm.Model{ID: 1}, + } + user := User{ + Model: gorm.Model{ + ID: 1, + }, + } + cache.Deletes([]string{"1_1"}, "share_") + r := httptest.NewRecorder() + c, _ := gin.CreateTestContext(r) + + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + err := share.DownloadBy(&user, c) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(err) + _, ok := cache.Get("share_1_1") + asserts.True(ok) +} + +func TestShare_Purchase(t *testing.T) { + asserts := assert.New(t) + + // 不需要购买 + { + share := Share{} + user := User{} + asserts.NoError(share.Purchase(&user)) + + share.Score = 1 + user.Group.OptionsSerialized.ShareFreeEnabled = true + asserts.NoError(share.Purchase(&user)) + + user.Group.OptionsSerialized.ShareFreeEnabled = false + share.UserID = 1 + user.ID = 1 + asserts.NoError(share.Purchase(&user)) + } + + // 积分不足 + { + share := Share{ + Score: 1, + UserID: 2, + } + user := User{} + asserts.Error(share.Purchase(&user)) + } + + // 成功 + { + cache.Set("setting_share_score_rate", "80", 0) + share := Share{ + Score: 10, + UserID: 2, + User: User{ + Model: gorm.Model{ + ID: 1, + }, + }, + } + user := User{ + Score: 10, + } + + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)"). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)"). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + asserts.NoError(share.Purchase(&user)) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.EqualValues(0, user.Score) + asserts.EqualValues(8, share.User.Score) + } +} + +func TestShare_Viewed(t *testing.T) { + asserts := assert.New(t) + share := Share{} + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)"). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + share.Viewed() + asserts.NoError(mock.ExpectationsWereMet()) + asserts.EqualValues(1, share.Views) +} diff --git a/models/user.go b/models/user.go index e4956e4..710a043 100644 --- a/models/user.go +++ b/models/user.go @@ -95,7 +95,6 @@ func (user *User) IncreaseStorage(size uint64) bool { } // PayScore 扣除积分,返回是否成功 -// todo 测试 func (user *User) PayScore(score int) bool { if score == 0 { return true @@ -109,7 +108,6 @@ func (user *User) PayScore(score int) bool { } // AddScore 增加积分 -// todo 测试 func (user *User) AddScore(score int) { user.Score += score DB.Model(user).UpdateColumn("score", gorm.Expr("score + ?", score)) @@ -257,7 +255,6 @@ func (user *User) SetPassword(password string) error { } // NewAnonymousUser 返回一个匿名用户 -// TODO 测试 func NewAnonymousUser() *User { user := User{} user.Policy.Type = "anonymous" @@ -266,7 +263,6 @@ func NewAnonymousUser() *User { } // IsAnonymous 返回是否为未登录用户 -// TODO 测试 func (user *User) IsAnonymous() bool { return user.ID == 0 } diff --git a/models/user_test.go b/models/user_test.go index 63d7212..1dce113 100644 --- a/models/user_test.go +++ b/models/user_test.go @@ -340,3 +340,48 @@ func TestUser_Root(t *testing.T) { asserts.Error(err) } } + +func TestUser_PayScore(t *testing.T) { + asserts := assert.New(t) + user := User{Score: 5} + + asserts.True(user.PayScore(0)) + asserts.False(user.PayScore(10)) + + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + asserts.True(user.PayScore(5)) + asserts.EqualValues(0, user.Score) + asserts.NoError(mock.ExpectationsWereMet()) +} + +func TestUser_AddScore(t *testing.T) { + asserts := assert.New(t) + user := User{Score: 5} + + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + user.AddScore(5) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.EqualValues(10, user.Score) +} + +func TestNewAnonymousUser(t *testing.T) { + asserts := assert.New(t) + + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(3)) + user := NewAnonymousUser() + asserts.NoError(mock.ExpectationsWereMet()) + asserts.NotNil(user) + asserts.EqualValues(3, user.Group.ID) +} + +func TestUser_IsAnonymous(t *testing.T) { + asserts := assert.New(t) + user := User{} + asserts.True(user.IsAnonymous()) + user.ID = 1 + asserts.False(user.IsAnonymous()) +} diff --git a/pkg/filesystem/file.go b/pkg/filesystem/file.go index f17d2a7..330af0b 100644 --- a/pkg/filesystem/file.go +++ b/pkg/filesystem/file.go @@ -240,7 +240,7 @@ func (fs *FileSystem) GetDownloadURL(ctx context.Context, path string, timeout s return source, nil } -// GetSource 获取可直接访问文件的外链地址 +// Source 获取可直接访问文件的外链地址 func (fs *FileSystem) GetSource(ctx context.Context, fileID uint) (string, error) { // 查找文件记录 err := fs.resetFileIDIfNotExist(ctx, fileID) diff --git a/pkg/serializer/share.go b/pkg/serializer/share.go index f2b6ef3..ddf4621 100644 --- a/pkg/serializer/share.go +++ b/pkg/serializer/share.go @@ -34,7 +34,7 @@ type shareSource struct { // BuildShareResponse 构建获取分享信息响应 func BuildShareResponse(share *model.Share, unlocked bool) Share { - creator := share.GetCreator() + creator := share.Creator() resp := Share{ Key: hashid.HashID(share.ID, hashid.ShareID), Locked: !unlocked, @@ -62,13 +62,13 @@ func BuildShareResponse(share *model.Share, unlocked bool) Share { } if share.IsDir { - source := share.GetSourceFolder() + source := share.SourceFolder() resp.Source = &shareSource{ Name: source.Name, Size: 0, } } else { - source := share.GetSourceFile() + source := share.SourceFile() resp.Source = &shareSource{ Name: source.Name, Size: source.Size, diff --git a/routers/controllers/file.go b/routers/controllers/file.go index 224f57b..0b6ff50 100644 --- a/routers/controllers/file.go +++ b/routers/controllers/file.go @@ -64,7 +64,7 @@ func AnonymousGetContent(c *gin.Context) { } } -// GetSource 获取文件的外链地址 +// Source 获取文件的外链地址 func GetSource(c *gin.Context) { // 创建上下文 ctx, cancel := context.WithCancel(context.Background()) diff --git a/service/share/visit.go b/service/share/visit.go index db1164f..ebef557 100644 --- a/service/share/visit.go +++ b/service/share/visit.go @@ -85,7 +85,7 @@ func (service *Service) CreateDownloadSession(c *gin.Context) serializer.Respons defer fs.Recycle() // 重设文件系统处理目标为源文件 - err = fs.SetTargetByInterface(share.GetSource()) + err = fs.SetTargetByInterface(share.Source()) if err != nil { return serializer.Err(serializer.CodePolicyNotAllowed, "源文件不存在", err) } @@ -115,9 +115,9 @@ func (service *Service) PreviewContent(ctx context.Context, c *gin.Context, isTe // 用于调下层service if share.IsDir { - ctx = context.WithValue(ctx, fsctx.FolderModelCtx, share.GetSource()) + ctx = context.WithValue(ctx, fsctx.FolderModelCtx, share.Source()) } else { - ctx = context.WithValue(ctx, fsctx.FileModelCtx, share.GetSource()) + ctx = context.WithValue(ctx, fsctx.FileModelCtx, share.Source()) } subService := explorer.SingleFileService{ Path: service.Path, @@ -134,9 +134,9 @@ func (service *Service) CreateDocPreviewSession(c *gin.Context) serializer.Respo // 用于调下层service ctx := context.Background() if share.IsDir { - ctx = context.WithValue(ctx, fsctx.FolderModelCtx, share.GetSource()) + ctx = context.WithValue(ctx, fsctx.FolderModelCtx, share.Source()) } else { - ctx = context.WithValue(ctx, fsctx.FileModelCtx, share.GetSource()) + ctx = context.WithValue(ctx, fsctx.FileModelCtx, share.Source()) } subService := explorer.SingleFileService{ Path: service.Path, @@ -165,7 +165,7 @@ func (service *Service) SaveToMyFile(c *gin.Context) serializer.Response { defer fs.Recycle() // 重设文件系统处理目标为源文件 - err = fs.SetTargetByInterface(share.GetSource()) + err = fs.SetTargetByInterface(share.Source()) if err != nil { return serializer.Err(serializer.CodePolicyNotAllowed, "源文件不存在", err) } @@ -192,7 +192,7 @@ func (service *Service) List(c *gin.Context) serializer.Response { } // 创建文件系统 - fs, err := filesystem.NewFileSystem(share.GetCreator()) + fs, err := filesystem.NewFileSystem(share.Creator()) if err != nil { return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) } @@ -203,7 +203,7 @@ func (service *Service) List(c *gin.Context) serializer.Response { defer cancel() // 重设根目录 - fs.Root = share.GetSource().(*model.Folder) + fs.Root = share.Source().(*model.Folder) fs.Root.Name = "/" // 分享Key上下文 @@ -231,14 +231,14 @@ func (service *Service) Thumb(c *gin.Context) serializer.Response { } // 创建文件系统 - fs, err := filesystem.NewFileSystem(share.GetCreator()) + fs, err := filesystem.NewFileSystem(share.Creator()) if err != nil { return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) } defer fs.Recycle() // 重设根目录 - fs.Root = share.GetSource().(*model.Folder) + fs.Root = share.Source().(*model.Folder) // 找到缩略图的父目录 exist, parent := fs.IsPathExist(service.Path) @@ -297,7 +297,7 @@ func (service *ArchiveService) Archive(c *gin.Context) serializer.Response { defer fs.Recycle() // 重设根目录 - fs.Root = share.GetSource().(*model.Folder) + fs.Root = share.Source().(*model.Folder) // 找到要打包文件的父目录 exist, parent := fs.IsPathExist(service.Path) @@ -309,7 +309,7 @@ func (service *ArchiveService) Archive(c *gin.Context) serializer.Response { ctx := context.WithValue(context.Background(), fsctx.LimitParentCtx, parent) // 用于调下层service - tempUser := share.GetCreator() + tempUser := share.Creator() tempUser.Group.OptionsSerialized.ArchiveDownloadEnabled = true c.Set("user", tempUser)