diff --git a/models/user.go b/models/user.go index 5a49002..2b704fe 100644 --- a/models/user.go +++ b/models/user.go @@ -54,8 +54,18 @@ type UserOption struct { WebDAVKey string `json:"webdav_key"` } -// DeductionCapacity 扣除用户容量配额 -func (user *User) DeductionCapacity(size uint64) bool { +// DeductionStorage 减少用户已用容量 +func (user *User) DeductionStorage(size uint64) bool { + if size <= user.Storage { + user.Storage -= size + DB.Save(user) + return true + } + return false +} + +// IncreaseStorage 检查并增加用户已用容量 +func (user *User) IncreaseStorage(size uint64) bool { if size <= user.GetRemainingCapacity() { user.Storage += size DB.Save(user) diff --git a/models/user_test.go b/models/user_test.go index 7346314..d769662 100644 --- a/models/user_test.go +++ b/models/user_test.go @@ -204,15 +204,15 @@ func TestUser_DeductionCapacity(t *testing.T) { asserts.NoError(err) asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal(false, newUser.DeductionCapacity(101)) + asserts.Equal(false, newUser.IncreaseStorage(101)) asserts.Equal(uint64(0), newUser.Storage) - asserts.Equal(true, newUser.DeductionCapacity(1)) + asserts.Equal(true, newUser.IncreaseStorage(1)) asserts.Equal(uint64(1), newUser.Storage) - asserts.Equal(true, newUser.DeductionCapacity(99)) + asserts.Equal(true, newUser.IncreaseStorage(99)) asserts.Equal(uint64(100), newUser.Storage) - asserts.Equal(false, newUser.DeductionCapacity(1)) + asserts.Equal(false, newUser.IncreaseStorage(1)) asserts.Equal(uint64(100), newUser.Storage) } diff --git a/pkg/filesystem/filesystem.go b/pkg/filesystem/filesystem.go index a6e3f7e..eef5e53 100644 --- a/pkg/filesystem/filesystem.go +++ b/pkg/filesystem/filesystem.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/HFO4/cloudreve/models" "github.com/HFO4/cloudreve/pkg/filesystem/local" + "github.com/HFO4/cloudreve/pkg/util" "github.com/gin-gonic/gin" "io" "path/filepath" @@ -21,7 +22,10 @@ type FileData interface { // Handler 存储策略适配器 type Handler interface { + // 上传文件 Put(ctx context.Context, file io.ReadCloser, dst string) error + // 删除一个或多个文件 + Delete(ctx context.Context, files []string) ([]string, error) } // FileSystem 管理文件的文件系统 @@ -39,7 +43,9 @@ type FileSystem struct { // 上传文件后 AfterUpload func(ctx context.Context, fs *FileSystem) error // 文件保存成功,插入数据库验证失败后 - ValidateFailed func(ctx context.Context, fs *FileSystem) error + AfterValidateFailed func(ctx context.Context, fs *FileSystem) error + // 用户取消上传后 + AfterUploadCanceled func(ctx context.Context, fs *FileSystem, file FileData) error /* 文件系统处理适配器 @@ -73,9 +79,11 @@ func NewFileSystem(user *model.User) (*FileSystem, error) { // Upload 上传文件 func (fs *FileSystem) Upload(ctx context.Context, file FileData) (err error) { // 上传前的钩子 - err = fs.BeforeUpload(ctx, fs, file) - if err != nil { - return err + if fs.BeforeUpload != nil { + err = fs.BeforeUpload(ctx, fs, file) + if err != nil { + return err + } } // 生成文件名和路径 @@ -106,10 +114,17 @@ func (fs *FileSystem) CancelUpload(ctx context.Context, path string, file FileDa ginCtx := ctx.Value("ginCtx").(*gin.Context) select { case <-ctx.Done(): + fmt.Println("正常关闭") // 客户端正常关闭,不执行操作 case <-ginCtx.Request.Context().Done(): - // 客户端取消了上传,删除保存的文件 - fmt.Println("取消上传") - // 归还空间 + // 客户端取消了上传 + if fs.AfterUploadCanceled == nil { + return + } + ctx = context.WithValue(ctx, "path", path) + err := fs.AfterUploadCanceled(ctx, fs, file) + if err != nil { + util.Log().Warning("执行 AfterUploadCanceled 钩子出错,%s", err) + } } } diff --git a/pkg/filesystem/hook.go b/pkg/filesystem/hook.go index 854bc48..b87c709 100644 --- a/pkg/filesystem/hook.go +++ b/pkg/filesystem/hook.go @@ -2,6 +2,7 @@ package filesystem import ( "context" + "errors" ) // GenericBeforeUpload 通用上传前处理钩子,包含数据库操作 @@ -27,3 +28,19 @@ func GenericBeforeUpload(ctx context.Context, fs *FileSystem, file FileData) err } return nil } + +// GenericAfterUploadCanceled 通用上传取消处理钩子,包含数据库操作 +func GenericAfterUploadCanceled(ctx context.Context, fs *FileSystem, file FileData) error { + filePath := ctx.Value("path").(string) + // 删除临时文件 + _, err := fs.Handler.Delete(ctx, []string{filePath}) + if err != nil { + return err + } + + // 归还用户容量 + if !fs.User.DeductionStorage(file.GetSize()) { + return errors.New("无法继续降低用户已用存储") + } + return nil +} diff --git a/pkg/filesystem/hook_test.go b/pkg/filesystem/hook_test.go index 3d22e43..528820c 100644 --- a/pkg/filesystem/hook_test.go +++ b/pkg/filesystem/hook_test.go @@ -36,4 +36,6 @@ func TestGenericBeforeUpload(t *testing.T) { asserts.Error(GenericBeforeUpload(ctx, &fs, file)) file.Name = "1.txt" asserts.NoError(GenericBeforeUpload(ctx, &fs, file)) + file.Name = "1.t/xt" + asserts.Error(GenericBeforeUpload(ctx, &fs, file)) } diff --git a/pkg/filesystem/local/handler.go b/pkg/filesystem/local/handler.go index d2ce283..c8e7110 100644 --- a/pkg/filesystem/local/handler.go +++ b/pkg/filesystem/local/handler.go @@ -20,6 +20,7 @@ func (handler Handler) Put(ctx context.Context, file io.ReadCloser, dst string) if !util.Exists(basePath) { err := os.MkdirAll(basePath, 0700) if err != nil { + util.Log().Warning("无法创建目录,%s", err) return err } } @@ -27,6 +28,7 @@ func (handler Handler) Put(ctx context.Context, file io.ReadCloser, dst string) // 创建目标文件 out, err := os.Create(dst) if err != nil { + util.Log().Warning("无法创建文件,%s", err) return err } defer out.Close() @@ -35,3 +37,22 @@ func (handler Handler) Put(ctx context.Context, file io.ReadCloser, dst string) _, err = io.Copy(out, file) return err } + +// Delete 删除一个或多个文件, +// 返回已删除的文件,及遇到的最后一个错误 +func (handler Handler) Delete(ctx context.Context, files []string) ([]string, error) { + deleted := make([]string, 0, len(files)) + var retErr error + + for _, value := range files { + err := os.Remove(value) + if err == nil { + deleted = append(deleted, value) + util.Log().Warning("无法删除文件,%s", err) + } else { + retErr = err + } + } + + return deleted, retErr +} diff --git a/pkg/filesystem/validator.go b/pkg/filesystem/validator.go index 2b0d129..96fb525 100644 --- a/pkg/filesystem/validator.go +++ b/pkg/filesystem/validator.go @@ -27,7 +27,7 @@ func (fs *FileSystem) ValidateFileSize(ctx context.Context, size uint64) bool { // ValidateCapacity 验证并扣除用户容量 func (fs *FileSystem) ValidateCapacity(ctx context.Context, size uint64) bool { - if fs.User.DeductionCapacity(size) { + if fs.User.IncreaseStorage(size) { return true } return false diff --git a/pkg/filesystem/validator_test.go b/pkg/filesystem/validator_test.go index 6d68f07..015613f 100644 --- a/pkg/filesystem/validator_test.go +++ b/pkg/filesystem/validator_test.go @@ -25,6 +25,19 @@ func TestMain(m *testing.M) { m.Run() } +func TestFileSystem_ValidateLegalName(t *testing.T) { + asserts := assert.New(t) + ctx := context.Background() + fs := FileSystem{} + asserts.True(fs.ValidateLegalName(ctx, "1.txt")) + asserts.True(fs.ValidateLegalName(ctx, "1-1.txt")) + asserts.True(fs.ValidateLegalName(ctx, "1?1.txt")) + asserts.False(fs.ValidateLegalName(ctx, "1:1.txt")) + asserts.False(fs.ValidateLegalName(ctx, "../11.txt")) + asserts.False(fs.ValidateLegalName(ctx, "/11.txt")) + asserts.False(fs.ValidateLegalName(ctx, "\\11.txt")) +} + func TestFileSystem_ValidateCapacity(t *testing.T) { asserts := assert.New(t) ctx := context.Background() diff --git a/routers/controllers/file.go b/routers/controllers/file.go index d28ede0..aa19c6d 100644 --- a/routers/controllers/file.go +++ b/routers/controllers/file.go @@ -70,6 +70,7 @@ func FileUploadStream(c *gin.Context) { // 给文件系统分配钩子 fs.BeforeUpload = filesystem.GenericBeforeUpload + fs.AfterUploadCanceled = filesystem.GenericAfterUploadCanceled // 执行上传 uploadCtx := context.WithValue(ctx, "ginCtx", c)