From 868a88e5fc5a1162276c33e4a8c12f0044883b60 Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Sun, 27 Feb 2022 14:03:07 +0800 Subject: [PATCH] Refactor: use universal FileHeader when handling file upload, remove usage of global ctx with FileHeader, SavePath, DisableOverwrite --- models/file.go | 23 ++++- pkg/aria2/monitor/monitor.go | 13 ++- pkg/filesystem/archive.go | 8 +- pkg/filesystem/driver/cos/handler.go | 16 ++-- pkg/filesystem/driver/handler.go | 6 +- pkg/filesystem/driver/local/file.go | 38 -------- pkg/filesystem/driver/local/handler.go | 10 +-- pkg/filesystem/driver/onedrive/api.go | 7 +- pkg/filesystem/driver/onedrive/handler.go | 23 ++--- .../driver/onedrive/handler_test.go | 10 +-- pkg/filesystem/driver/oss/handler.go | 20 ++--- pkg/filesystem/driver/oss/handler_test.go | 4 +- pkg/filesystem/driver/qiniu/handler.go | 23 ++--- pkg/filesystem/driver/remote/handler.go | 19 ++-- pkg/filesystem/driver/remote/handler_test.go | 2 +- pkg/filesystem/driver/s3/handler.go | 17 ++-- .../driver/shadow/masterinslave/handler.go | 8 +- .../driver/shadow/slaveinmaster/handler.go | 7 +- pkg/filesystem/driver/upyun/handler.go | 25 ++---- pkg/filesystem/file.go | 34 +++----- pkg/filesystem/file_test.go | 3 +- pkg/filesystem/filesystem.go | 26 ++---- pkg/filesystem/fsctx/context.go | 2 - pkg/filesystem/fsctx/stream.go | 87 +++++++++++++++++++ .../file_test.go => fsctx/stream_test.go} | 2 +- pkg/filesystem/hooks.go | 75 ++++++---------- pkg/filesystem/hooks_test.go | 34 ++++---- pkg/filesystem/upload.go | 60 ++++++------- pkg/filesystem/upload_test.go | 7 +- pkg/task/compress.go | 3 +- pkg/task/decompress.go | 7 +- pkg/task/import.go | 8 +- pkg/task/slavetask/transfer.go | 8 +- pkg/task/tranfer.go | 12 ++- pkg/webdav/webdav.go | 9 +- routers/controllers/file.go | 7 +- routers/controllers/slave.go | 6 +- service/callback/upload.go | 10 +-- service/explorer/file.go | 11 ++- 39 files changed, 331 insertions(+), 359 deletions(-) delete mode 100644 pkg/filesystem/driver/local/file.go create mode 100644 pkg/filesystem/fsctx/stream.go rename pkg/filesystem/{driver/local/file_test.go => fsctx/stream_test.go} (98%) diff --git a/models/file.go b/models/file.go index 453e76f..39f50c4 100644 --- a/models/file.go +++ b/models/file.go @@ -2,6 +2,7 @@ package model import ( "encoding/gob" + "encoding/json" "path" "time" @@ -20,12 +21,15 @@ type File struct { PicInfo string FolderID uint `gorm:"index:folder_id;unique_index:idx_only_one"` PolicyID uint + Hidden bool + Metadata string `gorm:"type:text"` // 关联模型 Policy Policy `gorm:"PRELOAD:false,association_autoupdate:false"` // 数据库忽略字段 - Position string `gorm:"-"` + Position string `gorm:"-"` + MetadataSerialized map[string]string `gorm:"-"` } func init() { @@ -42,6 +46,23 @@ func (file *File) Create() (uint, error) { return file.ID, nil } +// AfterFind 找到文件后的钩子 +func (file *File) AfterFind() (err error) { + // 反序列化文件元数据 + if file.Metadata != "" { + err = json.Unmarshal([]byte(file.Metadata), &file.MetadataSerialized) + } + + return +} + +// BeforeSave Save策略前的钩子 +func (file *File) BeforeSave() (err error) { + metaValue, err := json.Marshal(&file.MetadataSerialized) + file.Metadata = string(metaValue) + return err +} + // GetChildFile 查找目录下名为name的子文件 func (folder *Folder) GetChildFile(name string) (*File, error) { var file File diff --git a/pkg/aria2/monitor/monitor.go b/pkg/aria2/monitor/monitor.go index b989826..1f8954b 100644 --- a/pkg/aria2/monitor/monitor.go +++ b/pkg/aria2/monitor/monitor.go @@ -13,7 +13,6 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/task" @@ -191,12 +190,12 @@ func (monitor *Monitor) ValidateFile() error { defer fs.Recycle() // 创建上下文环境 - ctx := context.WithValue(context.Background(), fsctx.FileHeaderCtx, local.FileStream{ + file := &fsctx.FileStream{ Size: monitor.Task.TotalSize, - }) + } // 验证用户容量 - if err := filesystem.HookValidateCapacityWithoutIncrease(ctx, fs); err != nil { + if err := filesystem.HookValidateCapacityWithoutIncrease(context.Background(), fs, file); err != nil { return err } @@ -205,11 +204,11 @@ func (monitor *Monitor) ValidateFile() error { if fileInfo.Selected == "true" { // 创建上下文环境 fileSize, _ := strconv.ParseUint(fileInfo.Length, 10, 64) - ctx := context.WithValue(context.Background(), fsctx.FileHeaderCtx, local.FileStream{ + file := &fsctx.FileStream{ Size: fileSize, Name: filepath.Base(fileInfo.Path), - }) - if err := filesystem.HookValidateFile(ctx, fs); err != nil { + } + if err := filesystem.HookValidateFile(context.Background(), fs, file); err != nil { return err } } diff --git a/pkg/filesystem/archive.go b/pkg/filesystem/archive.go index 1359e28..8d4dcfd 100644 --- a/pkg/filesystem/archive.go +++ b/pkg/filesystem/archive.go @@ -303,7 +303,13 @@ func (fs *FileSystem) Decompress(ctx context.Context, src, dst string) error { } }() - err = fs.UploadFromStream(ctx, fileStream, savePath, uint64(size)) + err = fs.UploadFromStream(ctx, &fsctx.FileStream{ + File: fileStream, + Size: uint64(size), + Name: path.Base(dst), + VirtualPath: path.Dir(dst), + Mode: fsctx.Create, + }) fileStream.Close() if err != nil { util.Log().Debug("无法上传压缩包内的文件%s , %s , 跳过", rawPath, err) diff --git a/pkg/filesystem/driver/cos/handler.go b/pkg/filesystem/driver/cos/handler.go index 63c0057..589c2fd 100644 --- a/pkg/filesystem/driver/cos/handler.go +++ b/pkg/filesystem/driver/cos/handler.go @@ -183,9 +183,9 @@ func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, } // Put 将文件流保存到指定目录 -func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error { +func (handler Driver) Put(ctx context.Context, file fsctx.FileHeader) error { opt := &cossdk.ObjectPutOptions{} - _, err := handler.Client.Object.Put(ctx, dst, file, opt) + _, err := handler.Client.Object.Put(ctx, file.GetSavePath(), file, opt) return err } @@ -324,13 +324,7 @@ func (handler Driver) signSourceURL(ctx context.Context, path string, ttl int64, } // Token 获取上传策略和认证Token -func (handler Driver) Token(ctx context.Context, TTL int64, uploadSession *serializer.UploadSession) (serializer.UploadCredential, error) { - // 读取上下文中生成的存储路径 - savePath, ok := ctx.Value(fsctx.SavePathCtx).(string) - if !ok { - return serializer.UploadCredential{}, errors.New("无法获取存储路径") - } - +func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (serializer.UploadCredential, error) { // 生成回调地址 siteURL := model.GetSiteURL() apiBaseURI, _ := url.Parse("/api/v3/callback/cos/" + uploadSession.Key) @@ -338,13 +332,13 @@ func (handler Driver) Token(ctx context.Context, TTL int64, uploadSession *seria // 上传策略 startTime := time.Now() - endTime := startTime.Add(time.Duration(TTL) * time.Second) + endTime := startTime.Add(time.Duration(ttl) * time.Second) keyTime := fmt.Sprintf("%d;%d", startTime.Unix(), endTime.Unix()) postPolicy := UploadPolicy{ Expiration: endTime.UTC().Format(time.RFC3339), Conditions: []interface{}{ map[string]string{"bucket": handler.Policy.BucketName}, - map[string]string{"$key": savePath}, + map[string]string{"$key": file.GetSavePath()}, map[string]string{"x-cos-meta-callback": apiURL}, map[string]string{"x-cos-meta-key": uploadSession.Key}, map[string]string{"q-sign-algorithm": "sha1"}, diff --git a/pkg/filesystem/driver/handler.go b/pkg/filesystem/driver/handler.go index 279253f..5608764 100644 --- a/pkg/filesystem/driver/handler.go +++ b/pkg/filesystem/driver/handler.go @@ -2,9 +2,9 @@ package driver import ( "context" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "io" "net/url" ) @@ -12,7 +12,7 @@ import ( type Handler interface { // 上传文件, dst为文件存储路径,size 为文件大小。上下文关闭 // 时,应取消上传并清理临时文件 - Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error + Put(ctx context.Context, file fsctx.FileHeader) error // 删除一个或多个给定路径的文件,返回删除失败的文件路径列表及错误 Delete(ctx context.Context, files []string) ([]string, error) @@ -30,7 +30,7 @@ type Handler interface { Source(ctx context.Context, path string, url url.URL, ttl int64, isDownload bool, speed int) (string, error) // Token 获取有效期为ttl的上传凭证和签名 - Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession) (serializer.UploadCredential, error) + Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (serializer.UploadCredential, error) // List 递归列取远程端path路径下文件、目录,不包含path本身, // 返回的对象路径以path作为起始根目录. diff --git a/pkg/filesystem/driver/local/file.go b/pkg/filesystem/driver/local/file.go deleted file mode 100644 index 6bad732..0000000 --- a/pkg/filesystem/driver/local/file.go +++ /dev/null @@ -1,38 +0,0 @@ -package local - -import ( - "io" -) - -// FileStream 用户传来的文件 -type FileStream struct { - File io.ReadCloser - Size uint64 - VirtualPath string - Name string - MIMEType string -} - -func (file FileStream) Read(p []byte) (n int, err error) { - return file.File.Read(p) -} - -func (file FileStream) GetMIMEType() string { - return file.MIMEType -} - -func (file FileStream) GetSize() uint64 { - return file.Size -} - -func (file FileStream) Close() error { - return file.File.Close() -} - -func (file FileStream) GetFileName() string { - return file.Name -} - -func (file FileStream) GetVirtualPath() string { - return file.VirtualPath -} diff --git a/pkg/filesystem/driver/local/handler.go b/pkg/filesystem/driver/local/handler.go index 02417fc..b010b10 100644 --- a/pkg/filesystem/driver/local/handler.go +++ b/pkg/filesystem/driver/local/handler.go @@ -83,12 +83,12 @@ func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, } // Put 将文件流保存到指定目录 -func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error { +func (handler Driver) Put(ctx context.Context, file fsctx.FileHeader) error { defer file.Close() - dst = util.RelativePath(filepath.FromSlash(dst)) + dst := util.RelativePath(filepath.FromSlash(file.GetSavePath())) - // 如果禁止了 Overwrite,则检查是否有重名冲突 - if ctx.Value(fsctx.DisableOverwrite) != nil { + // 如果非 Overwrite,则检查是否有重名冲突 + if file.GetMode() != fsctx.Overwrite { if util.Exists(dst) { util.Log().Warning("物理同名文件已存在或不可用: %s", dst) return errors.New("物理同名文件已存在或不可用") @@ -214,7 +214,7 @@ func (handler Driver) Source( } // Token 获取上传策略和认证Token,本地策略直接返回空值 -func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession) (serializer.UploadCredential, error) { +func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (serializer.UploadCredential, error) { return serializer.UploadCredential{ SessionID: uploadSession.Key, }, nil diff --git a/pkg/filesystem/driver/onedrive/api.go b/pkg/filesystem/driver/onedrive/api.go index 784d906..906c89a 100644 --- a/pkg/filesystem/driver/onedrive/api.go +++ b/pkg/filesystem/driver/onedrive/api.go @@ -257,13 +257,16 @@ func (client *Client) UploadChunk(ctx context.Context, uploadURL string, chunk * } // Upload 上传文件 -func (client *Client) Upload(ctx context.Context, dst string, size int, file io.Reader) error { +func (client *Client) Upload(ctx context.Context, file fsctx.FileHeader) error { // 决定是否覆盖文件 overwrite := "replace" - if ctx.Value(fsctx.DisableOverwrite) != nil { + if file.GetMode() != fsctx.Overwrite { overwrite = "fail" } + size := int(file.GetSize()) + dst := file.GetSavePath() + // 小文件,使用简单上传接口上传 if size <= int(SmallFileSize) { _, err := client.SimpleUpload(ctx, dst, file, int64(size), WithConflictBehavior(overwrite)) diff --git a/pkg/filesystem/driver/onedrive/handler.go b/pkg/filesystem/driver/onedrive/handler.go index 15c7f77..df92549 100644 --- a/pkg/filesystem/driver/onedrive/handler.go +++ b/pkg/filesystem/driver/onedrive/handler.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "io" "net/url" "path" "path/filepath" @@ -121,9 +120,9 @@ func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, } // Put 将文件流保存到指定目录 -func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error { +func (handler Driver) Put(ctx context.Context, file fsctx.FileHeader) error { defer file.Close() - return handler.Client.Upload(ctx, dst, int(size), file) + return handler.Client.Upload(ctx, file) } // Delete 删除一个或多个文件, @@ -223,20 +222,10 @@ func (handler Driver) replaceSourceHost(origin string) (string, error) { } // Token 获取上传会话URL -func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession) (serializer.UploadCredential, error) { - - // 读取上下文中生成的存储路径和文件大小 - savePath, ok := ctx.Value(fsctx.SavePathCtx).(string) - if !ok { - return serializer.UploadCredential{}, errors.New("无法获取存储路径") - } - fileSize, ok := ctx.Value(fsctx.FileSizeCtx).(uint64) - if !ok { - return serializer.UploadCredential{}, errors.New("无法获取文件大小") - } +func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (serializer.UploadCredential, error) { // 如果小于4MB,则由服务端中转 - if fileSize <= SmallFileSize { + if file.GetSize() <= SmallFileSize { return serializer.UploadCredential{}, nil } @@ -245,13 +234,13 @@ func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *seria apiBaseURI, _ := url.Parse("/api/v3/callback/onedrive/finish/" + uploadSession.Key) apiURL := siteURL.ResolveReference(apiBaseURI) - uploadURL, err := handler.Client.CreateUploadSession(ctx, savePath, WithConflictBehavior("fail")) + uploadURL, err := handler.Client.CreateUploadSession(ctx, file.GetSavePath(), WithConflictBehavior("fail")) if err != nil { return serializer.UploadCredential{}, err } // 监控回调及上传 - go handler.Client.MonitorUpload(uploadURL, uploadSession.Key, savePath, fileSize, ttl) + go handler.Client.MonitorUpload(uploadURL, uploadSession.Key, file.GetSavePath(), file.GetSize(), ttl) return serializer.UploadCredential{ Policy: uploadURL, diff --git a/pkg/filesystem/driver/onedrive/handler_test.go b/pkg/filesystem/driver/onedrive/handler_test.go index ce64048..c2b00ae 100644 --- a/pkg/filesystem/driver/onedrive/handler_test.go +++ b/pkg/filesystem/driver/onedrive/handler_test.go @@ -36,7 +36,7 @@ func TestDriver_Token(t *testing.T) { // 无法获取文件路径 { ctx := context.WithValue(context.Background(), fsctx.FileSizeCtx, uint64(10)) - res, err := handler.Token(ctx, 10, "key") + res, err := handler.Token(ctx, 10, "key", nil) asserts.Error(err) asserts.Equal(serializer.UploadCredential{}, res) } @@ -44,7 +44,7 @@ func TestDriver_Token(t *testing.T) { // 无法获取文件大小 { ctx := context.WithValue(context.Background(), fsctx.SavePathCtx, "/123") - res, err := handler.Token(ctx, 10, "key") + res, err := handler.Token(ctx, 10, "key", nil) asserts.Error(err) asserts.Equal(serializer.UploadCredential{}, res) } @@ -53,7 +53,7 @@ func TestDriver_Token(t *testing.T) { { ctx := context.WithValue(context.Background(), fsctx.SavePathCtx, "/123") ctx = context.WithValue(ctx, fsctx.FileSizeCtx, uint64(10)) - res, err := handler.Token(ctx, 10, "key") + res, err := handler.Token(ctx, 10, "key", nil) asserts.NoError(err) asserts.Equal(serializer.UploadCredential{}, res) } @@ -80,7 +80,7 @@ func TestDriver_Token(t *testing.T) { handler.Client.Request = clientMock ctx := context.WithValue(context.Background(), fsctx.SavePathCtx, "/123") ctx = context.WithValue(ctx, fsctx.FileSizeCtx, uint64(20*1024*1024)) - res, err := handler.Token(ctx, 10, "key") + res, err := handler.Token(ctx, 10, "key", nil) asserts.Error(err) asserts.Equal(serializer.UploadCredential{}, res) } @@ -114,7 +114,7 @@ func TestDriver_Token(t *testing.T) { time.Sleep(time.Duration(1) * time.Second) FinishCallback("key") }() - res, err := handler.Token(ctx, 10, "key") + res, err := handler.Token(ctx, 10, "key", nil) asserts.NoError(err) asserts.Equal("123321", res.Policy) } diff --git a/pkg/filesystem/driver/oss/handler.go b/pkg/filesystem/driver/oss/handler.go index b7ec204..b8fd9e4 100644 --- a/pkg/filesystem/driver/oss/handler.go +++ b/pkg/filesystem/driver/oss/handler.go @@ -224,7 +224,7 @@ func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, } // Put 将文件流保存到指定目录 -func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error { +func (handler Driver) Put(ctx context.Context, file fsctx.FileHeader) error { defer file.Close() // 初始化客户端 @@ -237,7 +237,7 @@ func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, s // 是否允许覆盖 overwrite := true - if ctx.Value(fsctx.DisableOverwrite) != nil { + if file.GetMode() != fsctx.Overwrite { overwrite = false } @@ -247,7 +247,7 @@ func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, s } // 上传文件 - err := handler.bucket.PutObject(dst, file, options...) + err := handler.bucket.PutObject(file.GetSavePath(), file, options...) if err != nil { return err } @@ -397,13 +397,7 @@ func (handler Driver) signSourceURL(ctx context.Context, path string, ttl int64, } // Token 获取上传策略和认证Token -func (handler Driver) Token(ctx context.Context, TTL int64, uploadSession *serializer.UploadSession) (serializer.UploadCredential, error) { - // 读取上下文中生成的存储路径 - savePath, ok := ctx.Value(fsctx.SavePathCtx).(string) - if !ok { - return serializer.UploadCredential{}, errors.New("无法获取存储路径") - } - +func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (serializer.UploadCredential, error) { // 生成回调地址 siteURL := model.GetSiteURL() apiBaseURI, _ := url.Parse("/api/v3/callback/oss/" + uploadSession.Key) @@ -418,10 +412,10 @@ func (handler Driver) Token(ctx context.Context, TTL int64, uploadSession *seria // 上传策略 postPolicy := UploadPolicy{ - Expiration: time.Now().UTC().Add(time.Duration(TTL) * time.Second).Format(time.RFC3339), + Expiration: time.Now().UTC().Add(time.Duration(ttl) * time.Second).Format(time.RFC3339), Conditions: []interface{}{ map[string]string{"bucket": handler.Policy.BucketName}, - []string{"starts-with", "$key", path.Dir(savePath)}, + []string{"starts-with", "$key", path.Dir(file.GetSavePath())}, }, } @@ -430,7 +424,7 @@ func (handler Driver) Token(ctx context.Context, TTL int64, uploadSession *seria []interface{}{"content-length-range", 0, handler.Policy.MaxSize}) } - return handler.getUploadCredential(ctx, postPolicy, callbackPolicy, TTL) + return handler.getUploadCredential(ctx, postPolicy, callbackPolicy, ttl) } func (handler Driver) getUploadCredential(ctx context.Context, policy UploadPolicy, callback CallbackPolicy, TTL int64) (serializer.UploadCredential, error) { diff --git a/pkg/filesystem/driver/oss/handler_test.go b/pkg/filesystem/driver/oss/handler_test.go index 58401f3..d69c931 100644 --- a/pkg/filesystem/driver/oss/handler_test.go +++ b/pkg/filesystem/driver/oss/handler_test.go @@ -80,7 +80,7 @@ func TestDriver_Token(t *testing.T) { { ctx := context.WithValue(context.Background(), fsctx.SavePathCtx, "/123") cache.Set("setting_siteURL", "http://test.cloudreve.org", 0) - res, err := handler.Token(ctx, 10, "key") + res, err := handler.Token(ctx, 10, "key", nil) asserts.NoError(err) asserts.NotEmpty(res.Policy) asserts.NotEmpty(res.Token) @@ -91,7 +91,7 @@ func TestDriver_Token(t *testing.T) { // 上下文错误 { ctx := context.Background() - _, err := handler.Token(ctx, 10, "key") + _, err := handler.Token(ctx, 10, "key", nil) asserts.Error(err) } diff --git a/pkg/filesystem/driver/qiniu/handler.go b/pkg/filesystem/driver/qiniu/handler.go index f3cb279..3859acd 100644 --- a/pkg/filesystem/driver/qiniu/handler.go +++ b/pkg/filesystem/driver/qiniu/handler.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "io" "net/http" "net/url" "path" @@ -144,7 +143,7 @@ func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, } // Put 将文件流保存到指定目录 -func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error { +func (handler Driver) Put(ctx context.Context, file fsctx.FileHeader) error { defer file.Close() // 凭证有效期 @@ -153,10 +152,10 @@ func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, s // 生成上传策略 putPolicy := storage.PutPolicy{ // 指定为覆盖策略 - Scope: fmt.Sprintf("%s:%s", handler.Policy.BucketName, dst), - SaveKey: dst, + Scope: fmt.Sprintf("%s:%s", handler.Policy.BucketName, file.GetSavePath()), + SaveKey: file.GetSavePath(), ForceSaveKey: true, - FsizeLimit: int64(size), + FsizeLimit: int64(file.GetSize()), } // 是否开启了MIMEType限制 if handler.Policy.OptionsSerialized.MimeType != "" { @@ -178,7 +177,7 @@ func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, s } // 开始上传 - err = formUploader.Put(ctx, &ret, token.Token, dst, file, int64(size), &putExtra) + err = formUploader.Put(ctx, &ret, token.Token, file.GetSavePath(), file, int64(file.GetSize()), &putExtra) if err != nil { return err } @@ -274,25 +273,19 @@ func (handler Driver) signSourceURL(ctx context.Context, path string, ttl int64) } // Token 获取上传策略和认证Token -func (handler Driver) Token(ctx context.Context, TTL int64, uploadSession *serializer.UploadSession) (serializer.UploadCredential, error) { +func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (serializer.UploadCredential, error) { // 生成回调地址 siteURL := model.GetSiteURL() apiBaseURI, _ := url.Parse("/api/v3/callback/qiniu/" + uploadSession.Key) apiURL := siteURL.ResolveReference(apiBaseURI) - // 读取上下文中生成的存储路径 - savePath, ok := ctx.Value(fsctx.SavePathCtx).(string) - if !ok { - return serializer.UploadCredential{}, errors.New("无法获取存储路径") - } - // 创建上传策略 putPolicy := storage.PutPolicy{ Scope: handler.Policy.BucketName, CallbackURL: apiURL.String(), CallbackBody: `{"name":"$(fname)","source_name":"$(key)","size":$(fsize),"pic_info":"$(imageInfo.width),$(imageInfo.height)"}`, CallbackBodyType: "application/json", - SaveKey: savePath, + SaveKey: file.GetSavePath(), ForceSaveKey: true, FsizeLimit: int64(handler.Policy.MaxSize), } @@ -301,7 +294,7 @@ func (handler Driver) Token(ctx context.Context, TTL int64, uploadSession *seria putPolicy.MimeLimit = handler.Policy.OptionsSerialized.MimeType } - return handler.getUploadCredential(ctx, putPolicy, TTL) + return handler.getUploadCredential(ctx, putPolicy, ttl) } // getUploadCredential 签名上传策略 diff --git a/pkg/filesystem/driver/remote/handler.go b/pkg/filesystem/driver/remote/handler.go index 40bd31d..2451b2e 100644 --- a/pkg/filesystem/driver/remote/handler.go +++ b/pkg/filesystem/driver/remote/handler.go @@ -6,7 +6,6 @@ import ( "encoding/json" "errors" "fmt" - "io" "net/http" "net/url" "path" @@ -134,7 +133,7 @@ func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, } // Put 将文件流保存到指定目录 -func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error { +func (handler Driver) Put(ctx context.Context, file fsctx.FileHeader) error { defer file.Close() // 凭证有效期 @@ -142,10 +141,10 @@ func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, s // 生成上传策略 policy := serializer.UploadPolicy{ - SavePath: path.Dir(dst), - FileName: path.Base(dst), + SavePath: path.Dir(file.GetSavePath()), + FileName: path.Base(file.GetSavePath()), AutoRename: false, - MaxSize: size, + MaxSize: file.GetSize(), } credential, err := handler.getUploadCredential(ctx, policy, int64(credentialTTL)) if err != nil { @@ -153,11 +152,11 @@ func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, s } // 对文件名进行URLEncode - fileName := url.QueryEscape(path.Base(dst)) + fileName := url.QueryEscape(path.Base(file.GetSavePath())) // 决定是否要禁用文件覆盖 overwrite := "true" - if ctx.Value(fsctx.DisableOverwrite) != nil { + if file.GetMode() != fsctx.Overwrite { overwrite = "false" } @@ -171,7 +170,7 @@ func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, s "X-Cr-FileName": {fileName}, "X-Cr-Overwrite": {overwrite}, }), - request.WithContentLength(int64(size)), + request.WithContentLength(int64(file.GetSize())), request.WithTimeout(time.Duration(0)), request.WithMasterMeta(), request.WithSlaveMeta(handler.Policy.AccessKey), @@ -305,7 +304,7 @@ func (handler Driver) Source( } // Token 获取上传策略和认证Token -func (handler Driver) Token(ctx context.Context, TTL int64, uploadSession *serializer.UploadSession) (serializer.UploadCredential, error) { +func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (serializer.UploadCredential, error) { // 生成回调地址 siteURL := model.GetSiteURL() apiBaseURI, _ := url.Parse("/api/v3/callback/remote/" + uploadSession.Key) @@ -320,7 +319,7 @@ func (handler Driver) Token(ctx context.Context, TTL int64, uploadSession *seria AllowedExtension: handler.Policy.OptionsSerialized.FileType, CallbackURL: apiURL.String(), } - return handler.getUploadCredential(ctx, policy, TTL) + return handler.getUploadCredential(ctx, policy, ttl) } func (handler Driver) getUploadCredential(ctx context.Context, policy serializer.UploadPolicy, TTL int64) (serializer.UploadCredential, error) { diff --git a/pkg/filesystem/driver/remote/handler_test.go b/pkg/filesystem/driver/remote/handler_test.go index 0a565f7..478b290 100644 --- a/pkg/filesystem/driver/remote/handler_test.go +++ b/pkg/filesystem/driver/remote/handler_test.go @@ -40,7 +40,7 @@ func TestHandler_Token(t *testing.T) { // 成功 { cache.Set("setting_siteURL", "http://test.cloudreve.org", 0) - credential, err := handler.Token(ctx, 10, "123") + credential, err := handler.Token(ctx, 10, "123", nil) asserts.NoError(err) policy, err := serializer.DecodeUploadPolicy(credential.Policy) asserts.NoError(err) diff --git a/pkg/filesystem/driver/s3/handler.go b/pkg/filesystem/driver/s3/handler.go index b924183..e5641de 100644 --- a/pkg/filesystem/driver/s3/handler.go +++ b/pkg/filesystem/driver/s3/handler.go @@ -9,7 +9,6 @@ import ( "encoding/json" "errors" "github.com/cloudreve/Cloudreve/v3/pkg/util" - "io" "net/http" "net/url" "path" @@ -198,7 +197,7 @@ func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, } // Put 将文件流保存到指定目录 -func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error { +func (handler Driver) Put(ctx context.Context, file fsctx.FileHeader) error { // 初始化客户端 if err := handler.InitS3Client(); err != nil { @@ -207,6 +206,7 @@ func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, s uploader := s3manager.NewUploader(handler.sess) + dst := file.GetSavePath() _, err := uploader.Upload(&s3manager.UploadInput{ Bucket: &handler.Policy.BucketName, Key: &dst, @@ -324,14 +324,7 @@ func (handler Driver) Source( } // Token 获取上传策略和认证Token -func (handler Driver) Token(ctx context.Context, TTL int64, uploadSession *serializer.UploadSession) (serializer.UploadCredential, error) { - - // 读取上下文中生成的存储路径和文件大小 - savePath, ok := ctx.Value(fsctx.SavePathCtx).(string) - if !ok { - return serializer.UploadCredential{}, errors.New("无法获取存储路径") - } - +func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (serializer.UploadCredential, error) { // 生成回调地址 siteURL := model.GetSiteURL() apiBaseURI, _ := url.Parse("/api/v3/callback/s3/" + uploadSession.Key) @@ -339,10 +332,10 @@ func (handler Driver) Token(ctx context.Context, TTL int64, uploadSession *seria // 上传策略 putPolicy := UploadPolicy{ - Expiration: time.Now().UTC().Add(time.Duration(TTL) * time.Second).Format(time.RFC3339), + Expiration: time.Now().UTC().Add(time.Duration(ttl) * time.Second).Format(time.RFC3339), Conditions: []interface{}{ map[string]string{"bucket": handler.Policy.BucketName}, - []string{"starts-with", "$key", savePath}, + []string{"starts-with", "$key", file.GetSavePath()}, []string{"starts-with", "$success_action_redirect", apiURL.String()}, []string{"starts-with", "$name", ""}, []string{"starts-with", "$Content-Type", ""}, diff --git a/pkg/filesystem/driver/shadow/masterinslave/handler.go b/pkg/filesystem/driver/shadow/masterinslave/handler.go index 8496750..2382435 100644 --- a/pkg/filesystem/driver/shadow/masterinslave/handler.go +++ b/pkg/filesystem/driver/shadow/masterinslave/handler.go @@ -5,9 +5,9 @@ import ( model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "io" "net/url" ) @@ -27,8 +27,8 @@ func NewDriver(master cluster.Node, handler driver.Handler, policy *model.Policy } } -func (d *Driver) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error { - return d.handler.Put(ctx, file, dst, size) +func (d *Driver) Put(ctx context.Context, file fsctx.FileHeader) error { + return d.handler.Put(ctx, file) } func (d *Driver) Delete(ctx context.Context, files []string) ([]string, error) { @@ -47,7 +47,7 @@ func (d *Driver) Source(ctx context.Context, path string, url url.URL, ttl int64 return "", ErrNotImplemented } -func (d *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession) (serializer.UploadCredential, error) { +func (d *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (serializer.UploadCredential, error) { return serializer.UploadCredential{}, ErrNotImplemented } diff --git a/pkg/filesystem/driver/shadow/slaveinmaster/handler.go b/pkg/filesystem/driver/shadow/slaveinmaster/handler.go index fcb0d1f..ef6d640 100644 --- a/pkg/filesystem/driver/shadow/slaveinmaster/handler.go +++ b/pkg/filesystem/driver/shadow/slaveinmaster/handler.go @@ -13,7 +13,6 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "io" "net/url" "time" ) @@ -50,7 +49,7 @@ func NewDriver(node cluster.Node, handler driver.Handler, policy *model.Policy) } // Put 将ctx中指定的从机物理文件由从机上传到目标存储策略 -func (d *Driver) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error { +func (d *Driver) Put(ctx context.Context, file fsctx.FileHeader) error { src, ok := ctx.Value(fsctx.SlaveSrcPath).(string) if !ok { return ErrSlaveSrcPathNotExist @@ -58,7 +57,7 @@ func (d *Driver) Put(ctx context.Context, file io.ReadCloser, dst string, size u req := serializer.SlaveTransferReq{ Src: src, - Dst: dst, + Dst: file.GetSavePath(), Policy: d.policy, } @@ -112,7 +111,7 @@ func (d *Driver) Source(ctx context.Context, path string, url url.URL, ttl int64 return "", ErrNotImplemented } -func (d *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession) (serializer.UploadCredential, error) { +func (d *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (serializer.UploadCredential, error) { return serializer.UploadCredential{}, ErrNotImplemented } diff --git a/pkg/filesystem/driver/upyun/handler.go b/pkg/filesystem/driver/upyun/handler.go index 3881d1d..1d5f55d 100644 --- a/pkg/filesystem/driver/upyun/handler.go +++ b/pkg/filesystem/driver/upyun/handler.go @@ -9,7 +9,6 @@ import ( "encoding/json" "errors" "fmt" - "io" "net/http" "net/url" "path" @@ -146,7 +145,7 @@ func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, } // Put 将文件流保存到指定目录 -func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error { +func (handler Driver) Put(ctx context.Context, file fsctx.FileHeader) error { defer file.Close() up := upyun.NewUpYun(&upyun.UpYunConfig{ @@ -155,7 +154,7 @@ func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, s Password: handler.Policy.SecretKey, }) err := up.Put(&upyun.PutObjectConfig{ - Path: dst, + Path: file.GetSavePath(), Reader: file, }) @@ -311,17 +310,7 @@ func (handler Driver) signURL(ctx context.Context, path *url.URL, TTL int64) (st } // Token 获取上传策略和认证Token -func (handler Driver) Token(ctx context.Context, TTL int64, uploadSession *serializer.UploadSession) (serializer.UploadCredential, error) { - // 读取上下文中生成的存储路径和文件大小 - savePath, ok := ctx.Value(fsctx.SavePathCtx).(string) - if !ok { - return serializer.UploadCredential{}, errors.New("无法获取存储路径") - } - fileSize, ok := ctx.Value(fsctx.FileSizeCtx).(uint64) - if !ok { - return serializer.UploadCredential{}, errors.New("无法获取文件大小") - } - +func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (serializer.UploadCredential, error) { // 检查文件大小 // 生成回调地址 @@ -333,11 +322,11 @@ func (handler Driver) Token(ctx context.Context, TTL int64, uploadSession *seria putPolicy := UploadPolicy{ Bucket: handler.Policy.BucketName, // TODO escape - SaveKey: savePath, - Expiration: time.Now().Add(time.Duration(TTL) * time.Second).Unix(), + SaveKey: file.GetSavePath(), + Expiration: time.Now().Add(time.Duration(ttl) * time.Second).Unix(), CallbackURL: apiURL.String(), - ContentLength: fileSize, - ContentLengthRange: fmt.Sprintf("0,%d", fileSize), + ContentLength: file.GetSize(), + ContentLengthRange: fmt.Sprintf("0,%d", file.GetSize()), AllowFileType: strings.Join(handler.Policy.OptionsSerialized.FileType, ","), } diff --git a/pkg/filesystem/file.go b/pkg/filesystem/file.go index df9107f..df40c0a 100644 --- a/pkg/filesystem/file.go +++ b/pkg/filesystem/file.go @@ -43,26 +43,25 @@ func (fs *FileSystem) withSpeedLimit(rs response.RSCloser) response.RSCloser { } // AddFile 新增文件记录 -func (fs *FileSystem) AddFile(ctx context.Context, parent *model.Folder) (*model.File, error) { +func (fs *FileSystem) AddFile(ctx context.Context, parent *model.Folder, file fsctx.FileHeader) (*model.File, error) { // 添加文件记录前的钩子 - err := fs.Trigger(ctx, "BeforeAddFile") + err := fs.Trigger(ctx, "BeforeAddFile", file) if err != nil { - if err := fs.Trigger(ctx, "BeforeAddFileFailed"); err != nil { + if err := fs.Trigger(ctx, "BeforeAddFileFailed", file); err != nil { util.Log().Debug("BeforeAddFileFailed 钩子执行失败,%s", err) } return nil, err } - file := ctx.Value(fsctx.FileHeaderCtx).(FileHeader) - filePath := ctx.Value(fsctx.SavePathCtx).(string) - newFile := model.File{ - Name: file.GetFileName(), - SourceName: filePath, - UserID: fs.User.ID, - Size: file.GetSize(), - FolderID: parent.ID, - PolicyID: fs.Policy.ID, + Name: file.GetFileName(), + SourceName: file.GetSavePath(), + UserID: fs.User.ID, + Size: file.GetSize(), + FolderID: parent.ID, + PolicyID: fs.Policy.ID, + Hidden: file.IsHidden(), + MetadataSerialized: file.GetMetadata(), } if fs.Policy.IsThumbExist(file.GetFileName()) { @@ -72,7 +71,7 @@ func (fs *FileSystem) AddFile(ctx context.Context, parent *model.Folder) (*model _, err = newFile.Create() if err != nil { - if err := fs.Trigger(ctx, "AfterValidateFailed"); err != nil { + if err := fs.Trigger(ctx, "AfterValidateFailed", file); err != nil { util.Log().Debug("AfterValidateFailed 钩子执行失败,%s", err) } return nil, ErrFileExisted.WithError(err) @@ -153,14 +152,7 @@ func (fs *FileSystem) GetDownloadContent(ctx context.Context, id uint) (response // GetContent 获取文件内容,path为虚拟路径 func (fs *FileSystem) GetContent(ctx context.Context, id uint) (response.RSCloser, error) { - // 触发`下载前`钩子 - err := fs.Trigger(ctx, "BeforeFileDownload") - if err != nil { - util.Log().Debug("BeforeFileDownload 钩子执行失败,%s", err) - return nil, err - } - - err = fs.resetFileIDIfNotExist(ctx, id) + err := fs.resetFileIDIfNotExist(ctx, id) if err != nil { return nil, err } diff --git a/pkg/filesystem/file_test.go b/pkg/filesystem/file_test.go index 14f43cd..cb3d24f 100644 --- a/pkg/filesystem/file_test.go +++ b/pkg/filesystem/file_test.go @@ -9,7 +9,6 @@ import ( model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/util" @@ -19,7 +18,7 @@ import ( func TestFileSystem_AddFile(t *testing.T) { asserts := assert.New(t) - file := local.FileStream{ + file := fsctx.FileStream{ Size: 5, Name: "1.png", } diff --git a/pkg/filesystem/filesystem.go b/pkg/filesystem/filesystem.go index 8e45ccd..8597960 100644 --- a/pkg/filesystem/filesystem.go +++ b/pkg/filesystem/filesystem.go @@ -3,18 +3,11 @@ package filesystem import ( "errors" "fmt" - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/masterinslave" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/slaveinmaster" - "io" - "net/http" - "net/url" - "sync" - model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/auth" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/conf" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/cos" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/onedrive" @@ -22,11 +15,16 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/qiniu" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/remote" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/s3" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/masterinslave" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/slaveinmaster" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/upyun" "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/gin-gonic/gin" cossdk "github.com/tencentyun/cos-go-sdk-v5" + "net/http" + "net/url" + "sync" ) // FSPool 文件系统资源池 @@ -36,16 +34,6 @@ var FSPool = sync.Pool{ }, } -// FileHeader 上传来的文件数据处理器 -type FileHeader interface { - io.Reader - io.Closer - GetSize() uint64 - GetMIMEType() string - GetFileName() string - GetVirtualPath() string -} - // FileSystem 管理文件的文件系统 type FileSystem struct { // 文件系统所有者 diff --git a/pkg/filesystem/fsctx/context.go b/pkg/filesystem/fsctx/context.go index d280638..b334435 100644 --- a/pkg/filesystem/fsctx/context.go +++ b/pkg/filesystem/fsctx/context.go @@ -39,8 +39,6 @@ const ( CancelFuncCtx // ValidateCapacityOnceCtx 限定归还容量的操作只執行一次 ValidateCapacityOnceCtx - // 禁止上传时同名覆盖操作 - DisableOverwrite // 文件在从机节点中的路径 SlaveSrcPath ) diff --git a/pkg/filesystem/fsctx/stream.go b/pkg/filesystem/fsctx/stream.go new file mode 100644 index 0000000..573eef3 --- /dev/null +++ b/pkg/filesystem/fsctx/stream.go @@ -0,0 +1,87 @@ +package fsctx + +import ( + "io" + "time" +) + +type WriteMode int + +const ( + Overwrite WriteMode = iota + Append + Create +) + +// FileStream 用户传来的文件 +type FileStream struct { + Mode WriteMode + Hidden bool + LastModified time.Time + Metadata map[string]string + File io.ReadCloser + Size uint64 + VirtualPath string + Name string + MIMEType string + SavePath string +} + +func (file *FileStream) Read(p []byte) (n int, err error) { + return file.File.Read(p) +} + +func (file *FileStream) GetMIMEType() string { + return file.MIMEType +} + +func (file *FileStream) GetSize() uint64 { + return file.Size +} + +func (file *FileStream) Close() error { + return file.File.Close() +} + +func (file *FileStream) GetFileName() string { + return file.Name +} + +func (file *FileStream) GetVirtualPath() string { + return file.VirtualPath +} + +func (file *FileStream) GetMode() WriteMode { + return file.Mode +} + +func (file *FileStream) GetMetadata() map[string]string { + return file.Metadata +} + +func (file *FileStream) GetLastModified() time.Time { + return file.LastModified +} + +func (file *FileStream) IsHidden() bool { + return file.Hidden +} + +func (file *FileStream) GetSavePath() string { + return file.SavePath +} + +// FileHeader 上传来的文件数据处理器 +type FileHeader interface { + io.Reader + io.Closer + GetSize() uint64 + GetMIMEType() string + GetFileName() string + GetVirtualPath() string + GetMode() WriteMode + GetMetadata() map[string]string + GetLastModified() time.Time + IsHidden() bool + GetSavePath() string +} diff --git a/pkg/filesystem/driver/local/file_test.go b/pkg/filesystem/fsctx/stream_test.go similarity index 98% rename from pkg/filesystem/driver/local/file_test.go rename to pkg/filesystem/fsctx/stream_test.go index b368764..8cc0c85 100644 --- a/pkg/filesystem/driver/local/file_test.go +++ b/pkg/filesystem/fsctx/stream_test.go @@ -1,4 +1,4 @@ -package local +package fsctx import ( "github.com/stretchr/testify/assert" diff --git a/pkg/filesystem/hooks.go b/pkg/filesystem/hooks.go index e9eec53..a6fbbfd 100644 --- a/pkg/filesystem/hooks.go +++ b/pkg/filesystem/hooks.go @@ -16,7 +16,7 @@ import ( ) // Hook 钩子函数 -type Hook func(ctx context.Context, fs *FileSystem) error +type Hook func(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error // Use 注入钩子 func (fs *FileSystem) Use(name string, hook Hook) { @@ -41,10 +41,10 @@ func (fs *FileSystem) CleanHooks(name string) { // Trigger 触发钩子,遇到第一个错误时 // 返回错误,后续钩子不会继续执行 -func (fs *FileSystem) Trigger(ctx context.Context, name string) error { +func (fs *FileSystem) Trigger(ctx context.Context, name string, file fsctx.FileHeader) error { if hooks, ok := fs.Hooks[name]; ok { for _, hook := range hooks { - err := hook(ctx, fs) + err := hook(ctx, fs, file) if err != nil { util.Log().Warning("钩子执行失败:%s", err) return err @@ -54,18 +54,8 @@ func (fs *FileSystem) Trigger(ctx context.Context, name string) error { return nil } -// HookIsFileExist 检查虚拟路径文件是否存在 -func HookIsFileExist(ctx context.Context, fs *FileSystem) error { - filePath := ctx.Value(fsctx.PathCtx).(string) - if ok, _ := fs.IsFileExist(filePath); ok { - return nil - } - return ErrObjectNotExist -} - // HookSlaveUploadValidate Slave模式下对文件上传的一系列验证 -func HookSlaveUploadValidate(ctx context.Context, fs *FileSystem) error { - file := ctx.Value(fsctx.FileHeaderCtx).(FileHeader) +func HookSlaveUploadValidate(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { policy := ctx.Value(fsctx.UploadPolicyCtx).(serializer.UploadPolicy) // 验证单文件尺寸 @@ -89,9 +79,7 @@ func HookSlaveUploadValidate(ctx context.Context, fs *FileSystem) error { } // HookValidateFile 一系列对文件检验的集合 -func HookValidateFile(ctx context.Context, fs *FileSystem) error { - file := ctx.Value(fsctx.FileHeaderCtx).(FileHeader) - +func HookValidateFile(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { // 验证单文件尺寸 if !fs.ValidateFileSize(ctx, file.GetSize()) { return ErrFileSizeTooBig @@ -112,7 +100,7 @@ func HookValidateFile(ctx context.Context, fs *FileSystem) error { } // HookResetPolicy 重设存储策略为上下文已有文件 -func HookResetPolicy(ctx context.Context, fs *FileSystem) error { +func HookResetPolicy(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { originFile, ok := ctx.Value(fsctx.FileModelCtx).(model.File) if !ok { return ErrObjectNotExist @@ -123,8 +111,7 @@ func HookResetPolicy(ctx context.Context, fs *FileSystem) error { } // HookValidateCapacity 验证并扣除用户容量,包含数据库操作 -func HookValidateCapacity(ctx context.Context, fs *FileSystem) error { - file := ctx.Value(fsctx.FileHeaderCtx).(FileHeader) +func HookValidateCapacity(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { // 验证并扣除容量 if !fs.ValidateCapacity(ctx, file.GetSize()) { return ErrInsufficientCapacity @@ -133,8 +120,7 @@ func HookValidateCapacity(ctx context.Context, fs *FileSystem) error { } // HookValidateCapacityWithoutIncrease 验证用户容量,不扣除 -func HookValidateCapacityWithoutIncrease(ctx context.Context, fs *FileSystem) error { - file := ctx.Value(fsctx.FileHeaderCtx).(FileHeader) +func HookValidateCapacityWithoutIncrease(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { // 验证并扣除容量 if fs.User.GetRemainingCapacity() < file.GetSize() { return ErrInsufficientCapacity @@ -143,8 +129,7 @@ func HookValidateCapacityWithoutIncrease(ctx context.Context, fs *FileSystem) er } // HookChangeCapacity 根据原有文件和新文件的大小更新用户容量 -func HookChangeCapacity(ctx context.Context, fs *FileSystem) error { - newFile := ctx.Value(fsctx.FileHeaderCtx).(FileHeader) +func HookChangeCapacity(ctx context.Context, fs *FileSystem, newFile fsctx.FileHeader) error { originFile := ctx.Value(fsctx.FileModelCtx).(model.File) if newFile.GetSize() > originFile.Size { @@ -159,10 +144,9 @@ func HookChangeCapacity(ctx context.Context, fs *FileSystem) error { } // HookDeleteTempFile 删除已保存的临时文件 -func HookDeleteTempFile(ctx context.Context, fs *FileSystem) error { - filePath := ctx.Value(fsctx.SavePathCtx).(string) +func HookDeleteTempFile(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { // 删除临时文件 - _, err := fs.Handler.Delete(ctx, []string{filePath}) + _, err := fs.Handler.Delete(ctx, []string{file.GetSavePath()}) if err != nil { util.Log().Warning("无法清理上传临时文件,%s", err) } @@ -171,14 +155,17 @@ func HookDeleteTempFile(ctx context.Context, fs *FileSystem) error { } // HookCleanFileContent 清空文件内容 -func HookCleanFileContent(ctx context.Context, fs *FileSystem) error { - filePath := ctx.Value(fsctx.SavePathCtx).(string) +func HookCleanFileContent(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { // 清空内容 - return fs.Handler.Put(ctx, ioutil.NopCloser(strings.NewReader("")), filePath, 0) + return fs.Handler.Put(ctx, &fsctx.FileStream{ + File: ioutil.NopCloser(strings.NewReader("")), + SavePath: file.GetSavePath(), + Size: 0, + }) } // HookClearFileSize 将原始文件的尺寸设为0 -func HookClearFileSize(ctx context.Context, fs *FileSystem) error { +func HookClearFileSize(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { originFile, ok := ctx.Value(fsctx.FileModelCtx).(model.File) if !ok { return ErrObjectNotExist @@ -187,7 +174,7 @@ func HookClearFileSize(ctx context.Context, fs *FileSystem) error { } // HookCancelContext 取消上下文 -func HookCancelContext(ctx context.Context, fs *FileSystem) error { +func HookCancelContext(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { cancelFunc, ok := ctx.Value(fsctx.CancelFuncCtx).(context.CancelFunc) if ok { cancelFunc() @@ -196,8 +183,7 @@ func HookCancelContext(ctx context.Context, fs *FileSystem) error { } // HookGiveBackCapacity 归还用户容量 -func HookGiveBackCapacity(ctx context.Context, fs *FileSystem) error { - file := ctx.Value(fsctx.FileHeaderCtx).(FileHeader) +func HookGiveBackCapacity(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { once, ok := ctx.Value(fsctx.ValidateCapacityOnceCtx).(*sync.Once) if !ok { once = &sync.Once{} @@ -217,7 +203,7 @@ func HookGiveBackCapacity(ctx context.Context, fs *FileSystem) error { // HookUpdateSourceName 更新文件SourceName // TODO:测试 -func HookUpdateSourceName(ctx context.Context, fs *FileSystem) error { +func HookUpdateSourceName(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { originFile, ok := ctx.Value(fsctx.FileModelCtx).(model.File) if !ok { return ErrObjectNotExist @@ -226,7 +212,7 @@ func HookUpdateSourceName(ctx context.Context, fs *FileSystem) error { } // GenericAfterUpdate 文件内容更新后 -func GenericAfterUpdate(ctx context.Context, fs *FileSystem) error { +func GenericAfterUpdate(ctx context.Context, fs *FileSystem, newFile fsctx.FileHeader) error { // 更新文件尺寸 originFile, ok := ctx.Value(fsctx.FileModelCtx).(model.File) if !ok { @@ -235,10 +221,6 @@ func GenericAfterUpdate(ctx context.Context, fs *FileSystem) error { fs.SetTargetFile(&[]model.File{originFile}) - newFile, ok := ctx.Value(fsctx.FileHeaderCtx).(FileHeader) - if !ok { - return ErrObjectNotExist - } err := originFile.UpdateSize(newFile.GetSize()) if err != nil { return err @@ -260,14 +242,13 @@ func GenericAfterUpdate(ctx context.Context, fs *FileSystem) error { } // SlaveAfterUpload Slave模式下上传完成钩子 -func SlaveAfterUpload(ctx context.Context, fs *FileSystem) error { - fileHeader := ctx.Value(fsctx.FileHeaderCtx).(FileHeader) +func SlaveAfterUpload(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error { policy := ctx.Value(fsctx.UploadPolicyCtx).(serializer.UploadPolicy) // 构造一个model.File,用于生成缩略图 file := model.File{ Name: fileHeader.GetFileName(), - SourceName: ctx.Value(fsctx.SavePathCtx).(string), + SourceName: fileHeader.GetSavePath(), } fs.GenerateThumbnail(ctx, &file) @@ -286,9 +267,9 @@ func SlaveAfterUpload(ctx context.Context, fs *FileSystem) error { } // GenericAfterUpload 文件上传完成后,包含数据库操作 -func GenericAfterUpload(ctx context.Context, fs *FileSystem) error { +func GenericAfterUpload(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error { // 文件存放的虚拟路径 - virtualPath := ctx.Value(fsctx.FileHeaderCtx).(FileHeader).GetVirtualPath() + virtualPath := fileHeader.GetVirtualPath() // 检查路径是否存在,不存在就创建 isExist, folder := fs.IsPathExist(virtualPath) @@ -303,13 +284,13 @@ func GenericAfterUpload(ctx context.Context, fs *FileSystem) error { // 检查文件是否存在 if ok, _ := fs.IsChildFileExist( folder, - ctx.Value(fsctx.FileHeaderCtx).(FileHeader).GetFileName(), + ctx.Value(fsctx.FileHeaderCtx).(fsctx.FileHeader).GetFileName(), ); ok { return ErrFileExisted } // 向数据库中插入记录 - file, err := fs.AddFile(ctx, folder) + file, err := fs.AddFile(ctx, folder, fileHeader) if err != nil { return ErrInsertFileRecord } diff --git a/pkg/filesystem/hooks_test.go b/pkg/filesystem/hooks_test.go index 03fe325..40f668b 100644 --- a/pkg/filesystem/hooks_test.go +++ b/pkg/filesystem/hooks_test.go @@ -26,7 +26,7 @@ import ( func TestGenericBeforeUpload(t *testing.T) { asserts := assert.New(t) - file := local.FileStream{ + file := fsctx.FileStream{ Size: 5, Name: "1.txt", } @@ -68,7 +68,7 @@ func TestGenericAfterUploadCanceled(t *testing.T) { f, err := os.Create("TestGenericAfterUploadCanceled") asserts.NoError(err) f.Close() - file := local.FileStream{ + file := fsctx.FileStream{ Size: 5, Name: "TestGenericAfterUploadCanceled", } @@ -110,7 +110,7 @@ func TestGenericAfterUpload(t *testing.T) { }, } - ctx := context.WithValue(context.Background(), fsctx.FileHeaderCtx, local.FileStream{ + ctx := context.WithValue(context.Background(), fsctx.FileHeaderCtx, fsctx.FileStream{ VirtualPath: "/我的文件", Name: "test.txt", }) @@ -277,7 +277,7 @@ func TestHookValidateCapacity(t *testing.T) { MaxStorage: 11, }, }} - ctx := context.WithValue(context.Background(), fsctx.FileHeaderCtx, local.FileStream{Size: 10}) + ctx := context.WithValue(context.Background(), fsctx.FileHeaderCtx, fsctx.FileStream{Size: 10}) { err := HookValidateCapacity(ctx, fs) asserts.NoError(err) @@ -325,7 +325,7 @@ func TestHookChangeCapacity(t *testing.T) { Model: gorm.Model{ID: 1}, }} - newFile := local.FileStream{Size: 10} + newFile := fsctx.FileStream{Size: 10} oldFile := model.File{Size: 9} ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, oldFile) ctx = context.WithValue(ctx, fsctx.FileHeaderCtx, newFile) @@ -340,7 +340,7 @@ func TestHookChangeCapacity(t *testing.T) { Group: model.Group{MaxStorage: 1}, }} - newFile := local.FileStream{Size: 10} + newFile := fsctx.FileStream{Size: 10} oldFile := model.File{Size: 9} ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, oldFile) ctx = context.WithValue(ctx, fsctx.FileHeaderCtx, newFile) @@ -359,7 +359,7 @@ func TestHookChangeCapacity(t *testing.T) { Storage: 1, }} - newFile := local.FileStream{Size: 9} + newFile := fsctx.FileStream{Size: 9} oldFile := model.File{Size: 10} ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, oldFile) ctx = context.WithValue(ctx, fsctx.FileHeaderCtx, newFile) @@ -457,7 +457,7 @@ func TestGenericAfterUpdate(t *testing.T) { Model: gorm.Model{ID: 1}, PicInfo: "1,1", } - newFile := local.FileStream{Size: 10} + newFile := fsctx.FileStream{Size: 10} ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, originFile) ctx = context.WithValue(ctx, fsctx.FileHeaderCtx, newFile) @@ -489,7 +489,7 @@ func TestGenericAfterUpdate(t *testing.T) { // 原始文件上下文不存在 { - newFile := local.FileStream{Size: 10} + newFile := fsctx.FileStream{Size: 10} ctx := context.WithValue(context.Background(), fsctx.FileHeaderCtx, newFile) err := GenericAfterUpdate(ctx, fs) asserts.Error(err) @@ -502,7 +502,7 @@ func TestGenericAfterUpdate(t *testing.T) { Model: gorm.Model{ID: 1}, PicInfo: "1,1", } - newFile := local.FileStream{Size: 10} + newFile := fsctx.FileStream{Size: 10} ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, originFile) ctx = context.WithValue(ctx, fsctx.FileHeaderCtx, newFile) @@ -533,7 +533,7 @@ func TestHookSlaveUploadValidate(t *testing.T) { MaxSize: 10, AllowedExtension: nil, } - file := local.FileStream{Name: "1.txt", Size: 10} + file := fsctx.FileStream{Name: "1.txt", Size: 10} ctx := context.WithValue(context.Background(), fsctx.UploadPolicyCtx, policy) ctx = context.WithValue(ctx, fsctx.FileHeaderCtx, file) asserts.NoError(HookSlaveUploadValidate(ctx, fs)) @@ -546,7 +546,7 @@ func TestHookSlaveUploadValidate(t *testing.T) { MaxSize: 10, AllowedExtension: nil, } - file := local.FileStream{Name: "1.txt", Size: 11} + file := fsctx.FileStream{Name: "1.txt", Size: 11} ctx := context.WithValue(context.Background(), fsctx.UploadPolicyCtx, policy) ctx = context.WithValue(ctx, fsctx.FileHeaderCtx, file) asserts.Equal(ErrFileSizeTooBig, HookSlaveUploadValidate(ctx, fs)) @@ -559,7 +559,7 @@ func TestHookSlaveUploadValidate(t *testing.T) { MaxSize: 10, AllowedExtension: nil, } - file := local.FileStream{Name: "/1.txt", Size: 10} + file := fsctx.FileStream{Name: "/1.txt", Size: 10} ctx := context.WithValue(context.Background(), fsctx.UploadPolicyCtx, policy) ctx = context.WithValue(ctx, fsctx.FileHeaderCtx, file) asserts.Equal(ErrIllegalObjectName, HookSlaveUploadValidate(ctx, fs)) @@ -572,7 +572,7 @@ func TestHookSlaveUploadValidate(t *testing.T) { MaxSize: 10, AllowedExtension: []string{"jpg"}, } - file := local.FileStream{Name: "1.txt", Size: 10} + file := fsctx.FileStream{Name: "1.txt", Size: 10} ctx := context.WithValue(context.Background(), fsctx.UploadPolicyCtx, policy) ctx = context.WithValue(ctx, fsctx.FileHeaderCtx, file) asserts.Equal(ErrFileExtensionNotAllowed, HookSlaveUploadValidate(ctx, fs)) @@ -613,7 +613,7 @@ func TestSlaveAfterUpload(t *testing.T) { }, }) request.GeneralClient = clientMock - ctx := context.WithValue(context.Background(), fsctx.FileHeaderCtx, local.FileStream{ + ctx := context.WithValue(context.Background(), fsctx.FileHeaderCtx, fsctx.FileStream{ Size: 10, VirtualPath: "/my", Name: "test.txt", @@ -689,7 +689,7 @@ func TestHookGiveBackCapacity(t *testing.T) { Storage: 10, }, } - ctx := context.WithValue(context.Background(), fsctx.FileHeaderCtx, local.FileStream{Size: 1}) + ctx := context.WithValue(context.Background(), fsctx.FileHeaderCtx, fsctx.FileStream{Size: 1}) // without once limit { @@ -718,7 +718,7 @@ func TestHookValidateCapacityWithoutIncrease(t *testing.T) { Group: model.Group{}, }, } - ctx := context.WithValue(context.Background(), fsctx.FileHeaderCtx, local.FileStream{Size: 1}) + ctx := context.WithValue(context.Background(), fsctx.FileHeaderCtx, fsctx.FileStream{Size: 1}) // not enough { diff --git a/pkg/filesystem/upload.go b/pkg/filesystem/upload.go index 6fa348d..a4d83ee 100644 --- a/pkg/filesystem/upload.go +++ b/pkg/filesystem/upload.go @@ -2,13 +2,11 @@ package filesystem import ( "context" - "io" "os" "path" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" @@ -23,11 +21,11 @@ import ( */ // Upload 上传文件 -func (fs *FileSystem) Upload(ctx context.Context, file FileHeader) (err error) { +func (fs *FileSystem) Upload(ctx context.Context, file *fsctx.FileStream) (err error) { ctx = context.WithValue(ctx, fsctx.FileHeaderCtx, file) // 上传前的钩子 - err = fs.Trigger(ctx, "BeforeUpload") + err = fs.Trigger(ctx, "BeforeUpload", file) if err != nil { request.BlackHole(file) return err @@ -41,24 +39,24 @@ func (fs *FileSystem) Upload(ctx context.Context, file FileHeader) (err error) { } else { savePath = fs.GenerateSavePath(ctx, file) } - ctx = context.WithValue(ctx, fsctx.SavePathCtx, savePath) + file.SavePath = savePath // 处理客户端未完成上传时,关闭连接 go fs.CancelUpload(ctx, savePath, file) // 保存文件 - err = fs.Handler.Put(ctx, file, savePath, file.GetSize()) + err = fs.Handler.Put(ctx, file) if err != nil { - fs.Trigger(ctx, "AfterUploadFailed") + fs.Trigger(ctx, "AfterUploadFailed", file) return err } // 上传完成后的钩子 - err = fs.Trigger(ctx, "AfterUpload") + err = fs.Trigger(ctx, "AfterUpload", file) if err != nil { // 上传完成后续处理失败 - followUpErr := fs.Trigger(ctx, "AfterValidateFailed") + followUpErr := fs.Trigger(ctx, "AfterValidateFailed", file) // 失败后再失败... if followUpErr != nil { util.Log().Debug("AfterValidateFailed 钩子执行失败,%s", followUpErr) @@ -79,7 +77,7 @@ func (fs *FileSystem) Upload(ctx context.Context, file FileHeader) (err error) { // GenerateSavePath 生成要存放文件的路径 // TODO 完善测试 -func (fs *FileSystem) GenerateSavePath(ctx context.Context, file FileHeader) string { +func (fs *FileSystem) GenerateSavePath(ctx context.Context, file fsctx.FileHeader) string { if fs.User.Model.ID != 0 { return path.Join( fs.Policy.GeneratePath( @@ -116,7 +114,7 @@ func (fs *FileSystem) GenerateSavePath(ctx context.Context, file FileHeader) str } // CancelUpload 监测客户端取消上传 -func (fs *FileSystem) CancelUpload(ctx context.Context, path string, file FileHeader) { +func (fs *FileSystem) CancelUpload(ctx context.Context, path string, file fsctx.FileHeader) { var reqContext context.Context if ginCtx, ok := ctx.Value(fsctx.GinCtx).(*gin.Context); ok { reqContext = ginCtx.Request.Context() @@ -137,8 +135,7 @@ func (fs *FileSystem) CancelUpload(ctx context.Context, path string, file FileHe if fs.Hooks["AfterUploadCanceled"] == nil { return } - ctx = context.WithValue(ctx, fsctx.SavePathCtx, path) - err := fs.Trigger(ctx, "AfterUploadCanceled") + err := fs.Trigger(ctx, "AfterUploadCanceled", file) if err != nil { util.Log().Debug("执行 AfterUploadCanceled 钩子出错,%s", err) } @@ -157,23 +154,22 @@ func (fs *FileSystem) CreateUploadSession(ctx context.Context, path string, size // 进行文件上传预检查 - // 创建上下文环境 - ctx = context.WithValue(ctx, fsctx.FileHeaderCtx, local.FileStream{ + file := &fsctx.FileStream{ Size: size, Name: name, - }) + } // 检查上传请求合法性 - if err := HookValidateFile(ctx, fs); err != nil { + if err := HookValidateFile(ctx, fs, file); err != nil { return nil, err } - if err := HookValidateCapacityWithoutIncrease(ctx, fs); err != nil { + if err := HookValidateCapacityWithoutIncrease(ctx, fs, file); err != nil { return nil, err } // 生成存储路径 - savePath := fs.GenerateSavePath(ctx, local.FileStream{Name: name, VirtualPath: path}) + savePath := fs.GenerateSavePath(ctx, &fsctx.FileStream{Name: name, VirtualPath: path}) callbackKey := uuid.Must(uuid.NewV4()).String() uploadSession := &serializer.UploadSession{ @@ -188,7 +184,7 @@ func (fs *FileSystem) CreateUploadSession(ctx context.Context, path string, size } // 获取上传凭证 - credential, err := fs.Handler.Token(ctx, int64(credentialTTL), uploadSession) + credential, err := fs.Handler.Token(ctx, int64(credentialTTL), uploadSession, file) if err != nil { return nil, serializer.NewError(serializer.CodeEncryptError, "无法获取上传凭证", err) } @@ -207,17 +203,7 @@ func (fs *FileSystem) CreateUploadSession(ctx context.Context, path string, size } // UploadFromStream 从文件流上传文件 -func (fs *FileSystem) UploadFromStream(ctx context.Context, src io.ReadCloser, dst string, size uint64) error { - // 构建文件头 - fileName := path.Base(dst) - filePath := path.Dir(dst) - fileData := local.FileStream{ - File: src, - Size: size, - Name: fileName, - VirtualPath: filePath, - } - +func (fs *FileSystem) UploadFromStream(ctx context.Context, file *fsctx.FileStream) error { // 给文件系统分配钩子 fs.Lock.Lock() if fs.Hooks == nil { @@ -233,11 +219,11 @@ func (fs *FileSystem) UploadFromStream(ctx context.Context, src io.ReadCloser, d fs.Lock.Unlock() // 开始上传 - return fs.Upload(ctx, fileData) + return fs.Upload(ctx, file) } // UploadFromPath 将本机已有文件上传到用户的文件系统 -func (fs *FileSystem) UploadFromPath(ctx context.Context, src, dst string, resetPolicy bool) error { +func (fs *FileSystem) UploadFromPath(ctx context.Context, src, dst string, resetPolicy bool, mode fsctx.WriteMode) error { // 重设存储策略 if resetPolicy { fs.Policy = &fs.User.Policy @@ -261,5 +247,11 @@ func (fs *FileSystem) UploadFromPath(ctx context.Context, src, dst string, reset size := fi.Size() // 开始上传 - return fs.UploadFromStream(ctx, file, dst, uint64(size)) + return fs.UploadFromStream(ctx, &fsctx.FileStream{ + File: nil, + Size: uint64(size), + Name: path.Base(dst), + VirtualPath: path.Dir(dst), + Mode: mode, + }) } diff --git a/pkg/filesystem/upload_test.go b/pkg/filesystem/upload_test.go index 0ead828..74f7abe 100644 --- a/pkg/filesystem/upload_test.go +++ b/pkg/filesystem/upload_test.go @@ -13,7 +13,6 @@ import ( model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" @@ -85,7 +84,7 @@ func TestFileSystem_Upload(t *testing.T) { c.Request, _ = http.NewRequest("POST", "/", nil) ctx = context.WithValue(ctx, fsctx.GinCtx, c) cancel() - file := local.FileStream{ + file := fsctx.FileStream{ Size: 5, VirtualPath: "/", Name: "1.txt", @@ -114,7 +113,7 @@ func TestFileSystem_Upload(t *testing.T) { ctx = context.WithValue(ctx, fsctx.GinCtx, c) ctx = context.WithValue(ctx, fsctx.FileModelCtx, model.File{SourceName: "123/123.txt"}) cancel() - file = local.FileStream{ + file = fsctx.FileStream{ Size: 5, VirtualPath: "/", Name: "1.txt", @@ -168,7 +167,7 @@ func TestFileSystem_GenerateSavePath_Anonymous(t *testing.T) { }, ) - savePath := fs.GenerateSavePath(ctx, local.FileStream{ + savePath := fs.GenerateSavePath(ctx, fsctx.FileStream{ Name: "test.test", }) asserts.Len(savePath, 26) diff --git a/pkg/task/compress.go b/pkg/task/compress.go index 95b06d5..cf45ace 100644 --- a/pkg/task/compress.go +++ b/pkg/task/compress.go @@ -7,6 +7,7 @@ import ( model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/util" ) @@ -106,7 +107,7 @@ func (job *CompressTask) Do() { job.TaskModel.SetProgress(TransferringProgress) // 上传文件 - err = fs.UploadFromPath(ctx, zipFile, job.TaskProps.Dst, true) + err = fs.UploadFromPath(ctx, zipFile, job.TaskProps.Dst, true, fsctx.Create) if err != nil { job.SetErrorMsg(err.Error()) return diff --git a/pkg/task/decompress.go b/pkg/task/decompress.go index 2a06d6a..2c545e9 100644 --- a/pkg/task/decompress.go +++ b/pkg/task/decompress.go @@ -6,7 +6,6 @@ import ( model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" ) // DecompressTask 文件压缩任务 @@ -83,11 +82,7 @@ func (job *DecompressTask) Do() { job.TaskModel.SetProgress(DecompressingProgress) - // 禁止重名覆盖 - ctx := context.Background() - ctx = context.WithValue(ctx, fsctx.DisableOverwrite, true) - - err = fs.Decompress(ctx, job.TaskProps.Src, job.TaskProps.Dst) + err = fs.Decompress(context.Background(), job.TaskProps.Src, job.TaskProps.Dst) if err != nil { job.SetErrorMsg("解压缩失败", err) return diff --git a/pkg/task/import.go b/pkg/task/import.go index 7a94d67..afdaaa2 100644 --- a/pkg/task/import.go +++ b/pkg/task/import.go @@ -7,7 +7,6 @@ import ( model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/util" ) @@ -134,13 +133,12 @@ func (job *ImportTask) Do() { if !object.IsDir { // 创建文件信息 virtualPath := path.Dir(path.Join(job.TaskProps.Dst, object.RelativePath)) - fileHeader := local.FileStream{ + fileHeader := fsctx.FileStream{ Size: object.Size, VirtualPath: virtualPath, Name: object.Name, + SavePath: object.Source, } - addFileCtx := context.WithValue(ctx, fsctx.FileHeaderCtx, fileHeader) - addFileCtx = context.WithValue(addFileCtx, fsctx.SavePathCtx, object.Source) // 查找父目录 parentFolder := &model.Folder{} @@ -162,7 +160,7 @@ func (job *ImportTask) Do() { } // 插入文件记录 - _, err := fs.AddFile(addFileCtx, parentFolder) + _, err := fs.AddFile(context.Background(), parentFolder, &fileHeader) if err != nil { util.Log().Warning("导入任务无法创插入文件[%s], %s", object.RelativePath, err) diff --git a/pkg/task/slavetask/transfer.go b/pkg/task/slavetask/transfer.go index 698cc11..c9653cf 100644 --- a/pkg/task/slavetask/transfer.go +++ b/pkg/task/slavetask/transfer.go @@ -100,7 +100,6 @@ func (job *TransferTask) Do() { } fs.SwitchToShadowHandler(master.Instance, master.URL.String(), master.ID) - ctx := context.WithValue(context.Background(), fsctx.DisableOverwrite, true) file, err := os.Open(util.RelativePath(job.Req.Src)) if err != nil { job.SetErrorMsg("无法读取源文件", err) @@ -118,7 +117,12 @@ func (job *TransferTask) Do() { size := fi.Size() - err = fs.Handler.Put(ctx, file, job.Req.Dst, uint64(size)) + err = fs.Handler.Put(context.Background(), &fsctx.FileStream{ + File: file, + Mode: fsctx.Create, + SavePath: job.Req.Dst, + Size: uint64(size), + }) if err != nil { job.SetErrorMsg("文件上传失败", err) return diff --git a/pkg/task/tranfer.go b/pkg/task/tranfer.go index 5db638d..33ed4e2 100644 --- a/pkg/task/tranfer.go +++ b/pkg/task/tranfer.go @@ -107,8 +107,6 @@ func (job *TransferTask) Do() { dst = path.Join(job.TaskProps.Dst, strings.TrimPrefix(src, trim)) } - ctx := context.WithValue(context.Background(), fsctx.DisableOverwrite, true) - ctx = context.WithValue(ctx, fsctx.SlaveSrcPath, file) if job.TaskProps.NodeID > 1 { // 指定为从机中转 @@ -120,10 +118,16 @@ func (job *TransferTask) Do() { // 切换为从机节点处理上传 fs.SwitchToSlaveHandler(node) - err = fs.UploadFromStream(ctx, nil, dst, job.TaskProps.SrcSizes[file]) + err = fs.UploadFromStream(context.Background(), &fsctx.FileStream{ + File: nil, + Size: job.TaskProps.SrcSizes[file], + Name: path.Base(dst), + VirtualPath: path.Dir(dst), + Mode: fsctx.Create, + }) } else { // 主机节点中转 - err = fs.UploadFromPath(ctx, file, dst, true) + err = fs.UploadFromPath(context.Background(), file, dst, true, fsctx.Create) } if err != nil { diff --git a/pkg/webdav/webdav.go b/pkg/webdav/webdav.go index ae4b5cb..350f18f 100644 --- a/pkg/webdav/webdav.go +++ b/pkg/webdav/webdav.go @@ -19,7 +19,6 @@ import ( model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/util" ) @@ -325,7 +324,7 @@ func (h *Handler) handlePut(w http.ResponseWriter, r *http.Request, fs *filesyst } fileName := path.Base(reqPath) filePath := path.Dir(reqPath) - fileData := local.FileStream{ + fileData := fsctx.FileStream{ MIMEType: r.Header.Get("Content-Type"), File: r.Body, Size: fileSize, @@ -342,7 +341,7 @@ func (h *Handler) handlePut(w http.ResponseWriter, r *http.Request, fs *filesyst fileList, err := model.RemoveFilesWithSoftLinks([]model.File{*originFile}) if err == nil && len(fileList) == 0 { // 如果包含软连接,应重新生成新文件副本,并更新source_name - originFile.SourceName = fs.GenerateSavePath(ctx, fileData) + originFile.SourceName = fs.GenerateSavePath(ctx, &fileData) fs.Use("AfterUpload", filesystem.HookUpdateSourceName) fs.Use("AfterUploadCanceled", filesystem.HookUpdateSourceName) fs.Use("AfterValidateFailed", filesystem.HookUpdateSourceName) @@ -373,11 +372,11 @@ func (h *Handler) handlePut(w http.ResponseWriter, r *http.Request, fs *filesyst fs.Use("AfterUploadFailed", filesystem.HookGiveBackCapacity) // 禁止覆盖 - ctx = context.WithValue(ctx, fsctx.DisableOverwrite, true) + fileData.Mode = fsctx.Create } // 执行上传 - err = fs.Upload(ctx, fileData) + err = fs.Upload(ctx, &fileData) if err != nil { return http.StatusMethodNotAllowed, err } diff --git a/routers/controllers/file.go b/routers/controllers/file.go index 885a7e5..23d5f64 100644 --- a/routers/controllers/file.go +++ b/routers/controllers/file.go @@ -11,7 +11,6 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/service/explorer" @@ -307,12 +306,13 @@ func FileUploadStream(c *gin.Context) { return } - fileData := local.FileStream{ + fileData := fsctx.FileStream{ MIMEType: c.Request.Header.Get("Content-Type"), File: c.Request.Body, Size: fileSize, Name: fileName, VirtualPath: filePath, + Mode: fsctx.Create, } // 创建文件系统 @@ -341,9 +341,8 @@ func FileUploadStream(c *gin.Context) { // 执行上传 ctx = context.WithValue(ctx, fsctx.ValidateCapacityOnceCtx, &sync.Once{}) - ctx = context.WithValue(ctx, fsctx.DisableOverwrite, true) uploadCtx := context.WithValue(ctx, fsctx.GinCtx, c) - err = fs.Upload(uploadCtx, fileData) + err = fs.Upload(uploadCtx, &fileData) if err != nil { c.JSON(200, serializer.Err(serializer.CodeUploadFailed, err.Error(), err)) return diff --git a/routers/controllers/slave.go b/routers/controllers/slave.go index 2748582..4c62ac6 100644 --- a/routers/controllers/slave.go +++ b/routers/controllers/slave.go @@ -60,7 +60,7 @@ func SlaveUpload(c *gin.Context) { return } - fileData := local.FileStream{ + fileData := fsctx.FileStream{ MIMEType: c.Request.Header.Get("Content-Type"), File: c.Request.Body, Name: fileName, @@ -75,11 +75,11 @@ func SlaveUpload(c *gin.Context) { // 是否允许覆盖 if c.Request.Header.Get("X-Cr-Overwrite") == "false" { - ctx = context.WithValue(ctx, fsctx.DisableOverwrite, true) + fileData.Mode = fsctx.Create } // 执行上传 - err = fs.Upload(ctx, fileData) + err = fs.Upload(ctx, &fileData) if err != nil { c.JSON(200, serializer.Err(serializer.CodeUploadFailed, err.Error(), err)) return diff --git a/service/callback/upload.go b/service/callback/upload.go index 5a74ba3..8900375 100644 --- a/service/callback/upload.go +++ b/service/callback/upload.go @@ -7,7 +7,6 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/cos" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/onedrive" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/s3" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" @@ -151,16 +150,13 @@ func ProcessCallback(service CallbackProcessService, c *gin.Context) serializer. } // 创建文件头 - fileHeader := local.FileStream{ + fileHeader := fsctx.FileStream{ Size: callbackBody.Size, VirtualPath: callbackSession.VirtualPath, Name: callbackSession.Name, + SavePath: callbackBody.SourceName, } - // 生成上下文 - ctx := context.WithValue(context.Background(), fsctx.FileHeaderCtx, fileHeader) - ctx = context.WithValue(ctx, fsctx.SavePathCtx, callbackBody.SourceName) - // 添加钩子 fs.Use("BeforeAddFile", filesystem.HookValidateFile) fs.Use("BeforeAddFile", filesystem.HookValidateCapacity) @@ -169,7 +165,7 @@ func ProcessCallback(service CallbackProcessService, c *gin.Context) serializer. fs.Use("BeforeAddFileFailed", filesystem.HookDeleteTempFile) // 向数据库中添加文件 - file, err := fs.AddFile(ctx, parentFolder) + file, err := fs.AddFile(context.Background(), parentFolder, &fileHeader) if err != nil { return serializer.Err(serializer.CodeUploadFailed, err.Error(), err) } diff --git a/service/explorer/file.go b/service/explorer/file.go index 92961cb..fffd2d9 100644 --- a/service/explorer/file.go +++ b/service/explorer/file.go @@ -17,7 +17,6 @@ import ( model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/gin-gonic/gin" @@ -55,18 +54,18 @@ func (service *SingleFileService) Create(c *gin.Context) serializer.Response { // 上下文 ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ctx = context.WithValue(ctx, fsctx.DisableOverwrite, true) // 给文件系统分配钩子 fs.Use("BeforeUpload", filesystem.HookValidateFile) fs.Use("AfterUpload", filesystem.GenericAfterUpload) // 上传空文件 - err = fs.Upload(ctx, local.FileStream{ + err = fs.Upload(ctx, &fsctx.FileStream{ File: ioutil.NopCloser(strings.NewReader("")), Size: 0, VirtualPath: path.Dir(service.Path), Name: path.Base(service.Path), + Mode: fsctx.Create, }) if err != nil { return serializer.Err(serializer.CodeUploadFailed, err.Error(), err) @@ -372,7 +371,7 @@ func (service *FileIDService) PutContent(ctx context.Context, c *gin.Context) se return serializer.ParamErr("无法解析文件尺寸", err) } - fileData := local.FileStream{ + fileData := fsctx.FileStream{ MIMEType: c.Request.Header.Get("Content-Type"), File: c.Request.Body, Size: fileSize, @@ -397,7 +396,7 @@ func (service *FileIDService) PutContent(ctx context.Context, c *gin.Context) se fileList, err := model.RemoveFilesWithSoftLinks([]model.File{originFile[0]}) if err == nil && len(fileList) == 0 { // 如果包含软连接,应重新生成新文件副本,并更新source_name - originFile[0].SourceName = fs.GenerateSavePath(uploadCtx, fileData) + originFile[0].SourceName = fs.GenerateSavePath(uploadCtx, &fileData) fs.Use("AfterUpload", filesystem.HookUpdateSourceName) fs.Use("AfterUploadCanceled", filesystem.HookUpdateSourceName) fs.Use("AfterValidateFailed", filesystem.HookUpdateSourceName) @@ -417,7 +416,7 @@ func (service *FileIDService) PutContent(ctx context.Context, c *gin.Context) se // 执行上传 uploadCtx = context.WithValue(uploadCtx, fsctx.FileModelCtx, originFile[0]) - err = fs.Upload(uploadCtx, fileData) + err = fs.Upload(uploadCtx, &fileData) if err != nil { return serializer.Err(serializer.CodeUploadFailed, err.Error(), err) }