From 68d4a86166a974bf55d4ddbac9abfbef4e2cba37 Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Fri, 17 Jan 2020 14:35:21 +0800 Subject: [PATCH] Fix: storage policy should be re-dispatched according to policy id in upload session --- middleware/auth.go | 5 ----- middleware/auth_test.go | 30 ------------------------------ pkg/filesystem/archive.go | 2 +- pkg/filesystem/file.go | 6 +++--- pkg/filesystem/filesystem.go | 6 +++--- pkg/filesystem/filesystem_test.go | 4 ++-- pkg/filesystem/hooks.go | 2 +- service/callback/upload.go | 12 ++++++++++++ 8 files changed, 22 insertions(+), 45 deletions(-) diff --git a/middleware/auth.go b/middleware/auth.go index 79788a2..bdceac1 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -132,11 +132,6 @@ func uploadCallbackCheck(c *gin.Context) (serializer.Response, *model.User) { } c.Set("user", &user) - // 检查存储策略是否一致 - if user.GetPolicyID() != callbackSession.PolicyID { - return serializer.Err(serializer.CodePolicyNotAllowed, "存储策略已变更,请重新上传", nil), nil - } - return serializer.Response{}, &user } diff --git a/middleware/auth_test.go b/middleware/auth_test.go index ca595b9..196795b 100644 --- a/middleware/auth_test.go +++ b/middleware/auth_test.go @@ -277,36 +277,6 @@ func TestRemoteCallbackAuth(t *testing.T) { asserts.True(c.IsAborted()) } - // 存储策略不一致 - { - cache.Set( - "callback_testCallBackRemote", - serializer.UploadSession{ - UID: 1, - PolicyID: 2, - VirtualPath: "/", - }, - 0, - ) - cache.Deletes([]string{"1"}, "policy_") - mock.ExpectQuery("SELECT(.+)users(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1)) - mock.ExpectQuery("SELECT(.+)groups(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[3]")) - mock.ExpectQuery("SELECT(.+)policies(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "secret_key"}).AddRow(3, "123")) - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{ - {"key", "testCallBackRemote"}, - } - c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil) - authInstance := auth.HMACAuth{SecretKey: []byte("123")} - auth.SignRequest(authInstance, c.Request, 0) - AuthFunc(c) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.True(c.IsAborted()) - } - // 签名错误 { cache.Set( diff --git a/pkg/filesystem/archive.go b/pkg/filesystem/archive.go index df746ed..e988c70 100644 --- a/pkg/filesystem/archive.go +++ b/pkg/filesystem/archive.go @@ -105,7 +105,7 @@ func (fs *FileSystem) doCompress(ctx context.Context, file *model.File, folder * if file != nil { // 切换上传策略 fs.Policy = file.GetPolicy() - err := fs.dispatchHandler() + err := fs.DispatchHandler() if err != nil { util.Log().Warning("无法压缩文件%s,%s", file.Name, err) return diff --git a/pkg/filesystem/file.go b/pkg/filesystem/file.go index cc7fed0..6eebc02 100644 --- a/pkg/filesystem/file.go +++ b/pkg/filesystem/file.go @@ -79,7 +79,7 @@ func (fs *FileSystem) AddFile(ctx context.Context, parent *model.Folder) (*model func (fs *FileSystem) GetPhysicalFileContent(ctx context.Context, path string) (response.RSCloser, error) { // 重设上传策略 fs.Policy = &model.Policy{Type: "local"} - _ = fs.dispatchHandler() + _ = fs.DispatchHandler() // 获取文件流 rs, err := fs.Handler.Get(ctx, path) @@ -184,7 +184,7 @@ func (fs *FileSystem) deleteGroupedFile(ctx context.Context, files map[uint][]*m // 切换上传策略 fs.Policy = toBeDeletedFiles[0].GetPolicy() - err := fs.dispatchHandler() + err := fs.DispatchHandler() if err != nil { failed[policyID] = sourceNames continue @@ -327,7 +327,7 @@ func (fs *FileSystem) resetPolicyToFirstFile(ctx context.Context) error { } fs.Policy = fs.FileTarget[0].GetPolicy() - err := fs.dispatchHandler() + err := fs.DispatchHandler() if err != nil { return err } diff --git a/pkg/filesystem/filesystem.go b/pkg/filesystem/filesystem.go index 7410fd8..cd8c331 100644 --- a/pkg/filesystem/filesystem.go +++ b/pkg/filesystem/filesystem.go @@ -109,7 +109,7 @@ func NewFileSystem(user *model.User) (*FileSystem, error) { fs := getEmptyFS() fs.User = user // 分配存储策略适配器 - err := fs.dispatchHandler() + err := fs.DispatchHandler() // TODO 分配默认钩子 return fs, err @@ -135,9 +135,9 @@ func NewAnonymousFileSystem() (*FileSystem, error) { return fs, nil } -// dispatchHandler 根据存储策略分配文件适配器 +// DispatchHandler 根据存储策略分配文件适配器 // TODO 完善测试 -func (fs *FileSystem) dispatchHandler() error { +func (fs *FileSystem) DispatchHandler() error { var policyType string var currentPolicy *model.Policy diff --git a/pkg/filesystem/filesystem_test.go b/pkg/filesystem/filesystem_test.go index 2dc91d0..94404b0 100644 --- a/pkg/filesystem/filesystem_test.go +++ b/pkg/filesystem/filesystem_test.go @@ -64,13 +64,13 @@ func TestDispatchHandler(t *testing.T) { } // 未指定,使用用户默认 - err := fs.dispatchHandler() + err := fs.DispatchHandler() asserts.NoError(err) asserts.IsType(local.Driver{}, fs.Handler) // 已指定,发生错误 fs.Policy = &model.Policy{Type: "unknown"} - err = fs.dispatchHandler() + err = fs.DispatchHandler() asserts.Error(err) } diff --git a/pkg/filesystem/hooks.go b/pkg/filesystem/hooks.go index aab2355..15c9d20 100644 --- a/pkg/filesystem/hooks.go +++ b/pkg/filesystem/hooks.go @@ -106,7 +106,7 @@ func HookResetPolicy(ctx context.Context, fs *FileSystem) error { } fs.Policy = originFile.GetPolicy() - return fs.dispatchHandler() + return fs.DispatchHandler() } // HookValidateCapacity 验证并扣除用户容量,包含数据库操作 diff --git a/service/callback/upload.go b/service/callback/upload.go index eb15598..90dc110 100644 --- a/service/callback/upload.go +++ b/service/callback/upload.go @@ -2,6 +2,7 @@ package callback import ( "context" + model "github.com/HFO4/cloudreve/models" "github.com/HFO4/cloudreve/pkg/filesystem" "github.com/HFO4/cloudreve/pkg/filesystem/fsctx" "github.com/HFO4/cloudreve/pkg/filesystem/local" @@ -61,6 +62,17 @@ func ProcessCallback(service CallbackProcessService, c *gin.Context) serializer. } callbackSession := callbackSessionRaw.(*serializer.UploadSession) + // 重新指向上传策略 + policy, err := model.GetPolicyByID(callbackSession.PolicyID) + if err != nil { + return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) + } + fs.Policy = &policy + err = fs.DispatchHandler() + if err != nil { + return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) + } + // 获取父目录 exist, parentFolder := fs.IsPathExist(callbackSession.VirtualPath) if !exist {