diff --git a/middleware/auth.go b/middleware/auth.go index 82533a3..457c974 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -3,6 +3,7 @@ package middleware import ( "github.com/HFO4/cloudreve/models" "github.com/HFO4/cloudreve/pkg/auth" + "github.com/HFO4/cloudreve/pkg/cache" "github.com/HFO4/cloudreve/pkg/serializer" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" @@ -25,6 +26,7 @@ func SignRequired() gin.HandlerFunc { if err != nil { c.JSON(200, serializer.Err(serializer.CodeCheckLogin, err.Error(), err)) c.Abort() + return } c.Next() } @@ -103,3 +105,55 @@ func WebDAVAuth() gin.HandlerFunc { c.Next() } } + +// RemoteCallbackAuth 远程回调签名验证 +// TODO 测试 +func RemoteCallbackAuth() gin.HandlerFunc { + return func(c *gin.Context) { + // 验证 Callback Key + callbackKey := c.Param("key") + if callbackKey == "" { + c.JSON(200, serializer.ParamErr("Callback Key 不能为空", nil)) + c.Abort() + return + } + callbackSessionRaw, exist := cache.Get("callback_" + callbackKey) + if !exist { + c.JSON(200, serializer.ParamErr("回调会话不存在或已过期", nil)) + c.Abort() + return + } + callbackSession := callbackSessionRaw.(serializer.UploadSession) + c.Set("callbackSession", &callbackSession) + + // 清理回调会话 + _ = cache.Deletes([]string{callbackKey}, "callback_") + + // 查找用户 + user, err := model.GetUserByID(callbackSession.UID) + if err != nil { + c.JSON(200, serializer.Err(serializer.CodeCheckLogin, "找不到用户", err)) + c.Abort() + return + } + c.Set("user", &user) + + // 检查存储策略是否一致 + if user.GetPolicyID() != callbackSession.PolicyID { + c.JSON(200, serializer.Err(serializer.CodePolicyNotAllowed, "存储策略已变更,请重新上传", nil)) + c.Abort() + return + } + + // 验证签名 + authInstance := auth.HMACAuth{SecretKey: []byte(user.Policy.SecretKey)} + if err := auth.CheckRequest(authInstance, c.Request); err != nil { + c.JSON(200, serializer.Err(serializer.CodeCheckLogin, err.Error(), err)) + c.Abort() + return + } + + c.Next() + + } +} diff --git a/pkg/filesystem/file.go b/pkg/filesystem/file.go index 7ce2d26..06d6b39 100644 --- a/pkg/filesystem/file.go +++ b/pkg/filesystem/file.go @@ -94,7 +94,7 @@ func (fs *FileSystem) GetDownloadContent(ctx context.Context, path string) (resp // GetContent 获取文件内容,path为虚拟路径 func (fs *FileSystem) GetContent(ctx context.Context, path string) (response.RSCloser, error) { // 触发`下载前`钩子 - err := fs.Trigger(ctx, fs.BeforeFileDownload) + err := fs.Trigger(ctx, "BeforeFileDownload") if err != nil { util.Log().Debug("BeforeFileDownload 钩子执行失败,%s", err) return nil, err diff --git a/pkg/filesystem/filesystem.go b/pkg/filesystem/filesystem.go index 8ef232a..e644d34 100644 --- a/pkg/filesystem/filesystem.go +++ b/pkg/filesystem/filesystem.go @@ -68,16 +68,7 @@ type FileSystem struct { /* 钩子函数 */ - // 上传文件前 - BeforeUpload []Hook - // 上传文件后 - AfterUpload []Hook - // 文件保存成功,插入数据库验证失败后 - AfterValidateFailed []Hook - // 用户取消上传后 - AfterUploadCanceled []Hook - // 文件下载前 - BeforeFileDownload []Hook + Hooks map[string][]Hook /* 文件系统处理适配器 @@ -102,11 +93,7 @@ func (fs *FileSystem) reset() { fs.User = nil fs.CleanTargets() fs.Policy = nil - fs.BeforeUpload = fs.BeforeUpload[:0] - fs.AfterUpload = fs.AfterUpload[:0] - fs.AfterValidateFailed = fs.AfterValidateFailed[:0] - fs.AfterUploadCanceled = fs.AfterUploadCanceled[:0] - fs.BeforeFileDownload = fs.BeforeFileDownload[:0] + fs.Hooks = nil fs.Handler = nil } diff --git a/pkg/filesystem/filesystem_test.go b/pkg/filesystem/filesystem_test.go index 6ed6afb..4aaa26b 100644 --- a/pkg/filesystem/filesystem_test.go +++ b/pkg/filesystem/filesystem_test.go @@ -136,11 +136,11 @@ func TestNewAnonymousFileSystem(t *testing.T) { func TestFileSystem_Recycle(t *testing.T) { fs := &FileSystem{ - User: &model.User{}, - Policy: &model.Policy{}, - FileTarget: []model.File{model.File{}}, - DirTarget: []model.Folder{model.Folder{}}, - AfterUpload: []Hook{GenericAfterUpdate}, + User: &model.User{}, + Policy: &model.Policy{}, + FileTarget: []model.File{model.File{}}, + DirTarget: []model.Folder{model.Folder{}}, + Hooks: map[string][]Hook{"AfterUpload": []Hook{GenericAfterUpdate}}, } fs.Recycle() newFS := getEmptyFS() diff --git a/pkg/filesystem/hooks.go b/pkg/filesystem/hooks.go index e666471..2acfbc4 100644 --- a/pkg/filesystem/hooks.go +++ b/pkg/filesystem/hooks.go @@ -18,28 +18,26 @@ type Hook func(ctx context.Context, fs *FileSystem) error // Use 注入钩子 func (fs *FileSystem) Use(name string, hook Hook) { - switch name { - case "BeforeUpload": - fs.BeforeUpload = append(fs.BeforeUpload, hook) - case "AfterUpload": - fs.AfterUpload = append(fs.AfterUpload, hook) - case "AfterValidateFailed": - fs.AfterValidateFailed = append(fs.AfterValidateFailed, hook) - case "AfterUploadCanceled": - fs.AfterUploadCanceled = append(fs.AfterUploadCanceled, hook) - case "BeforeFileDownload": - fs.BeforeFileDownload = append(fs.BeforeFileDownload, hook) + if fs.Hooks == nil { + fs.Hooks = make(map[string][]Hook) } + if _, ok := fs.Hooks[name]; ok { + fs.Hooks[name] = append(fs.Hooks[name], hook) + return + } + fs.Hooks[name] = []Hook{hook} } // Trigger 触发钩子,遇到第一个错误时 // 返回错误,后续钩子不会继续执行 -func (fs *FileSystem) Trigger(ctx context.Context, hooks []Hook) error { - for _, hook := range hooks { - err := hook(ctx, fs) - if err != nil { - util.Log().Warning("钩子执行失败:%s", err) - return err +func (fs *FileSystem) Trigger(ctx context.Context, name string) error { + if hooks, ok := fs.Hooks[name]; ok { + for _, hook := range hooks { + err := hook(ctx, fs) + if err != nil { + util.Log().Warning("钩子执行失败:%s", err) + return err + } } } return nil @@ -223,7 +221,7 @@ func SlaveAfterUpload(ctx context.Context, fs *FileSystem) error { fs.GenerateThumbnail(ctx, &file) // 发送回调请求 - callbackBody := serializer.UploadCallback{ + callbackBody := serializer.RemoteUploadCallback{ Name: file.Name, SourceName: file.SourceName, PicInfo: file.PicInfo, diff --git a/pkg/filesystem/hooks_test.go b/pkg/filesystem/hooks_test.go index ece638e..3533284 100644 --- a/pkg/filesystem/hooks_test.go +++ b/pkg/filesystem/hooks_test.go @@ -183,11 +183,11 @@ func TestFileSystem_Use(t *testing.T) { // 添加一个 fs.Use("BeforeUpload", hook) - asserts.Len(fs.BeforeUpload, 1) + asserts.Len(fs.Hooks["BeforeUpload"], 1) // 添加一个 fs.Use("BeforeUpload", hook) - asserts.Len(fs.BeforeUpload, 2) + asserts.Len(fs.Hooks["BeforeUpload"], 2) // 不存在 fs.Use("BeforeUpload2333", hook) @@ -219,14 +219,14 @@ func TestFileSystem_Trigger(t *testing.T) { // 一个 fs.Use("BeforeUpload", hook) - err := fs.Trigger(ctx, fs.BeforeUpload) + err := fs.Trigger(ctx, "BeforeUpload") asserts.NoError(err) asserts.Equal(uint64(1), fs.User.Storage) // 多个 fs.Use("BeforeUpload", hook) fs.Use("BeforeUpload", hook) - err = fs.Trigger(ctx, fs.BeforeUpload) + err = fs.Trigger(ctx, "BeforeUpload") asserts.NoError(err) asserts.Equal(uint64(4), fs.User.Storage) } diff --git a/pkg/filesystem/remote/handler.go b/pkg/filesystem/remote/handler.go index 823636c..216db83 100644 --- a/pkg/filesystem/remote/handler.go +++ b/pkg/filesystem/remote/handler.go @@ -57,7 +57,7 @@ func (handler Handler) Source( func (handler Handler) Token(ctx context.Context, TTL int64, key string) (serializer.UploadCredential, error) { // 生成回调地址 siteURL := model.GetSiteURL() - apiBaseURI, _ := url.Parse("/api/v3/callback/upload/" + key) + apiBaseURI, _ := url.Parse("/api/v3/callback/remote/" + key) apiURL := siteURL.ResolveReference(apiBaseURI) // 生成上传策略 diff --git a/pkg/filesystem/upload.go b/pkg/filesystem/upload.go index 5a4b720..5f8d16d 100644 --- a/pkg/filesystem/upload.go +++ b/pkg/filesystem/upload.go @@ -22,7 +22,7 @@ func (fs *FileSystem) Upload(ctx context.Context, file FileHeader) (err error) { ctx = context.WithValue(ctx, fsctx.FileHeaderCtx, file) // 上传前的钩子 - err = fs.Trigger(ctx, fs.BeforeUpload) + err = fs.Trigger(ctx, "BeforeUpload") if err != nil { return err } @@ -47,11 +47,11 @@ func (fs *FileSystem) Upload(ctx context.Context, file FileHeader) (err error) { } // 上传完成后的钩子 - err = fs.Trigger(ctx, fs.AfterUpload) + err = fs.Trigger(ctx, "AfterUpload") if err != nil { // 上传完成后续处理失败 - followUpErr := fs.Trigger(ctx, fs.AfterValidateFailed) + followUpErr := fs.Trigger(ctx, "AfterValidateFailed") // 失败后再失败... if followUpErr != nil { util.Log().Debug("AfterValidateFailed 钩子执行失败,%s", followUpErr) @@ -125,11 +125,11 @@ func (fs *FileSystem) CancelUpload(ctx context.Context, path string, file FileHe default: // 客户端取消上传,删除临时文件 util.Log().Debug("客户端取消上传") - if fs.AfterUploadCanceled == nil { + if fs.Hooks["AfterUploadCanceled"] == nil { return } ctx = context.WithValue(ctx, fsctx.SavePathCtx, path) - err := fs.Trigger(ctx, fs.AfterUploadCanceled) + err := fs.Trigger(ctx, "AfterUploadCanceled") if err != nil { util.Log().Debug("执行 AfterUploadCanceled 钩子出错,%s", err) } @@ -176,6 +176,7 @@ func (fs *FileSystem) GetUploadToken(ctx context.Context, path string, size uint "callback_"+callbackKey, serializer.UploadSession{ UID: fs.User.ID, + PolicyID: fs.User.GetPolicyID(), VirtualPath: path, }, int(callBackSessionTTL), diff --git a/pkg/filesystem/upload_test.go b/pkg/filesystem/upload_test.go index 21081e0..70e0902 100644 --- a/pkg/filesystem/upload_test.go +++ b/pkg/filesystem/upload_test.go @@ -120,7 +120,7 @@ func TestFileSystem_Upload(t *testing.T) { }) err = fs.Upload(ctx, file) asserts.Error(err) - fs.BeforeUpload = nil + fs.Hooks["BeforeUpload"] = nil testHandller.AssertExpectations(t) // 上传文件失败 diff --git a/pkg/request/callback.go b/pkg/request/callback.go index d438832..6458881 100644 --- a/pkg/request/callback.go +++ b/pkg/request/callback.go @@ -13,8 +13,12 @@ import ( ) // RemoteCallback 发送远程存储策略上传回调请求 -func RemoteCallback(url string, body serializer.UploadCallback) error { - callbackBody, err := json.Marshal(body) +func RemoteCallback(url string, body serializer.RemoteUploadCallback) error { + callbackBody, err := json.Marshal(struct { + Data serializer.RemoteUploadCallback `json:"data"` + }{ + Data: body, + }) if err != nil { return serializer.NewError(serializer.CodeCallbackError, "无法编码回调正文", err) } diff --git a/pkg/request/callback_test.go b/pkg/request/callback_test.go index 67e4d99..957a47b 100644 --- a/pkg/request/callback_test.go +++ b/pkg/request/callback_test.go @@ -34,7 +34,7 @@ func TestRemoteCallback(t *testing.T) { }, }) GeneralClient = clientMock - resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ + resp := RemoteCallback("http://test/test/url", serializer.RemoteUploadCallback{ SourceName: "source", }) asserts.NoError(resp) @@ -59,7 +59,7 @@ func TestRemoteCallback(t *testing.T) { }, }) GeneralClient = clientMock - resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ + resp := RemoteCallback("http://test/test/url", serializer.RemoteUploadCallback{ SourceName: "source", }) asserts.EqualValues(401, resp.(serializer.AppError).Code) @@ -83,7 +83,7 @@ func TestRemoteCallback(t *testing.T) { }, }) GeneralClient = clientMock - resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ + resp := RemoteCallback("http://test/test/url", serializer.RemoteUploadCallback{ SourceName: "source", }) asserts.Error(resp) @@ -107,7 +107,7 @@ func TestRemoteCallback(t *testing.T) { }, }) GeneralClient = clientMock - resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ + resp := RemoteCallback("http://test/test/url", serializer.RemoteUploadCallback{ SourceName: "source", }) asserts.Error(resp) @@ -127,7 +127,7 @@ func TestRemoteCallback(t *testing.T) { Err: errors.New("error"), }) GeneralClient = clientMock - resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ + resp := RemoteCallback("http://test/test/url", serializer.RemoteUploadCallback{ SourceName: "source", }) asserts.Error(resp) diff --git a/pkg/serializer/upload.go b/pkg/serializer/upload.go index b56dc8e..0d430b1 100644 --- a/pkg/serializer/upload.go +++ b/pkg/serializer/upload.go @@ -25,11 +25,12 @@ type UploadCredential struct { // UploadSession 上传会话 type UploadSession struct { UID uint + PolicyID uint VirtualPath string } -// UploadCallback 远程存储策略上传回调正文 -type UploadCallback struct { +// RemoteUploadCallback 远程存储策略上传回调正文 +type RemoteUploadCallback struct { Name string `json:"name"` SourceName string `json:"source_name"` PicInfo string `json:"pic_info"` diff --git a/routers/controllers/callback.go b/routers/controllers/callback.go new file mode 100644 index 0000000..53b5a18 --- /dev/null +++ b/routers/controllers/callback.go @@ -0,0 +1,17 @@ +package controllers + +import ( + "github.com/HFO4/cloudreve/service/callback" + "github.com/gin-gonic/gin" +) + +// RemoteCallback 远程上传回调 +func RemoteCallback(c *gin.Context) { + var callbackBody callback.RemoteUploadCallbackService + if err := c.ShouldBindJSON(&callbackBody); err == nil { + res := callbackBody.Process(c) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} diff --git a/routers/router.go b/routers/router.go index 731eccc..816c0be 100644 --- a/routers/router.go +++ b/routers/router.go @@ -114,6 +114,17 @@ func InitMasterRouter() *gin.Engine { } } + // 回调接口 + callback := v3.Group("callback") + { + // 远程上传回调 + callback.POST( + "remote/:key", + middleware.RemoteCallbackAuth(), + controllers.RemoteCallback, + ) + } + // 需要登录保护的 auth := v3.Group("") auth.Use(middleware.AuthRequired()) diff --git a/service/callback/upload.go b/service/callback/upload.go new file mode 100644 index 0000000..afbbe6c --- /dev/null +++ b/service/callback/upload.go @@ -0,0 +1,26 @@ +package callback + +import ( + "github.com/HFO4/cloudreve/pkg/filesystem" + "github.com/HFO4/cloudreve/pkg/serializer" + "github.com/gin-gonic/gin" +) + +// RemoteUploadCallbackService 远程存储上传回调请求服务 +type RemoteUploadCallbackService struct { + Data serializer.RemoteUploadCallback `json:"data" binding:"required"` +} + +// Process 处理远程策略上传结果回调 +func (service *RemoteUploadCallbackService) Process(c *gin.Context) serializer.Response { + // 创建文件系统 + fs, err := filesystem.NewFileSystemFromContext(c) + if err != nil { + return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) + } + defer fs.Recycle() + + return serializer.Response{ + Code: 0, + } +}