From 7eb81731018fa5dd1d20261ad538dc960c4fb36d Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Sun, 20 Mar 2022 11:29:50 +0800 Subject: [PATCH] Feat: adapt new uploader for s3 like policy This commit also fix #730, #713, #756, #5 --- models/policy.go | 2 +- pkg/filesystem/driver/oss/handler.go | 10 +- pkg/filesystem/driver/s3/handler.go | 193 +++++++++++++-------------- pkg/filesystem/filesystem.go | 7 +- pkg/filesystem/fsctx/stream.go | 7 +- pkg/serializer/upload.go | 25 ++-- routers/controllers/callback.go | 1 - routers/router.go | 4 +- service/admin/policy.go | 6 +- service/callback/upload.go | 10 +- 10 files changed, 128 insertions(+), 137 deletions(-) diff --git a/models/policy.go b/models/policy.go index c182eba..0dc0601 100644 --- a/models/policy.go +++ b/models/policy.go @@ -229,7 +229,7 @@ func (policy *Policy) IsUploadPlaceholderWithSize() bool { return true } - if util.ContainsString([]string{"onedrive", "oss", "qiniu", "cos"}, policy.Type) { + if util.ContainsString([]string{"onedrive", "oss", "qiniu", "cos", "s3"}, policy.Type) { return policy.OptionsSerialized.PlaceholderWithSize } diff --git a/pkg/filesystem/driver/oss/handler.go b/pkg/filesystem/driver/oss/handler.go index b04784b..2e15674 100644 --- a/pkg/filesystem/driver/oss/handler.go +++ b/pkg/filesystem/driver/oss/handler.go @@ -467,11 +467,11 @@ func (handler *Driver) Token(ctx context.Context, ttl int64, uploadSession *seri } return &serializer.UploadCredential{ - SessionID: uploadSession.Key, - ChunkSize: handler.Policy.OptionsSerialized.ChunkSize, - UploadID: imur.UploadID, - UploadURLs: urls, - Callback: completeURL, + SessionID: uploadSession.Key, + ChunkSize: handler.Policy.OptionsSerialized.ChunkSize, + UploadID: imur.UploadID, + UploadURLs: urls, + CompleteURL: completeURL, }, nil } diff --git a/pkg/filesystem/driver/s3/handler.go b/pkg/filesystem/driver/s3/handler.go index 120bcf1..a84178f 100644 --- a/pkg/filesystem/driver/s3/handler.go +++ b/pkg/filesystem/driver/s3/handler.go @@ -2,13 +2,12 @@ package s3 import ( "context" - "crypto/hmac" - "crypto/sha256" - "encoding/base64" - "encoding/hex" - "encoding/json" "errors" + "fmt" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff" "github.com/cloudreve/Cloudreve/v3/pkg/util" + "io" "net/http" "net/url" "path" @@ -47,6 +46,14 @@ type MetaData struct { Etag string } +func NewDriver(policy *model.Policy) (*Driver, error) { + driver := &Driver{ + Policy: policy, + } + + return driver, driver.InitS3Client() +} + // InitS3Client 初始化S3会话 func (handler *Driver) InitS3Client() error { if handler.Policy == nil { @@ -72,13 +79,7 @@ func (handler *Driver) InitS3Client() error { } // List 列出给定路径下的文件 -func (handler Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) { - - // 初始化客户端 - if err := handler.InitS3Client(); err != nil { - return nil, err - } - +func (handler *Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) { // 初始化列目录参数 base = strings.TrimPrefix(base, "/") if base != "" { @@ -155,8 +156,7 @@ func (handler Driver) List(ctx context.Context, base string, recursive bool) ([] } // Get 获取文件 -func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) { - +func (handler *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) { // 获取文件源地址 downloadURL, err := handler.Source( ctx, @@ -197,7 +197,7 @@ func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, } // Put 将文件流保存到指定目录 -func (handler Driver) Put(ctx context.Context, file fsctx.FileHeader) error { +func (handler *Driver) Put(ctx context.Context, file fsctx.FileHeader) error { defer file.Close() // 初始化客户端 @@ -205,13 +205,15 @@ func (handler Driver) Put(ctx context.Context, file fsctx.FileHeader) error { return err } - uploader := s3manager.NewUploader(handler.sess) + uploader := s3manager.NewUploader(handler.sess, func(u *s3manager.Uploader) { + u.PartSize = int64(handler.Policy.OptionsSerialized.ChunkSize) + }) dst := file.Info().SavePath _, err := uploader.Upload(&s3manager.UploadInput{ Bucket: &handler.Policy.BucketName, Key: &dst, - Body: file, + Body: io.LimitReader(file, int64(file.Info().Size)), }) if err != nil { @@ -223,13 +225,7 @@ func (handler Driver) Put(ctx context.Context, file fsctx.FileHeader) error { // Delete 删除一个或多个文件, // 返回未删除的文件,及遇到的最后一个错误 -func (handler Driver) Delete(ctx context.Context, files []string) ([]string, error) { - - // 初始化客户端 - if err := handler.InitS3Client(); err != nil { - return files, err - } - +func (handler *Driver) Delete(ctx context.Context, files []string) ([]string, error) { failed := make([]string, 0, len(files)) deleted := make([]string, 0, len(files)) @@ -263,12 +259,12 @@ func (handler Driver) Delete(ctx context.Context, files []string) ([]string, err } // Thumb 获取文件缩略图 -func (handler Driver) Thumb(ctx context.Context, path string) (*response.ContentResponse, error) { +func (handler *Driver) Thumb(ctx context.Context, path string) (*response.ContentResponse, error) { return nil, errors.New("未实现") } // Source 获取外链URL -func (handler Driver) Source( +func (handler *Driver) Source( ctx context.Context, path string, baseURL url.URL, @@ -325,42 +321,75 @@ func (handler Driver) Source( } // Token 获取上传策略和认证Token -func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) { - // 生成回调地址 - siteURL := model.GetSiteURL() - apiBaseURI, _ := url.Parse("/api/v3/callback/s3/" + uploadSession.Key) - apiURL := siteURL.ResolveReference(apiBaseURI) - - // 上传策略 - savePath := file.Info().SavePath - putPolicy := UploadPolicy{ - Expiration: time.Now().UTC().Add(time.Duration(ttl) * time.Second).Format(time.RFC3339), - Conditions: []interface{}{ - map[string]string{"bucket": handler.Policy.BucketName}, - []string{"starts-with", "$key", savePath}, - []string{"starts-with", "$success_action_redirect", apiURL.String()}, - []string{"starts-with", "$name", ""}, - []string{"starts-with", "$Content-Type", ""}, - map[string]string{"x-amz-algorithm": "AWS4-HMAC-SHA256"}, - }, +func (handler *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) { + // 检查文件是否存在 + fileInfo := file.Info() + if _, err := handler.Meta(ctx, fileInfo.SavePath); err == nil { + return nil, fmt.Errorf("file already exist") } - if handler.Policy.MaxSize > 0 { - putPolicy.Conditions = append(putPolicy.Conditions, - []interface{}{"content-length-range", 0, handler.Policy.MaxSize}) + // 创建分片上传 + expires := time.Now().Add(time.Duration(ttl) * time.Second) + res, err := handler.svc.CreateMultipartUpload(&s3.CreateMultipartUploadInput{ + Bucket: &handler.Policy.BucketName, + Key: &fileInfo.SavePath, + Expires: &expires, + }) + if err != nil { + return nil, fmt.Errorf("failed to create multipart upload: %w", err) } - // 生成上传凭证 - return handler.getUploadCredential(ctx, putPolicy, apiURL, savePath) -} + uploadSession.UploadID = *res.UploadId + + // 为每个分片签名上传 URL + chunks := chunk.NewChunkGroup(file, handler.Policy.OptionsSerialized.ChunkSize, &backoff.ConstantBackoff{}) + urls := make([]string, chunks.Num()) + for chunks.Next() { + err := chunks.Process(func(c *chunk.ChunkGroup, chunk io.Reader) error { + signedReq, _ := handler.svc.UploadPartRequest(&s3.UploadPartInput{ + Bucket: &handler.Policy.BucketName, + Key: &fileInfo.SavePath, + PartNumber: aws.Int64(int64(c.Index() + 1)), + UploadId: res.UploadId, + }) + + signedURL, err := signedReq.Presign(time.Duration(ttl) * time.Second) + if err != nil { + return err + } + + urls[c.Index()] = signedURL + return nil + }) + if err != nil { + return nil, err + } + } -// Meta 获取文件信息 -func (handler Driver) Meta(ctx context.Context, path string) (*MetaData, error) { - // 初始化客户端 - if err := handler.InitS3Client(); err != nil { + // 签名完成分片上传的请求URL + signedReq, _ := handler.svc.CompleteMultipartUploadRequest(&s3.CompleteMultipartUploadInput{ + Bucket: &handler.Policy.BucketName, + Key: &fileInfo.SavePath, + UploadId: res.UploadId, + }) + + signedURL, err := signedReq.Presign(time.Duration(ttl) * time.Second) + if err != nil { return nil, err } + // 生成上传凭证 + return &serializer.UploadCredential{ + SessionID: uploadSession.Key, + ChunkSize: handler.Policy.OptionsSerialized.ChunkSize, + UploadID: *res.UploadId, + UploadURLs: urls, + CompleteURL: signedURL, + }, nil +} + +// Meta 获取文件信息 +func (handler *Driver) Meta(ctx context.Context, path string) (*MetaData, error) { res, err := handler.svc.GetObject( &s3.GetObjectInput{ Bucket: &handler.Policy.BucketName, @@ -378,52 +407,8 @@ func (handler Driver) Meta(ctx context.Context, path string) (*MetaData, error) } -func (handler Driver) getUploadCredential(ctx context.Context, policy UploadPolicy, callback *url.URL, savePath string) (*serializer.UploadCredential, error) { - - longDate := time.Now().UTC().Format("20060102T150405Z") - shortDate := time.Now().UTC().Format("20060102") - - credential := handler.Policy.AccessKey + "/" + shortDate + "/" + handler.Policy.OptionsSerialized.Region + "/s3/aws4_request" - policy.Conditions = append(policy.Conditions, map[string]string{"x-amz-credential": credential}) - policy.Conditions = append(policy.Conditions, map[string]string{"x-amz-date": longDate}) - - // 编码上传策略 - policyJSON, err := json.Marshal(policy) - if err != nil { - return nil, err - } - policyEncoded := base64.StdEncoding.EncodeToString(policyJSON) - - //签名 - signature := getHMAC([]byte("AWS4"+handler.Policy.SecretKey), []byte(shortDate)) - signature = getHMAC(signature, []byte(handler.Policy.OptionsSerialized.Region)) - signature = getHMAC(signature, []byte("s3")) - signature = getHMAC(signature, []byte("aws4_request")) - signature = getHMAC(signature, []byte(policyEncoded)) - - return &serializer.UploadCredential{ - Policy: policyEncoded, - Callback: callback.String(), - Token: hex.EncodeToString(signature), - AccessKey: credential, - Path: savePath, - KeyTime: longDate, - }, nil -} - -func getHMAC(key []byte, data []byte) []byte { - hash := hmac.New(sha256.New, key) - hash.Write(data) - return hash.Sum(nil) -} - // CORS 创建跨域策略 -func (handler Driver) CORS() error { - // 初始化客户端 - if err := handler.InitS3Client(); err != nil { - return err - } - +func (handler *Driver) CORS() error { rule := s3.CORSRule{ AllowedMethods: aws.StringSlice([]string{ "GET", @@ -434,6 +419,7 @@ func (handler Driver) CORS() error { }), AllowedOrigins: aws.StringSlice([]string{"*"}), AllowedHeaders: aws.StringSlice([]string{"*"}), + ExposeHeaders: aws.StringSlice([]string{"ETag"}), MaxAgeSeconds: aws.Int64(3600), } @@ -448,6 +434,11 @@ func (handler Driver) CORS() error { } // 取消上传凭证 -func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error { - return nil +func (handler *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error { + _, err := handler.svc.AbortMultipartUpload(&s3.AbortMultipartUploadInput{ + UploadId: &uploadSession.UploadID, + Bucket: &handler.Policy.BucketName, + Key: &uploadSession.SavePath, + }) + return err } diff --git a/pkg/filesystem/filesystem.go b/pkg/filesystem/filesystem.go index 19dd01c..51a13a8 100644 --- a/pkg/filesystem/filesystem.go +++ b/pkg/filesystem/filesystem.go @@ -174,10 +174,9 @@ func (fs *FileSystem) DispatchHandler() error { } return nil case "s3": - fs.Handler = s3.Driver{ - Policy: currentPolicy, - } - return nil + handler, err := s3.NewDriver(currentPolicy) + fs.Handler = handler + return err default: return ErrUnknownPolicyType } diff --git a/pkg/filesystem/fsctx/stream.go b/pkg/filesystem/fsctx/stream.go index c51d28c..4cf48e3 100644 --- a/pkg/filesystem/fsctx/stream.go +++ b/pkg/filesystem/fsctx/stream.go @@ -1,6 +1,7 @@ package fsctx import ( + "errors" "io" "time" ) @@ -75,7 +76,11 @@ func (file *FileStream) Close() error { } func (file *FileStream) Seek(offset int64, whence int) (int64, error) { - return file.Seeker.Seek(offset, whence) + if file.Seekable() { + return file.Seeker.Seek(offset, whence) + } + + return 0, errors.New("no seeker") } func (file *FileStream) Seekable() bool { diff --git a/pkg/serializer/upload.go b/pkg/serializer/upload.go index da61c7d..b9e9029 100644 --- a/pkg/serializer/upload.go +++ b/pkg/serializer/upload.go @@ -20,19 +20,18 @@ type UploadPolicy struct { // UploadCredential 返回给客户端的上传凭证 type UploadCredential struct { - SessionID string `json:"sessionID"` - ChunkSize uint64 `json:"chunkSize"` // 分块大小,0 为部分快 - Expires int64 `json:"expires"` // 上传凭证过期时间, Unix 时间戳 - UploadURLs []string `json:"uploadURLs,omitempty"` - Credential string `json:"credential,omitempty"` - UploadID string `json:"uploadID,omitempty"` - Callback string `json:"callback,omitempty"` // 回调地址 - Path string `json:"path,omitempty"` // 存储路径 - AccessKey string `json:"ak,omitempty"` - KeyTime string `json:"keyTime,omitempty"` // COS用有效期 - Policy string `json:"policy,omitempty"` - - Token string `json:"token,omitempty"` + SessionID string `json:"sessionID"` + ChunkSize uint64 `json:"chunkSize"` // 分块大小,0 为部分快 + Expires int64 `json:"expires"` // 上传凭证过期时间, Unix 时间戳 + UploadURLs []string `json:"uploadURLs,omitempty"` + Credential string `json:"credential,omitempty"` + UploadID string `json:"uploadID,omitempty"` + Callback string `json:"callback,omitempty"` // 回调地址 + Path string `json:"path,omitempty"` // 存储路径 + AccessKey string `json:"ak,omitempty"` + KeyTime string `json:"keyTime,omitempty"` // COS用有效期 + Policy string `json:"policy,omitempty"` + CompleteURL string `json:"completeURL,omitempty"` } // UploadSession 上传会话 diff --git a/routers/controllers/callback.go b/routers/controllers/callback.go index 9179ab5..dade566 100644 --- a/routers/controllers/callback.go +++ b/routers/controllers/callback.go @@ -112,7 +112,6 @@ func COSCallback(c *gin.Context) { // S3Callback S3上传完成客户端回调 func S3Callback(c *gin.Context) { - c.Header("Access-Control-Allow-Origin", "*") var callbackBody callback.S3Callback if err := c.ShouldBindQuery(&callbackBody); err == nil { res := callbackBody.PreProcess(c) diff --git a/routers/router.go b/routers/router.go index bfaa37e..962179f 100644 --- a/routers/router.go +++ b/routers/router.go @@ -286,8 +286,8 @@ func InitMasterRouter() *gin.Engine { ) // AWS S3策略上传回调 callback.GET( - "s3/:key", - middleware.S3CallbackAuth(), + "s3/:sessionID", + middleware.UseUploadSession("s3"), controllers.S3Callback, ) } diff --git a/service/admin/policy.go b/service/admin/policy.go index c5b771e..a9151d5 100644 --- a/service/admin/policy.go +++ b/service/admin/policy.go @@ -174,9 +174,11 @@ func (service *PolicyService) AddCORS() serializer.Response { return serializer.Err(serializer.CodeInternalSetting, "跨域策略添加失败", err) } case "s3": - handler := s3.Driver{ - Policy: &policy, + handler, err := s3.NewDriver(&policy) + if err != nil { + return serializer.Err(serializer.CodeInternalSetting, "跨域策略添加失败", err) } + if err := handler.CORS(); err != nil { return serializer.Err(serializer.CodeInternalSetting, "跨域策略添加失败", err) } diff --git a/service/callback/upload.go b/service/callback/upload.go index 411b571..fb94675 100644 --- a/service/callback/upload.go +++ b/service/callback/upload.go @@ -61,9 +61,6 @@ type COSCallback struct { // S3Callback S3 客户端回调正文 type S3Callback struct { - Bucket string `form:"bucket"` - Etag string `form:"etag"` - Key string `form:"key"` } // GetBody 返回回调正文 @@ -226,17 +223,16 @@ func (service *S3Callback) PreProcess(c *gin.Context) serializer.Response { defer fs.Recycle() // 获取回调会话 - callbackSessionRaw, _ := c.Get("callbackSession") - callbackSession := callbackSessionRaw.(*serializer.UploadSession) + uploadSession := c.MustGet(filesystem.UploadSessionCtx).(*serializer.UploadSession) // 获取文件信息 - info, err := fs.Handler.(s3.Driver).Meta(context.Background(), callbackSession.SavePath) + info, err := fs.Handler.(*s3.Driver).Meta(context.Background(), uploadSession.SavePath) if err != nil { return serializer.Err(serializer.CodeUploadFailed, "文件信息不一致", err) } // 验证实际文件信息与回调会话中是否一致 - if callbackSession.Size != info.Size || service.Etag != info.Etag { + if uploadSession.Size != info.Size { return serializer.Err(serializer.CodeUploadFailed, "文件信息不一致", err) }