diff --git a/models/user.go b/models/user.go index e4748bc..49ed1b1 100644 --- a/models/user.go +++ b/models/user.go @@ -54,6 +54,24 @@ type UserOption struct { WebDAVKey string `json:"webdav_key"` } +// DeductionCapacity 扣除用户容量配额 +func (user *User) DeductionCapacity(size uint64) bool { + if size <= user.GetRemainingCapacity() { + user.Storage += size + DB.Save(user) + return true + } + return false +} + +// GetRemainingCapacity 获取剩余配额 +func (user *User) GetRemainingCapacity() uint64 { + if user.Group.MaxStorage <= user.Storage { + return 0 + } + return user.Group.MaxStorage - user.Storage +} + // GetPolicyID 获取用户当前的上传策略ID func (user *User) GetPolicyID() uint { // 用户未指定时,返回可用的第一个 diff --git a/models/user_test.go b/models/user_test.go index 986899c..7346314 100644 --- a/models/user_test.go +++ b/models/user_test.go @@ -164,3 +164,55 @@ func TestUser_GetPolicyID(t *testing.T) { asserts.Equal(testCase.expected, newUser.GetPolicyID(), "测试用例 #%d 未通过", key) } } + +func TestUser_GetRemainingCapacity(t *testing.T) { + asserts := assert.New(t) + newUser := NewUser() + + newUser.Group.MaxStorage = 100 + asserts.Equal(uint64(100), newUser.GetRemainingCapacity()) + + newUser.Group.MaxStorage = 100 + newUser.Storage = 1 + asserts.Equal(uint64(99), newUser.GetRemainingCapacity()) + + newUser.Group.MaxStorage = 100 + newUser.Storage = 100 + asserts.Equal(uint64(0), newUser.GetRemainingCapacity()) + + newUser.Group.MaxStorage = 100 + newUser.Storage = 200 + asserts.Equal(uint64(0), newUser.GetRemainingCapacity()) +} + +func TestUser_DeductionCapacity(t *testing.T) { + asserts := assert.New(t) + + userRows := sqlmock.NewRows([]string{"id", "deleted_at", "storage", "options", "group_id"}). + AddRow(1, nil, 0, "{}", 1) + mock.ExpectQuery("^SELECT (.+)").WillReturnRows(userRows) + groupRows := sqlmock.NewRows([]string{"id", "name", "policies"}). + AddRow(1, "管理员", "[1]") + mock.ExpectQuery("^SELECT (.+)").WillReturnRows(groupRows) + + policyRows := sqlmock.NewRows([]string{"id", "name"}). + AddRow(1, "默认上传策略") + mock.ExpectQuery("^SELECT (.+)").WillReturnRows(policyRows) + + newUser, err := GetUserByID(1) + newUser.Group.MaxStorage = 100 + asserts.NoError(err) + asserts.NoError(mock.ExpectationsWereMet()) + + asserts.Equal(false, newUser.DeductionCapacity(101)) + asserts.Equal(uint64(0), newUser.Storage) + + asserts.Equal(true, newUser.DeductionCapacity(1)) + asserts.Equal(uint64(1), newUser.Storage) + + asserts.Equal(true, newUser.DeductionCapacity(99)) + asserts.Equal(uint64(100), newUser.Storage) + + asserts.Equal(false, newUser.DeductionCapacity(1)) + asserts.Equal(uint64(100), newUser.Storage) +} diff --git a/pkg/filesystem/filesystem.go b/pkg/filesystem/filesystem.go index 16debd4..f4226d9 100644 --- a/pkg/filesystem/filesystem.go +++ b/pkg/filesystem/filesystem.go @@ -11,18 +11,37 @@ type FileData interface { io.Closer GetSize() uint64 GetMIMEType() string + GetFileName() string } // FileSystem 管理文件的文件系统 type FileSystem struct { - // 文件系统所有者 + /* + 文件系统所有者 + */ User *model.User - // 文件系统处理适配器 + /* + 钩子函数 + */ + // 上传文件前 + BeforeUpload func(fs *FileSystem, file FileData) error + // 上传文件后 + AfterUpload func(fs *FileSystem) error + // 文件验证失败后 + ValidateFailed func(fs *FileSystem) error + + /* + 文件系统处理适配器 + */ } // Upload 上传文件 -func (fs *FileSystem) Upload(File FileData) (err error) { +func (fs *FileSystem) Upload(file FileData) (err error) { + err = fs.BeforeUpload(fs, file) + if err != nil { + return err + } return nil } diff --git a/pkg/filesystem/hook.go b/pkg/filesystem/hook.go new file mode 100644 index 0000000..e4e1616 --- /dev/null +++ b/pkg/filesystem/hook.go @@ -0,0 +1,22 @@ +package filesystem + +import "errors" + +// GenericBeforeUpload 通用上传前处理钩子,包含数据库操作 +func GenericBeforeUpload(fs *FileSystem, file FileData) error { + // 验证单文件尺寸 + if !fs.ValidateFileSize(file.GetSize()) { + return errors.New("单个文件尺寸太大") + } + + // 验证并扣除容量 + if !fs.ValidateCapacity(file.GetSize()) { + return errors.New("容量空间不足") + } + + // 验证扩展名 + if !fs.ValidateExtension(file.GetFileName()) { + return errors.New("不允许上传此类型的文件") + } + return nil +} diff --git a/pkg/filesystem/local/file.go b/pkg/filesystem/local/file.go index 38698dc..2f950d0 100644 --- a/pkg/filesystem/local/file.go +++ b/pkg/filesystem/local/file.go @@ -6,6 +6,7 @@ import "mime/multipart" type FileData struct { File multipart.File Size uint64 + Name string MIMEType string } @@ -24,3 +25,7 @@ func (file FileData) GetSize() uint64 { func (file FileData) Close() error { return file.Close() } + +func (file FileData) GetFileName() string { + return file.Name +} diff --git a/pkg/filesystem/validator.go b/pkg/filesystem/validator.go new file mode 100644 index 0000000..cc0ffda --- /dev/null +++ b/pkg/filesystem/validator.go @@ -0,0 +1,40 @@ +package filesystem + +import ( + "cloudreve/pkg/util" + "path/filepath" +) + +// ValidateFileSize 验证上传的文件大小是否超出限制 +func (fs *FileSystem) ValidateFileSize(size uint64) bool { + return size <= fs.User.Policy.MaxSize +} + +// ValidateCapacity 验证并扣除用户容量 +func (fs *FileSystem) ValidateCapacity(size uint64) bool { + if fs.User.DeductionCapacity(size) { + return true + } + return false +} + +// ValidateExtension 验证文件扩展名 +func (fs *FileSystem) ValidateExtension(fileName string) bool { + // 不需要验证 + if len(fs.User.Policy.OptionsSerialized.FileType) == 0 { + return true + } + + ext := filepath.Ext(fileName) + + // 无扩展名时 + if len(ext) == 0 { + return false + } + + if util.ContainsString(fs.User.Policy.OptionsSerialized.FileType, ext[1:]) { + return true + } + + return false +} diff --git a/pkg/serializer/common.go b/pkg/serializer/common.go index 9819472..e3d156f 100644 --- a/pkg/serializer/common.go +++ b/pkg/serializer/common.go @@ -19,6 +19,8 @@ const ( CodeCheckLogin = 401 // CodeNoRightErr 未授权访问 CodeNoRightErr = 403 + // CodeUploadFailed 上传出错 + CodeUploadFailed = 4001 // CodeDBError 数据库操作失败 CodeDBError = 50001 // CodeEncryptError 加密失败 diff --git a/pkg/util/common.go b/pkg/util/common.go index 4f058fc..e136eee 100644 --- a/pkg/util/common.go +++ b/pkg/util/common.go @@ -24,3 +24,13 @@ func ContainsUint(s []uint, e uint) bool { } return false } + +// ContainsString 返回list中是否包含 +func ContainsString(s []string, e string) bool { + for _, a := range s { + if a == e { + return true + } + } + return false +} diff --git a/service/file/upload.go b/service/file/upload.go index 68403df..f9ce63c 100644 --- a/service/file/upload.go +++ b/service/file/upload.go @@ -28,14 +28,18 @@ func (service *UploadService) Upload(c *gin.Context) serializer.Response { MIMEType: service.File.Header.Get("Content-Type"), File: file, Size: uint64(service.File.Size), + Name: service.Name, } - user, _ := c.Get("user") - fs := filesystem.FileSystem{ - User: user.(*model.User), + BeforeUpload: filesystem.GenericBeforeUpload, + User: user.(*model.User), } + err = fs.Upload(fileData) + if err != nil { + return serializer.Err(serializer.CodeUploadFailed, err.Error(), err) + } return serializer.Response{ Code: 0,