diff --git a/.gitignore b/.gitignore index ebff2af..17e970d 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ # Development enviroment .idea/* +uploads/* # Version control version.lock \ No newline at end of file diff --git a/models/policy.go b/models/policy.go index e4bdb19..9fd6401 100644 --- a/models/policy.go +++ b/models/policy.go @@ -2,7 +2,10 @@ package model import ( "encoding/json" + "github.com/HFO4/cloudreve/pkg/util" "github.com/jinzhu/gorm" + "strconv" + "time" ) // Policy 存储策略 @@ -64,3 +67,50 @@ func (policy *Policy) SerializeOptions() (err error) { policy.Options = string(optionsValue) return err } + +// GeneratePath 生成存储文件的路径 +func (policy *Policy) GeneratePath(uid uint) string { + dirRule := policy.DirNameRule + replaceTable := map[string]string{ + "{randomkey16}": util.RandStringRunes(16), + "{randomkey8}": util.RandStringRunes(8), + "{timestamp}": strconv.FormatInt(time.Now().Unix(), 10), + "{uid}": strconv.Itoa(int(uid)), + "{datetime}": time.Now().Format("20060102150405"), + "{date}": time.Now().Format("20060102"), + } + dirRule = util.Replace(replaceTable, dirRule) + return dirRule +} + +// GenerateFileName 生成存储文件名 +func (policy *Policy) GenerateFileName(uid uint, origin string) string { + fileRule := policy.FileNameRule + + replaceTable := map[string]string{ + "{randomkey16}": util.RandStringRunes(16), + "{randomkey8}": util.RandStringRunes(8), + "{timestamp}": strconv.FormatInt(time.Now().Unix(), 10), + "{uid}": strconv.Itoa(int(uid)), + "{datetime}": time.Now().Format("20060102150405"), + "{date}": time.Now().Format("20060102"), + } + + // 部分存储策略可以使用{origin}代表原始文件名 + switch policy.Type { + case "qiniu": + // 七牛会将$(fname)自动替换为原始文件名 + replaceTable["{originname}"] = "$(fname)" + case "local": + replaceTable["{originname}"] = origin + case "oss": + // OSS会将${filename}自动替换为原始文件名 + replaceTable["{originname}"] = "${filename}" + case "upyun": + // Upyun会将{filename}{.suffix}自动替换为原始文件名 + replaceTable["{originname}"] = "{filename}{.suffix}" + } + + fileRule = util.Replace(replaceTable, fileRule) + return fileRule +} diff --git a/models/policy_test.go b/models/policy_test.go index f0742a1..6a637f9 100644 --- a/models/policy_test.go +++ b/models/policy_test.go @@ -4,7 +4,9 @@ import ( "encoding/json" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" + "strconv" "testing" + "time" ) func TestGetPolicyByID(t *testing.T) { @@ -38,3 +40,72 @@ func TestPolicy_BeforeSave(t *testing.T) { asserts.Equal(string(expected), testPolicy.Options) } + +func TestPolicy_GeneratePath(t *testing.T) { + asserts := assert.New(t) + testPolicy := Policy{} + + testPolicy.DirNameRule = "{randomkey16}" + asserts.Len(testPolicy.GeneratePath(1), 16) + + testPolicy.DirNameRule = "{randomkey8}" + asserts.Len(testPolicy.GeneratePath(1), 8) + + testPolicy.DirNameRule = "{timestamp}" + asserts.Equal(testPolicy.GeneratePath(1), strconv.FormatInt(time.Now().Unix(), 10)) + + testPolicy.DirNameRule = "{uid}" + asserts.Equal(testPolicy.GeneratePath(1), strconv.Itoa(int(1))) + + testPolicy.DirNameRule = "{datetime}" + asserts.Len(testPolicy.GeneratePath(1), 14) + + testPolicy.DirNameRule = "{date}" + asserts.Len(testPolicy.GeneratePath(1), 8) + + testPolicy.DirNameRule = "123{date}ss{datetime}" + asserts.Len(testPolicy.GeneratePath(1), 27) +} + +func TestPolicy_GenerateFileName(t *testing.T) { + asserts := assert.New(t) + testPolicy := Policy{} + + testPolicy.FileNameRule = "{randomkey16}" + asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 16) + + testPolicy.FileNameRule = "{randomkey8}" + asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 8) + + testPolicy.FileNameRule = "{timestamp}" + asserts.Equal(testPolicy.GenerateFileName(1, "123.txt"), strconv.FormatInt(time.Now().Unix(), 10)) + + testPolicy.FileNameRule = "{uid}" + asserts.Equal(testPolicy.GenerateFileName(1, "123.txt"), strconv.Itoa(int(1))) + + testPolicy.FileNameRule = "{datetime}" + asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 14) + + testPolicy.FileNameRule = "{date}" + asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 8) + + testPolicy.FileNameRule = "123{date}ss{datetime}" + asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 27) + + // 支持{originname}的策略 + testPolicy.Type = "local" + testPolicy.FileNameRule = "123{originname}" + asserts.Equal("123123.txt", testPolicy.GenerateFileName(1, "123.txt")) + + testPolicy.Type = "qiniu" + testPolicy.FileNameRule = "{uid}123{originname}" + asserts.Equal("1123$(fname)", testPolicy.GenerateFileName(1, "123.txt")) + + testPolicy.Type = "oss" + testPolicy.FileNameRule = "{uid}123{originname}" + asserts.Equal("1123${filename}", testPolicy.GenerateFileName(1, "")) + + testPolicy.Type = "upyun" + testPolicy.FileNameRule = "{uid}123{originname}" + asserts.Equal("1123{filename}{.suffix}", testPolicy.GenerateFileName(1, "")) +} diff --git a/pkg/filesystem/errors.go b/pkg/filesystem/errors.go new file mode 100644 index 0000000..211ff72 --- /dev/null +++ b/pkg/filesystem/errors.go @@ -0,0 +1,10 @@ +package filesystem + +import "errors" + +var ( + UnknownPolicyTypeError = errors.New("未知存储策略类型") + FileSizeTooBigError = errors.New("单个文件尺寸太大") + FileExtensionNotAllowedError = errors.New("不允许上传此类型的文件") + InsufficientCapacityError = errors.New("容量空间不足") +) diff --git a/pkg/filesystem/filesystem.go b/pkg/filesystem/filesystem.go index 4b8d687..91fcfc3 100644 --- a/pkg/filesystem/filesystem.go +++ b/pkg/filesystem/filesystem.go @@ -3,7 +3,9 @@ package filesystem import ( "context" "github.com/HFO4/cloudreve/models" + "github.com/HFO4/cloudreve/pkg/filesystem/local" "io" + "path/filepath" ) // FileData 上传来的文件数据处理器 @@ -15,6 +17,11 @@ type FileData interface { GetFileName() string } +// Handler 存储策略适配器 +type Handler interface { + Put(ctx context.Context, file io.ReadCloser, dst string) error +} + // FileSystem 管理文件的文件系统 type FileSystem struct { /* @@ -29,27 +36,58 @@ type FileSystem struct { BeforeUpload func(ctx context.Context, fs *FileSystem, file FileData) error // 上传文件后 AfterUpload func(ctx context.Context, fs *FileSystem) error - // 文件验证失败后 + // 文件保存成功,插入数据库验证失败后 ValidateFailed func(ctx context.Context, fs *FileSystem) error /* 文件系统处理适配器 */ - + Handler Handler } // NewFileSystem 初始化一个文件系统 -func NewFileSystem(user *model.User) *FileSystem { - return &FileSystem{ - User: user, +func NewFileSystem(user *model.User) (*FileSystem, error) { + var handler Handler + + // 根据存储策略类型分配适配器 + switch user.Policy.Type { + case "local": + handler = local.Handler{} + default: + return nil, UnknownPolicyTypeError } + + // TODO 分配默认钩子 + return &FileSystem{ + User: user, + Handler: handler, + }, nil } // Upload 上传文件 func (fs *FileSystem) Upload(ctx context.Context, file FileData) (err error) { + // 上传前的钩子 err = fs.BeforeUpload(ctx, fs, file) if err != nil { return err } + + // 生成文件名和路径 + savePath := fs.GenerateSavePath(file) + + // 保存文件 + err = fs.Handler.Put(ctx, file, savePath) + if err != nil { + return err + } + return nil } + +// GenerateSavePath 生成要存放文件的路径 +func (fs *FileSystem) GenerateSavePath(file FileData) string { + return filepath.Join( + fs.User.Policy.GeneratePath(fs.User.Model.ID), + fs.User.Policy.GenerateFileName(fs.User.Model.ID, file.GetFileName()), + ) +} diff --git a/pkg/filesystem/hook.go b/pkg/filesystem/hook.go index c577c53..a0d606e 100644 --- a/pkg/filesystem/hook.go +++ b/pkg/filesystem/hook.go @@ -2,24 +2,23 @@ package filesystem import ( "context" - "errors" ) // GenericBeforeUpload 通用上传前处理钩子,包含数据库操作 func GenericBeforeUpload(ctx context.Context, fs *FileSystem, file FileData) error { // 验证单文件尺寸 if !fs.ValidateFileSize(ctx, file.GetSize()) { - return errors.New("单个文件尺寸太大") - } - - // 验证并扣除容量 - if !fs.ValidateCapacity(ctx, file.GetSize()) { - return errors.New("容量空间不足") + return FileSizeTooBigError } // 验证扩展名 if !fs.ValidateExtension(ctx, file.GetFileName()) { - return errors.New("不允许上传此类型的文件") + return FileExtensionNotAllowedError + } + + // 验证并扣除容量 + if !fs.ValidateCapacity(ctx, file.GetSize()) { + return InsufficientCapacityError } return nil } diff --git a/pkg/filesystem/local/file.go b/pkg/filesystem/local/file.go index 2f950d0..043fc0c 100644 --- a/pkg/filesystem/local/file.go +++ b/pkg/filesystem/local/file.go @@ -11,7 +11,7 @@ type FileData struct { } func (file FileData) Read(p []byte) (n int, err error) { - return file.Read(p) + return file.File.Read(p) } func (file FileData) GetMIMEType() string { @@ -23,7 +23,7 @@ func (file FileData) GetSize() uint64 { } func (file FileData) Close() error { - return file.Close() + return file.File.Close() } func (file FileData) GetFileName() string { diff --git a/pkg/filesystem/local/handler.go b/pkg/filesystem/local/handler.go new file mode 100644 index 0000000..9a0b9d6 --- /dev/null +++ b/pkg/filesystem/local/handler.go @@ -0,0 +1,37 @@ +package local + +import ( + "context" + "fmt" + "github.com/HFO4/cloudreve/pkg/util" + "io" + "os" + "path/filepath" +) + +type Handler struct { +} + +// Put 将文件流保存到指定目录 +func (handler Handler) Put(ctx context.Context, file io.ReadCloser, dst string) error { + defer file.Close() + + // 如果目标目录不存在,创建 + basePath := filepath.Dir(dst) + if !util.Exists(basePath) { + fmt.Println("创建", basePath) + err := os.MkdirAll(basePath, 0666) + if err != nil { + return err + } + } + + out, err := os.Create(dst) + if err != nil { + return err + } + defer out.Close() + + _, err = io.Copy(out, file) + return err +} diff --git a/pkg/filesystem/local/local.go b/pkg/filesystem/local/local.go deleted file mode 100644 index 469c3dc..0000000 --- a/pkg/filesystem/local/local.go +++ /dev/null @@ -1 +0,0 @@ -package local diff --git a/pkg/util/common.go b/pkg/util/common.go index e136eee..ff974d8 100644 --- a/pkg/util/common.go +++ b/pkg/util/common.go @@ -2,6 +2,7 @@ package util import ( "math/rand" + "strings" ) // RandStringRunes 返回随机字符串 @@ -34,3 +35,11 @@ func ContainsString(s []string, e string) bool { } return false } + +// Replace 根据替换表执行批量替换 +func Replace(table map[string]string, s string) string { + for key, value := range table { + s = strings.Replace(s, key, value, -1) + } + return s +} diff --git a/routers/controllers/file.go b/routers/controllers/file.go index 2567e02..e2bd5f1 100644 --- a/routers/controllers/file.go +++ b/routers/controllers/file.go @@ -17,11 +17,8 @@ func FileUpload(c *gin.Context) { return } - var ( - ctx context.Context - cancel context.CancelFunc - ) - ctx, cancel = context.WithCancel(context.Background()) + // 建立上下文 + ctx, cancel := context.WithCancel(context.Background()) var service file.UploadService defer cancel() diff --git a/service/file/upload.go b/service/file/upload.go index 498925b..053dfe3 100644 --- a/service/file/upload.go +++ b/service/file/upload.go @@ -2,7 +2,7 @@ package file import ( "context" - "github.com/HFO4/cloudreve/models" + model "github.com/HFO4/cloudreve/models" "github.com/HFO4/cloudreve/pkg/filesystem" "github.com/HFO4/cloudreve/pkg/filesystem/local" "github.com/HFO4/cloudreve/pkg/serializer" @@ -32,11 +32,17 @@ func (service *UploadService) Upload(ctx context.Context, c *gin.Context) serial Name: service.Name, } user, _ := c.Get("user") - fs := filesystem.FileSystem{ - BeforeUpload: filesystem.GenericBeforeUpload, - User: user.(*model.User), + + // 创建文件系统 + fs, err := filesystem.NewFileSystem(user.(*model.User)) + if err != nil { + return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) } + // 给文件系统分配钩子 + fs.BeforeUpload = filesystem.GenericBeforeUpload + + // 执行上传 err = fs.Upload(ctx, fileData) if err != nil { return serializer.Err(serializer.CodeUploadFailed, err.Error(), err)