diff --git a/pkg/filesystem/filesystem.go b/pkg/filesystem/filesystem.go index d6d0aa1..026448c 100644 --- a/pkg/filesystem/filesystem.go +++ b/pkg/filesystem/filesystem.go @@ -3,6 +3,7 @@ package filesystem import ( "context" "github.com/HFO4/cloudreve/models" + "github.com/HFO4/cloudreve/pkg/conf" "github.com/HFO4/cloudreve/pkg/filesystem/local" "github.com/HFO4/cloudreve/pkg/filesystem/response" "github.com/gin-gonic/gin" @@ -90,12 +91,15 @@ func NewAnonymousFileSystem() (*FileSystem, error) { User: &model.User{}, } - anonymousGroup, err := model.GetGroupByID(3) - if err != nil { - return nil, err + // 如果是主机模式下,则为匿名文件系统分配游客用户组 + if conf.SystemConfig.Mode == "master" { + anonymousGroup, err := model.GetGroupByID(3) + if err != nil { + return nil, err + } + fs.User.Group = anonymousGroup } - fs.User.Group = anonymousGroup return fs, nil } diff --git a/pkg/filesystem/fsctx/context.go b/pkg/filesystem/fsctx/context.go index 6718bd3..43bda92 100644 --- a/pkg/filesystem/fsctx/context.go +++ b/pkg/filesystem/fsctx/context.go @@ -15,4 +15,6 @@ const ( FileModelCtx // HTTPCtx HTTP请求的上下文 HTTPCtx + // UploadPolicyCtx 上传策略,一般为slave模式下使用 + UploadPolicyCtx ) diff --git a/pkg/filesystem/hooks.go b/pkg/filesystem/hooks.go index 16c607b..8158f47 100644 --- a/pkg/filesystem/hooks.go +++ b/pkg/filesystem/hooks.go @@ -6,6 +6,7 @@ import ( model "github.com/HFO4/cloudreve/models" "github.com/HFO4/cloudreve/pkg/conf" "github.com/HFO4/cloudreve/pkg/filesystem/fsctx" + "github.com/HFO4/cloudreve/pkg/serializer" "github.com/HFO4/cloudreve/pkg/util" "io/ioutil" "strings" @@ -52,6 +53,30 @@ func HookIsFileExist(ctx context.Context, fs *FileSystem) error { return ErrObjectNotExist } +// HookSlaveUploadValidate Slave模式下对文件上传的一系列验证 +// TODO 测试 +func HookSlaveUploadValidate(ctx context.Context, fs *FileSystem) error { + file := ctx.Value(fsctx.FileHeaderCtx).(FileHeader) + policy := ctx.Value(fsctx.UploadPolicyCtx).(serializer.UploadPolicy) + + // 验证单文件尺寸 + if file.GetSize() > policy.MaxSize { + return ErrFileSizeTooBig + } + + // 验证文件名 + if !fs.ValidateLegalName(ctx, file.GetFileName()) { + return ErrIllegalObjectName + } + + // 验证扩展名 + if len(policy.AllowedExtension) > 0 && !IsInExtensionList(policy.AllowedExtension, file.GetFileName()) { + return ErrFileExtensionNotAllowed + } + + return nil +} + // HookValidateFile 一系列对文件检验的集合 func HookValidateFile(ctx context.Context, fs *FileSystem) error { file := ctx.Value(fsctx.FileHeaderCtx).(FileHeader) diff --git a/pkg/filesystem/hooks_test.go b/pkg/filesystem/hooks_test.go index 65b7cd1..231c790 100644 --- a/pkg/filesystem/hooks_test.go +++ b/pkg/filesystem/hooks_test.go @@ -6,8 +6,10 @@ import ( "github.com/DATA-DOG/go-sqlmock" model "github.com/HFO4/cloudreve/models" "github.com/HFO4/cloudreve/pkg/cache" + "github.com/HFO4/cloudreve/pkg/conf" "github.com/HFO4/cloudreve/pkg/filesystem/fsctx" "github.com/HFO4/cloudreve/pkg/filesystem/local" + "github.com/HFO4/cloudreve/pkg/serializer" "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" testMock "github.com/stretchr/testify/mock" @@ -477,3 +479,64 @@ func TestGenericAfterUpdate(t *testing.T) { asserts.Error(err) } } + +func TestHookSlaveUploadValidate(t *testing.T) { + asserts := assert.New(t) + conf.SystemConfig.Mode = "slave" + fs, err := NewAnonymousFileSystem() + conf.SystemConfig.Mode = "master" + asserts.NoError(err) + + // 正常 + { + policy := serializer.UploadPolicy{ + SavePath: "", + MaxSize: 10, + AllowedExtension: nil, + } + file := local.FileStream{Name: "1.txt", Size: 10} + ctx := context.WithValue(context.Background(), fsctx.UploadPolicyCtx, policy) + ctx = context.WithValue(ctx, fsctx.FileHeaderCtx, file) + asserts.NoError(HookSlaveUploadValidate(ctx, fs)) + } + + // 尺寸太大 + { + policy := serializer.UploadPolicy{ + SavePath: "", + MaxSize: 10, + AllowedExtension: nil, + } + file := local.FileStream{Name: "1.txt", Size: 11} + ctx := context.WithValue(context.Background(), fsctx.UploadPolicyCtx, policy) + ctx = context.WithValue(ctx, fsctx.FileHeaderCtx, file) + asserts.Equal(ErrFileSizeTooBig, HookSlaveUploadValidate(ctx, fs)) + } + + // 文件名非法 + { + policy := serializer.UploadPolicy{ + SavePath: "", + MaxSize: 10, + AllowedExtension: nil, + } + file := local.FileStream{Name: "/1.txt", Size: 10} + ctx := context.WithValue(context.Background(), fsctx.UploadPolicyCtx, policy) + ctx = context.WithValue(ctx, fsctx.FileHeaderCtx, file) + asserts.Equal(ErrIllegalObjectName, HookSlaveUploadValidate(ctx, fs)) + } + + // 扩展名非法 + { + policy := serializer.UploadPolicy{ + SavePath: "", + MaxSize: 10, + AllowedExtension: []string{"jpg"}, + } + file := local.FileStream{Name: "1.txt", Size: 10} + ctx := context.WithValue(context.Background(), fsctx.UploadPolicyCtx, policy) + ctx = context.WithValue(ctx, fsctx.FileHeaderCtx, file) + asserts.Equal(ErrFileExtensionNotAllowed, HookSlaveUploadValidate(ctx, fs)) + } + +} diff --git a/pkg/filesystem/upload.go b/pkg/filesystem/upload.go index 05793db..0cce03f 100644 --- a/pkg/filesystem/upload.go +++ b/pkg/filesystem/upload.go @@ -24,14 +24,15 @@ func (fs *FileSystem) Upload(ctx context.Context, file FileHeader) (err error) { return err } - // 生成文件名和路径, 如果是更新操作就从原始文件获取 + // 生成文件名和路径, var savePath string - originFile, ok := ctx.Value(fsctx.FileModelCtx).(model.File) - if ok { + // 如果是更新操作就从上下文中获取 + if originFile, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok { savePath = originFile.SourceName } else { savePath = fs.GenerateSavePath(ctx, file) } + ctx = context.WithValue(ctx, fsctx.SavePathCtx, savePath) // 处理客户端未完成上传时,关闭连接 go fs.CancelUpload(ctx, savePath, file) @@ -43,7 +44,6 @@ func (fs *FileSystem) Upload(ctx context.Context, file FileHeader) (err error) { } // 上传完成后的钩子 - ctx = context.WithValue(ctx, fsctx.SavePathCtx, savePath) err = fs.Trigger(ctx, fs.AfterUpload) if err != nil { @@ -57,21 +57,42 @@ func (fs *FileSystem) Upload(ctx context.Context, file FileHeader) (err error) { return err } - util.Log().Info("新文件PUT:%s , 大小:%d, 上传者:%s", file.GetFileName(), file.GetSize(), fs.User.Nick) + util.Log().Info( + "新文件PUT:%s , 大小:%d, 上传者:%s", + file.GetFileName(), + file.GetSize(), + fs.User.Nick, + ) return nil } // GenerateSavePath 生成要存放文件的路径 +// TODO 完善测试 func (fs *FileSystem) GenerateSavePath(ctx context.Context, file FileHeader) string { + if fs.User.Model.ID != 0 { + return filepath.Join( + fs.User.Policy.GeneratePath( + fs.User.Model.ID, + file.GetVirtualPath(), + ), + fs.User.Policy.GenerateFileName( + fs.User.Model.ID, + file.GetFileName(), + ), + ) + } + + // 匿名文件系统使用空上传策略生成路径 + nilPolicy := model.Policy{} return filepath.Join( - fs.User.Policy.GeneratePath( - fs.User.Model.ID, - file.GetVirtualPath(), + nilPolicy.GeneratePath( + 0, + "", ), - fs.User.Policy.GenerateFileName( - fs.User.Model.ID, - file.GetFileName(), + nilPolicy.GenerateFileName( + 0, + "", ), ) } diff --git a/pkg/serializer/file.go b/pkg/serializer/file.go index 1103778..d20cbf2 100644 --- a/pkg/serializer/file.go +++ b/pkg/serializer/file.go @@ -1,10 +1,33 @@ package serializer +import ( + "encoding/base64" + "encoding/json" +) + // UploadPolicy slave模式下传递的上传策略 type UploadPolicy struct { SavePath string `json:"save_path"` - MaxSize uint64 `json:"save_path"` + MaxSize uint64 `json:"max_size"` AllowedExtension []string `json:"allowed_extension"` CallbackURL string `json:"callback_url"` CallbackKey string `json:"callback_key"` } + +// DecodeUploadPolicy 反序列化Header中携带的上传策略 +// TODO 测试 +func DecodeUploadPolicy(raw string) (*UploadPolicy, error) { + var res UploadPolicy + + rawJSON, err := base64.StdEncoding.DecodeString(raw) + if err != nil { + return nil, err + } + + err = json.Unmarshal(rawJSON, &res) + if err != nil { + return nil, err + } + + return &res, err +} diff --git a/pkg/serializer/file_test.go b/pkg/serializer/file_test.go new file mode 100644 index 0000000..2068e7a --- /dev/null +++ b/pkg/serializer/file_test.go @@ -0,0 +1,55 @@ +package serializer + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestDecodeUploadPolicy(t *testing.T) { + asserts := assert.New(t) + + testCases := []struct { + input string + expectError bool + expectNil bool + expectRes *UploadPolicy + }{ + { + "错误的base64字符", + true, + true, + &UploadPolicy{}, + }, + { + "6ZSZ6K+v55qESlNPTuWtl+espg==", + true, + true, + &UploadPolicy{}, + }, + { + "e30=", + false, + false, + &UploadPolicy{}, + }, + { + "eyJjYWxsYmFja19rZXkiOiJ0ZXN0In0=", + false, + false, + &UploadPolicy{CallbackKey: "test"}, + }, + } + + for _, testCase := range testCases { + res, err := DecodeUploadPolicy(testCase.input) + if testCase.expectError { + asserts.Error(err) + } + if testCase.expectNil { + asserts.Nil(res) + } + if !testCase.expectNil { + asserts.Equal(testCase.expectRes, res) + } + } +} diff --git a/routers/controllers/slave.go b/routers/controllers/slave.go index 524a686..d4ede24 100644 --- a/routers/controllers/slave.go +++ b/routers/controllers/slave.go @@ -1,12 +1,75 @@ package controllers import ( + "context" + "github.com/HFO4/cloudreve/pkg/filesystem" + "github.com/HFO4/cloudreve/pkg/filesystem/fsctx" + "github.com/HFO4/cloudreve/pkg/filesystem/local" "github.com/HFO4/cloudreve/pkg/serializer" "github.com/gin-gonic/gin" + "net/url" + "strconv" ) // SlaveUpload 从机文件上传 func SlaveUpload(c *gin.Context) { + // 创建上下文 + ctx, cancel := context.WithCancel(context.Background()) + ctx = context.WithValue(ctx, fsctx.GinCtx, c) + defer cancel() + + // 创建匿名文件系统 + fs, err := filesystem.NewAnonymousFileSystem() + if err != nil { + c.JSON(200, serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err)) + return + } + + // 从请求中取得上传策略 + uploadPolicyRaw := c.GetHeader("X-Policy") + if uploadPolicyRaw == "" { + c.JSON(200, serializer.ParamErr("未指定上传策略", nil)) + } + + // 解析上传策略 + uploadPolicy, err := serializer.DecodeUploadPolicy(uploadPolicyRaw) + if err != nil { + c.JSON(200, serializer.ParamErr("上传策略格式有误", err)) + } + ctx = context.WithValue(ctx, fsctx.UploadPolicyCtx, uploadPolicy) + + // 取得文件大小 + fileSize, err := strconv.ParseUint(c.Request.Header.Get("Content-Length"), 10, 64) + if err != nil { + c.JSON(200, ErrorResponse(err)) + return + } + + // 解码文件名和路径 + fileName, err := url.QueryUnescape(c.Request.Header.Get("X-FileName")) + if err != nil { + c.JSON(200, ErrorResponse(err)) + return + } + + fileData := local.FileStream{ + MIMEType: c.Request.Header.Get("Content-Type"), + File: c.Request.Body, + Name: fileName, + Size: fileSize, + } + + // 给文件系统分配钩子 + fs.Use("BeforeUpload", filesystem.HookSlaveUploadValidate) + fs.Use("AfterUploadCanceled", filesystem.HookDeleteTempFile) + fs.Use("AfterValidateFailed", filesystem.HookDeleteTempFile) + + // 执行上传 + err = fs.Upload(ctx, fileData) + if err != nil { + c.JSON(200, serializer.Err(serializer.CodeUploadFailed, err.Error(), err)) + return + } c.JSON(200, serializer.Response{ Code: 0,