diff --git a/models/policy.go b/models/policy.go index 40de2db..ed4ea4e 100644 --- a/models/policy.go +++ b/models/policy.go @@ -4,6 +4,7 @@ import ( "encoding/json" "github.com/HFO4/cloudreve/pkg/util" "github.com/jinzhu/gorm" + "path/filepath" "strconv" "time" ) @@ -69,7 +70,7 @@ func (policy *Policy) SerializeOptions() (err error) { } // GeneratePath 生成存储文件的路径 -func (policy *Policy) GeneratePath(uid uint) string { +func (policy *Policy) GeneratePath(uid uint, path string) string { dirRule := policy.DirNameRule replaceTable := map[string]string{ "{randomkey16}": util.RandStringRunes(16), @@ -78,9 +79,10 @@ func (policy *Policy) GeneratePath(uid uint) string { "{uid}": strconv.Itoa(int(uid)), "{datetime}": time.Now().Format("20060102150405"), "{date}": time.Now().Format("20060102"), + "{path}": path + "/", } dirRule = util.Replace(replaceTable, dirRule) - return dirRule + return filepath.Clean(dirRule) } // GenerateFileName 生成存储文件名 diff --git a/models/policy_test.go b/models/policy_test.go index 6a637f9..d6571d8 100644 --- a/models/policy_test.go +++ b/models/policy_test.go @@ -46,25 +46,32 @@ func TestPolicy_GeneratePath(t *testing.T) { testPolicy := Policy{} testPolicy.DirNameRule = "{randomkey16}" - asserts.Len(testPolicy.GeneratePath(1), 16) + asserts.Len(testPolicy.GeneratePath(1, "/"), 16) testPolicy.DirNameRule = "{randomkey8}" - asserts.Len(testPolicy.GeneratePath(1), 8) + asserts.Len(testPolicy.GeneratePath(1, "/"), 8) testPolicy.DirNameRule = "{timestamp}" - asserts.Equal(testPolicy.GeneratePath(1), strconv.FormatInt(time.Now().Unix(), 10)) + asserts.Equal(testPolicy.GeneratePath(1, "/"), strconv.FormatInt(time.Now().Unix(), 10)) testPolicy.DirNameRule = "{uid}" - asserts.Equal(testPolicy.GeneratePath(1), strconv.Itoa(int(1))) + asserts.Equal(testPolicy.GeneratePath(1, "/"), strconv.Itoa(int(1))) testPolicy.DirNameRule = "{datetime}" - asserts.Len(testPolicy.GeneratePath(1), 14) + asserts.Len(testPolicy.GeneratePath(1, "/"), 14) testPolicy.DirNameRule = "{date}" - asserts.Len(testPolicy.GeneratePath(1), 8) + asserts.Len(testPolicy.GeneratePath(1, "/"), 8) testPolicy.DirNameRule = "123{date}ss{datetime}" - asserts.Len(testPolicy.GeneratePath(1), 27) + asserts.Len(testPolicy.GeneratePath(1, "/"), 27) + + testPolicy.DirNameRule = "/1/{path}/456" + asserts.Condition(func() (success bool) { + res := testPolicy.GeneratePath(1, "/23") + return res == "/1/23/456" || res == "\\1\\23\\456" + }) + } func TestPolicy_GenerateFileName(t *testing.T) { diff --git a/pkg/filesystem/filesystem.go b/pkg/filesystem/filesystem.go index 4be41fc..d7ad378 100644 --- a/pkg/filesystem/filesystem.go +++ b/pkg/filesystem/filesystem.go @@ -10,13 +10,14 @@ import ( "path/filepath" ) -// FileData 上传来的文件数据处理器 -type FileData interface { +// FileHeader 上传来的文件数据处理器 +type FileHeader interface { io.Reader io.Closer GetSize() uint64 GetMIMEType() string GetFileName() string + GetVirtualPath() string } // Handler 存储策略适配器 @@ -77,7 +78,7 @@ func NewFileSystem(user *model.User) (*FileSystem, error) { */ // Upload 上传文件 -func (fs *FileSystem) Upload(ctx context.Context, file FileData) (err error) { +func (fs *FileSystem) Upload(ctx context.Context, file FileHeader) (err error) { ctx = context.WithValue(ctx, FileCtx, file) // 上传前的钩子 @@ -89,7 +90,7 @@ func (fs *FileSystem) Upload(ctx context.Context, file FileData) (err error) { } // 生成文件名和路径 - savePath := fs.GenerateSavePath(file) + savePath := fs.GenerateSavePath(ctx, file) // 处理客户端未完成上传时,关闭连接 go fs.CancelUpload(ctx, savePath, file) @@ -122,15 +123,21 @@ func (fs *FileSystem) Upload(ctx context.Context, file FileData) (err error) { } // GenerateSavePath 生成要存放文件的路径 -func (fs *FileSystem) GenerateSavePath(file FileData) string { +func (fs *FileSystem) GenerateSavePath(ctx context.Context, file FileHeader) string { return filepath.Join( - fs.User.Policy.GeneratePath(fs.User.Model.ID), - fs.User.Policy.GenerateFileName(fs.User.Model.ID, file.GetFileName()), + fs.User.Policy.GeneratePath( + fs.User.Model.ID, + file.GetVirtualPath(), + ), + fs.User.Policy.GenerateFileName( + fs.User.Model.ID, + file.GetFileName(), + ), ) } // CancelUpload 监测客户端取消上传 -func (fs *FileSystem) CancelUpload(ctx context.Context, path string, file FileData) { +func (fs *FileSystem) CancelUpload(ctx context.Context, path string, file FileHeader) { ginCtx := ctx.Value(GinCtx).(*gin.Context) select { case <-ctx.Done(): diff --git a/pkg/filesystem/hooks.go b/pkg/filesystem/hooks.go index d957bad..ad561da 100644 --- a/pkg/filesystem/hooks.go +++ b/pkg/filesystem/hooks.go @@ -4,12 +4,11 @@ import ( "context" "errors" "github.com/HFO4/cloudreve/pkg/util" - "github.com/gin-gonic/gin" ) // GenericBeforeUpload 通用上传前处理钩子,包含数据库操作 func GenericBeforeUpload(ctx context.Context, fs *FileSystem) error { - file := ctx.Value(FileCtx).(FileData) + file := ctx.Value(FileCtx).(FileHeader) // 验证单文件尺寸 if !fs.ValidateFileSize(ctx, file.GetSize()) { @@ -35,7 +34,7 @@ func GenericBeforeUpload(ctx context.Context, fs *FileSystem) error { // GenericAfterUploadCanceled 通用上传取消处理钩子,包含数据库操作 func GenericAfterUploadCanceled(ctx context.Context, fs *FileSystem) error { - file := ctx.Value(FileCtx).(FileData) + file := ctx.Value(FileCtx).(FileHeader) filePath := ctx.Value(SavePathCtx).(string) // 删除临时文件 @@ -55,10 +54,8 @@ func GenericAfterUploadCanceled(ctx context.Context, fs *FileSystem) error { // GenericAfterUpload 文件上传完成后,包含数据库操作 func GenericAfterUpload(ctx context.Context, fs *FileSystem) error { - // 获取Gin的上下文 - ginCtx := ctx.Value(GinCtx).(*gin.Context) // 文件存放的虚拟路径 - virtualPath := util.DotPathToStandardPath(ginCtx.GetHeader("X-Path")) + virtualPath := ctx.Value(FileCtx).(FileHeader).GetVirtualPath() // 检查路径是否存在 if !fs.IsPathExist(virtualPath) { diff --git a/pkg/filesystem/local/file.go b/pkg/filesystem/local/file.go index 9ff4d57..c605079 100644 --- a/pkg/filesystem/local/file.go +++ b/pkg/filesystem/local/file.go @@ -33,11 +33,16 @@ func (file FileData) GetFileName() string { return file.Name } +func (file FileData) GetVirtualPath() string { + return file.Name +} + type FileStream struct { - File io.ReadCloser - Size uint64 - Name string - MIMEType string + File io.ReadCloser + Size uint64 + VirtualPath string + Name string + MIMEType string } func (file FileStream) Read(p []byte) (n int, err error) { @@ -59,3 +64,7 @@ func (file FileStream) Close() error { func (file FileStream) GetFileName() string { return file.Name } + +func (file FileStream) GetVirtualPath() string { + return file.VirtualPath +} diff --git a/pkg/filesystem/local/handler.go b/pkg/filesystem/local/handler.go index 389a5e74..23f56b6 100644 --- a/pkg/filesystem/local/handler.go +++ b/pkg/filesystem/local/handler.go @@ -8,8 +8,8 @@ import ( "path/filepath" ) -type Handler struct { -} +// Handler 本地策略适配器 +type Handler struct{} // Put 将文件流保存到指定目录 func (handler Handler) Put(ctx context.Context, file io.ReadCloser, dst string) error { diff --git a/pkg/util/path_test.go b/pkg/util/path_test.go new file mode 100644 index 0000000..6978d3a --- /dev/null +++ b/pkg/util/path_test.go @@ -0,0 +1,14 @@ +package util + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestDotPathToStandardPath(t *testing.T) { + asserts := assert.New(t) + + asserts.Equal("/", DotPathToStandardPath("")) + asserts.Equal("/目录", DotPathToStandardPath("目录")) + asserts.Equal("/目录/目录2", DotPathToStandardPath("目录,目录2")) +} diff --git a/routers/controllers/file.go b/routers/controllers/file.go index 4f8d0bb..422b0d1 100644 --- a/routers/controllers/file.go +++ b/routers/controllers/file.go @@ -6,6 +6,7 @@ import ( "github.com/HFO4/cloudreve/pkg/filesystem" "github.com/HFO4/cloudreve/pkg/filesystem/local" "github.com/HFO4/cloudreve/pkg/serializer" + "github.com/HFO4/cloudreve/pkg/util" "github.com/HFO4/cloudreve/service/file" "github.com/gin-gonic/gin" "strconv" @@ -54,10 +55,11 @@ func FileUploadStream(c *gin.Context) { } fileData := local.FileStream{ - MIMEType: c.Request.Header.Get("Content-Type"), - File: c.Request.Body, - Size: fileSize, - Name: c.Request.Header.Get("X-FileName"), + MIMEType: c.Request.Header.Get("Content-Type"), + File: c.Request.Body, + Size: fileSize, + Name: c.Request.Header.Get("X-FileName"), + VirtualPath: util.DotPathToStandardPath(c.Request.Header.Get("X-Path")), } user, _ := c.Get("user")