diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 2e39dc74..510e4810 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -10,10 +10,10 @@ jobs: runs-on: ubuntu-18.04 steps: - - name: Set up Go 1.13 + - name: Set up Golang uses: actions/setup-go@v1 with: - go-version: 1.13 + go-version: 1.17 id: go - name: Check out code into the Go module directory @@ -26,7 +26,7 @@ jobs: - name: Get dependencies and build run: | - go get github.com/rakyll/statik + go install github.com/rakyll/statik export PATH=$PATH:~/go/bin/ statik -src=models -f sudo apt-get update diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 31eb7e8d..0c01296b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -14,10 +14,10 @@ jobs: runs-on: ubuntu-18.04 steps: - - name: Set up Go 1.13 + - name: Set up Golang uses: actions/setup-go@v1 with: - go-version: 1.13 + go-version: 1.17 id: go - name: Check out code into the Go module directory diff --git a/.travis.yml b/.travis.yml index 99bef5c2..22d4990b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,6 @@ language: go go: - - 1.13.x + - 1.17.x node_js: "12.16.3" git: depth: 1 diff --git a/Dockerfile b/Dockerfile index 28bfea1d..0f8edf4c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:alpine as cloudreve_builder +FROM golang:1.17.7-alpine as cloudreve_builder # install dependencies and build tools diff --git a/assets b/assets index eb3f3292..e0da8f48 160000 --- a/assets +++ b/assets @@ -1 +1 @@ -Subproject commit eb3f32922ab9cd2f9fbef4860b93fec759a7054d +Subproject commit e0da8f48856e3fb6e3e9cc920a32390ca132935e diff --git a/middleware/auth.go b/middleware/auth.go index 3e7cbe73..83f972b7 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -1,24 +1,19 @@ package middleware import ( - "bytes" - "context" - "crypto/md5" - "fmt" - "io/ioutil" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" "net/http" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/onedrive" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/oss" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/upyun" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" - "github.com/qiniu/api.v7/v7/auth/qbox" +) + +const ( + CallbackFailedStatusCode = http.StatusUnauthorized ) // SignRequired 验证请求签名 @@ -117,48 +112,60 @@ func WebDAVAuth() gin.HandlerFunc { } } +// 对上传会话进行验证 +func UseUploadSession(policyType string) gin.HandlerFunc { + return func(c *gin.Context) { + // 验证key并查找用户 + resp := uploadCallbackCheck(c, policyType) + if resp.Code != 0 { + c.JSON(CallbackFailedStatusCode, resp) + c.Abort() + return + } + + c.Next() + } +} + // uploadCallbackCheck 对上传回调请求的 callback key 进行验证,如果成功则返回上传用户 -func uploadCallbackCheck(c *gin.Context) (serializer.Response, *model.User) { +func uploadCallbackCheck(c *gin.Context, policyType string) serializer.Response { // 验证 Callback Key - callbackKey := c.Param("key") - if callbackKey == "" { - return serializer.ParamErr("Callback Key 不能为空", nil), nil + sessionID := c.Param("sessionID") + if sessionID == "" { + return serializer.ParamErr("Session ID 不能为空", nil) } - callbackSessionRaw, exist := cache.Get("callback_" + callbackKey) + + callbackSessionRaw, exist := cache.Get(filesystem.UploadSessionCachePrefix + sessionID) if !exist { - return serializer.ParamErr("回调会话不存在或已过期", nil), nil + return serializer.ParamErr("上传会话不存在或已过期", nil) } + callbackSession := callbackSessionRaw.(serializer.UploadSession) - c.Set("callbackSession", &callbackSession) + c.Set(filesystem.UploadSessionCtx, &callbackSession) + if callbackSession.Policy.Type != policyType { + return serializer.Err(serializer.CodePolicyNotAllowed, "Policy not supported", nil) + } // 清理回调会话 - _ = cache.Deletes([]string{callbackKey}, "callback_") + _ = cache.Deletes([]string{sessionID}, filesystem.UploadSessionCachePrefix) // 查找用户 user, err := model.GetActiveUserByID(callbackSession.UID) if err != nil { - return serializer.Err(serializer.CodeCheckLogin, "找不到用户", err), nil + return serializer.Err(serializer.CodeCheckLogin, "找不到用户", err) } - c.Set("user", &user) - - return serializer.Response{}, &user + c.Set(filesystem.UserCtx, &user) + return serializer.Response{} } // RemoteCallbackAuth 远程回调签名验证 func RemoteCallbackAuth() gin.HandlerFunc { return func(c *gin.Context) { - // 验证key并查找用户 - resp, user := uploadCallbackCheck(c) - if resp.Code != 0 { - c.JSON(200, resp) - c.Abort() - return - } - // 验证签名 - authInstance := auth.HMACAuth{SecretKey: []byte(user.Policy.SecretKey)} + session := c.MustGet(filesystem.UploadSessionCtx).(*serializer.UploadSession) + authInstance := auth.HMACAuth{SecretKey: []byte(session.Policy.SecretKey)} if err := auth.CheckRequest(authInstance, c.Request); err != nil { - c.JSON(200, serializer.Err(serializer.CodeCheckLogin, err.Error(), err)) + c.JSON(CallbackFailedStatusCode, serializer.Err(serializer.CodeCredentialInvalid, err.Error(), err)) c.Abort() return } @@ -171,28 +178,28 @@ func RemoteCallbackAuth() gin.HandlerFunc { // QiniuCallbackAuth 七牛回调签名验证 func QiniuCallbackAuth() gin.HandlerFunc { return func(c *gin.Context) { - // 验证key并查找用户 - resp, user := uploadCallbackCheck(c) - if resp.Code != 0 { - c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg}) - c.Abort() - return - } - - // 验证回调是否来自qiniu - mac := qbox.NewMac(user.Policy.AccessKey, user.Policy.SecretKey) - ok, err := mac.VerifyCallback(c.Request) - if err != nil { - util.Log().Debug("无法验证回调请求,%s", err) - c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "无法验证回调请求"}) - c.Abort() - return - } - if !ok { - c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "回调签名无效"}) - c.Abort() - return - } + //// 验证key并查找用户 + //resp, user := uploadCallbackCheck(c) + //if resp.Code != 0 { + // c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg}) + // c.Abort() + // return + //} + // + //// 验证回调是否来自qiniu + //mac := qbox.NewMac(user.Policy.AccessKey, user.Policy.SecretKey) + //ok, err := mac.VerifyCallback(c.Request) + //if err != nil { + // util.Log().Debug("无法验证回调请求,%s", err) + // c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "无法验证回调请求"}) + // c.Abort() + // return + //} + //if !ok { + // c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "回调签名无效"}) + // c.Abort() + // return + //} c.Next() } @@ -201,21 +208,21 @@ func QiniuCallbackAuth() gin.HandlerFunc { // OSSCallbackAuth 阿里云OSS回调签名验证 func OSSCallbackAuth() gin.HandlerFunc { return func(c *gin.Context) { - // 验证key并查找用户 - resp, _ := uploadCallbackCheck(c) - if resp.Code != 0 { - c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg}) - c.Abort() - return - } - - err := oss.VerifyCallbackSignature(c.Request) - if err != nil { - util.Log().Debug("回调签名验证失败,%s", err) - c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "回调签名验证失败"}) - c.Abort() - return - } + //// 验证key并查找用户 + //resp, _ := uploadCallbackCheck(c) + //if resp.Code != 0 { + // c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg}) + // c.Abort() + // return + //} + // + //err := oss.VerifyCallbackSignature(c.Request) + //if err != nil { + // util.Log().Debug("回调签名验证失败,%s", err) + // c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "回调签名验证失败"}) + // c.Abort() + // return + //} c.Next() } @@ -224,53 +231,53 @@ func OSSCallbackAuth() gin.HandlerFunc { // UpyunCallbackAuth 又拍云回调签名验证 func UpyunCallbackAuth() gin.HandlerFunc { return func(c *gin.Context) { - // 验证key并查找用户 - resp, user := uploadCallbackCheck(c) - if resp.Code != 0 { - c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg}) - c.Abort() - return - } - - // 获取请求正文 - body, err := ioutil.ReadAll(c.Request.Body) - c.Request.Body.Close() - if err != nil { - c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: err.Error()}) - c.Abort() - return - } - - c.Request.Body = ioutil.NopCloser(bytes.NewReader(body)) - - // 准备验证Upyun回调签名 - handler := upyun.Driver{Policy: &user.Policy} - contentMD5 := c.Request.Header.Get("Content-Md5") - date := c.Request.Header.Get("Date") - actualSignature := c.Request.Header.Get("Authorization") - - // 计算正文MD5 - actualContentMD5 := fmt.Sprintf("%x", md5.Sum(body)) - if actualContentMD5 != contentMD5 { - c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "MD5不一致"}) - c.Abort() - return - } - - // 计算理论签名 - signature := handler.Sign(context.Background(), []string{ - "POST", - c.Request.URL.Path, - date, - contentMD5, - }) - - // 对比签名 - if signature != actualSignature { - c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "鉴权失败"}) - c.Abort() - return - } + //// 验证key并查找用户 + //resp, user := uploadCallbackCheck(c) + //if resp.Code != 0 { + // c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg}) + // c.Abort() + // return + //} + // + //// 获取请求正文 + //body, err := ioutil.ReadAll(c.Request.Body) + //c.Request.Body.Close() + //if err != nil { + // c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: err.Error()}) + // c.Abort() + // return + //} + // + //c.Request.Body = ioutil.NopCloser(bytes.NewReader(body)) + // + //// 准备验证Upyun回调签名 + //handler := upyun.Driver{Policy: &user.Policy} + //contentMD5 := c.Request.Header.Get("Content-Md5") + //date := c.Request.Header.Get("Date") + //actualSignature := c.Request.Header.Get("Authorization") + // + //// 计算正文MD5 + //actualContentMD5 := fmt.Sprintf("%x", md5.Sum(body)) + //if actualContentMD5 != contentMD5 { + // c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "MD5不一致"}) + // c.Abort() + // return + //} + // + //// 计算理论签名 + //signature := handler.Sign(context.Background(), []string{ + // "POST", + // c.Request.URL.Path, + // date, + // contentMD5, + //}) + // + //// 对比签名 + //if signature != actualSignature { + // c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "鉴权失败"}) + // c.Abort() + // return + //} c.Next() } @@ -280,16 +287,16 @@ func UpyunCallbackAuth() gin.HandlerFunc { // TODO 解耦 func OneDriveCallbackAuth() gin.HandlerFunc { return func(c *gin.Context) { - // 验证key并查找用户 - resp, _ := uploadCallbackCheck(c) - if resp.Code != 0 { - c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg}) - c.Abort() - return - } - - // 发送回调结束信号 - onedrive.FinishCallback(c.Param("key")) + //// 验证key并查找用户 + //resp, _ := uploadCallbackCheck(c) + //if resp.Code != 0 { + // c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg}) + // c.Abort() + // return + //} + // + //// 发送回调结束信号 + //onedrive.FinishCallback(c.Param("key")) c.Next() } @@ -299,13 +306,13 @@ func OneDriveCallbackAuth() gin.HandlerFunc { // TODO 解耦 测试 func COSCallbackAuth() gin.HandlerFunc { return func(c *gin.Context) { - // 验证key并查找用户 - resp, _ := uploadCallbackCheck(c) - if resp.Code != 0 { - c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg}) - c.Abort() - return - } + //// 验证key并查找用户 + //resp, _ := uploadCallbackCheck(c) + //if resp.Code != 0 { + // c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg}) + // c.Abort() + // return + //} c.Next() } @@ -314,13 +321,13 @@ func COSCallbackAuth() gin.HandlerFunc { // S3CallbackAuth Amazon S3回调签名验证 func S3CallbackAuth() gin.HandlerFunc { return func(c *gin.Context) { - // 验证key并查找用户 - resp, _ := uploadCallbackCheck(c) - if resp.Code != 0 { - c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg}) - c.Abort() - return - } + //// 验证key并查找用户 + //resp, _ := uploadCallbackCheck(c) + //if resp.Code != 0 { + // c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg}) + // c.Abort() + // return + //} c.Next() } diff --git a/models/file.go b/models/file.go index 34918bbd..004d1579 100644 --- a/models/file.go +++ b/models/file.go @@ -268,6 +268,7 @@ func (file *File) UpdatePicInfo(value string) error { } // UpdateSize 更新文件的大小信息 +// TODO: 全局锁 func (file *File) UpdateSize(value uint64) error { tx := DB.Begin() var sizeDelta uint64 @@ -281,7 +282,10 @@ func (file *File) UpdateSize(value uint64) error { sizeDelta = file.Size - value } - if res := tx.Model(&file).Set("gorm:association_autoupdate", false).Update("size", value); res.Error != nil { + if res := tx.Model(&file). + Where("size = ?", file.Size). + Set("gorm:association_autoupdate", false). + Update("size", value); res.Error != nil { tx.Rollback() return res.Error } @@ -291,6 +295,7 @@ func (file *File) UpdateSize(value uint64) error { return err } + file.Size = value return tx.Commit().Error } @@ -299,7 +304,7 @@ func (file *File) UpdateSourceName(value string) error { return DB.Model(&file).Set("gorm:association_autoupdate", false).Update("source_name", value).Error } -func (file *File) PopChunkToFile(lastModified *time.Time) error { +func (file *File) PopChunkToFile(lastModified *time.Time, picInfo string) error { file.UploadSessionID = nil if lastModified != nil { file.UpdatedAt = *lastModified @@ -308,6 +313,7 @@ func (file *File) PopChunkToFile(lastModified *time.Time) error { return DB.Model(file).UpdateColumns(map[string]interface{}{ "upload_session_id": file.UploadSessionID, "updated_at": file.UpdatedAt, + "pic_info": picInfo, }).Error } diff --git a/models/migration.go b/models/migration.go index 230f35ab..1053e654 100644 --- a/models/migration.go +++ b/models/migration.go @@ -125,6 +125,7 @@ func addDefaultSettings() { {Name: "onedrive_callback_check", Value: `20`, Type: "timeout"}, {Name: "folder_props_timeout", Value: `300`, Type: "timeout"}, {Name: "onedrive_chunk_retries", Value: `1`, Type: "retry"}, + {Name: "slave_chunk_retries", Value: `1`, Type: "retry"}, {Name: "onedrive_source_timeout", Value: `1800`, Type: "timeout"}, {Name: "reset_after_upload_failed", Value: `0`, Type: "upload"}, {Name: "login_captcha", Value: `0`, Type: "login"}, diff --git a/models/task.go b/models/task.go index e880f85a..028d5522 100644 --- a/models/task.go +++ b/models/task.go @@ -64,7 +64,7 @@ func ListTasks(uid uint, page, pageSize int, order string) ([]Task, int) { dbChain = dbChain.Where("user_id = ?", uid) // 计算总数用于分页 - dbChain.Model(&Share{}).Count(&total) + dbChain.Model(&Task{}).Count(&total) // 查询记录 dbChain.Limit(pageSize).Offset((page - 1) * pageSize).Order(order).Find(&tasks) diff --git a/pkg/cluster/slave.go b/pkg/cluster/slave.go index ac8a46ae..49e2a48d 100644 --- a/pkg/cluster/slave.go +++ b/pkg/cluster/slave.go @@ -1,11 +1,15 @@ package cluster import ( + "bytes" "encoding/json" + "errors" + "fmt" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/auth" + "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/util" @@ -40,7 +44,7 @@ func (node *SlaveNode) Init(nodeModel *model.Node) { var endpoint *url.URL if serverURL, err := url.Parse(node.Model.Server); err == nil { var controller *url.URL - controller, _ = url.Parse("/api/v3/slave") + controller, _ = url.Parse("/api/v3/slave/") endpoint = serverURL.ResolveReference(controller) } @@ -408,3 +412,41 @@ func getAria2RequestBody(body *serializer.SlaveAria2Call) (io.Reader, error) { return strings.NewReader(string(reqBodyEncoded)), nil } + +// TODO: move to slave pkg +// RemoteCallback 发送远程存储策略上传回调请求 +func RemoteCallback(url string, body serializer.UploadCallback) error { + callbackBody, err := json.Marshal(struct { + Data serializer.UploadCallback `json:"data"` + }{ + Data: body, + }) + if err != nil { + return serializer.NewError(serializer.CodeCallbackError, "无法编码回调正文", err) + } + + resp := request.GeneralClient.Request( + "POST", + url, + bytes.NewReader(callbackBody), + request.WithTimeout(time.Duration(conf.SlaveConfig.CallbackTimeout)*time.Second), + request.WithCredential(auth.General, int64(conf.SlaveConfig.SignatureTTL)), + ) + + if resp.Err != nil { + return serializer.NewError(serializer.CodeCallbackError, "从机无法发起回调请求", resp.Err) + } + + // 解析回调服务端响应 + response, err := resp.DecodeResponse() + if err != nil { + msg := fmt.Sprintf("从机无法解析主机返回的响应 (StatusCode=%d)", resp.Response.StatusCode) + return serializer.NewError(serializer.CodeCallbackError, msg, err) + } + + if response.Code != 0 { + return serializer.NewError(response.Code, response.Msg, errors.New(response.Error)) + } + + return nil +} diff --git a/pkg/cluster/slave_test.go b/pkg/cluster/slave_test.go index 0b70caae..25809361 100644 --- a/pkg/cluster/slave_test.go +++ b/pkg/cluster/slave_test.go @@ -441,3 +441,125 @@ func TestSlaveCaller_DeleteTempFile(t *testing.T) { a.NoError(err) } } + +//func TestRemoteCallback(t *testing.T) { +// asserts := assert.New(t) +// +// // 回调成功 +// { +// clientMock := request.ClientMock{} +// mockResp, _ := json.Marshal(serializer.Response{Code: 0}) +// clientMock.On( +// "Request", +// "POST", +// "http://test/test/url", +// testMock.Anything, +// testMock.Anything, +// ).Return(&request.Response{ +// Err: nil, +// Response: &http.Response{ +// StatusCode: 200, +// Body: ioutil.NopCloser(bytes.NewReader(mockResp)), +// }, +// }) +// request.GeneralClient = clientMock +// resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ +// SourceName: "source", +// }) +// asserts.NoError(resp) +// clientMock.AssertExpectations(t) +// } +// +// // 服务端返回业务错误 +// { +// clientMock := request.ClientMock{} +// mockResp, _ := json.Marshal(serializer.Response{Code: 401}) +// clientMock.On( +// "Request", +// "POST", +// "http://test/test/url", +// testMock.Anything, +// testMock.Anything, +// ).Return(&request.Response{ +// Err: nil, +// Response: &http.Response{ +// StatusCode: 200, +// Body: ioutil.NopCloser(bytes.NewReader(mockResp)), +// }, +// }) +// request.GeneralClient = clientMock +// resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ +// SourceName: "source", +// }) +// asserts.EqualValues(401, resp.(serializer.AppError).Code) +// clientMock.AssertExpectations(t) +// } +// +// // 无法解析回调响应 +// { +// clientMock := request.ClientMock{} +// clientMock.On( +// "Request", +// "POST", +// "http://test/test/url", +// testMock.Anything, +// testMock.Anything, +// ).Return(&request.Response{ +// Err: nil, +// Response: &http.Response{ +// StatusCode: 200, +// Body: ioutil.NopCloser(strings.NewReader("mockResp")), +// }, +// }) +// request.GeneralClient = clientMock +// resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ +// SourceName: "source", +// }) +// asserts.Error(resp) +// clientMock.AssertExpectations(t) +// } +// +// // HTTP状态码非200 +// { +// clientMock := request.ClientMock{} +// clientMock.On( +// "Request", +// "POST", +// "http://test/test/url", +// testMock.Anything, +// testMock.Anything, +// ).Return(&request.Response{ +// Err: nil, +// Response: &http.Response{ +// StatusCode: 404, +// Body: ioutil.NopCloser(strings.NewReader("mockResp")), +// }, +// }) +// request.GeneralClient = clientMock +// resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ +// SourceName: "source", +// }) +// asserts.Error(resp) +// clientMock.AssertExpectations(t) +// } +// +// // 无法发起回调 +// { +// clientMock := request.ClientMock{} +// clientMock.On( +// "Request", +// "POST", +// "http://test/test/url", +// testMock.Anything, +// testMock.Anything, +// ).Return(&request.Response{ +// Err: errors.New("error"), +// }) +// request.GeneralClient = clientMock +// resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ +// SourceName: "source", +// }) +// asserts.Error(resp) +// clientMock.AssertExpectations(t) +// } +//} diff --git a/pkg/filesystem/archive.go b/pkg/filesystem/archive.go index ac0a2fc7..cca49eb4 100644 --- a/pkg/filesystem/archive.go +++ b/pkg/filesystem/archive.go @@ -306,9 +306,9 @@ func (fs *FileSystem) Decompress(ctx context.Context, src, dst string) error { err = fs.UploadFromStream(ctx, &fsctx.FileStream{ File: fileStream, Size: uint64(size), - Name: path.Base(dst), - VirtualPath: path.Dir(dst), - }) + Name: path.Base(savePath), + VirtualPath: path.Dir(savePath), + }, true) fileStream.Close() if err != nil { util.Log().Debug("无法上传压缩包内的文件%s , %s , 跳过", rawPath, err) diff --git a/pkg/filesystem/chunk/backoff/backoff.go b/pkg/filesystem/chunk/backoff/backoff.go new file mode 100644 index 00000000..d15b9754 --- /dev/null +++ b/pkg/filesystem/chunk/backoff/backoff.go @@ -0,0 +1,31 @@ +package backoff + +import "time" + +// Backoff used for retry sleep backoff +type Backoff interface { + Next() bool + Reset() +} + +// ConstantBackoff implements Backoff interface with constant sleep time +type ConstantBackoff struct { + Sleep time.Duration + Max int + + tried int +} + +func (c *ConstantBackoff) Next() bool { + c.tried++ + if c.tried > c.Max { + return false + } + + time.Sleep(c.Sleep) + return true +} + +func (c *ConstantBackoff) Reset() { + c.tried = 0 +} diff --git a/pkg/filesystem/chunk/chunk.go b/pkg/filesystem/chunk/chunk.go new file mode 100644 index 00000000..e8a63f7d --- /dev/null +++ b/pkg/filesystem/chunk/chunk.go @@ -0,0 +1,91 @@ +package chunk + +import ( + "context" + "fmt" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" + "github.com/cloudreve/Cloudreve/v3/pkg/util" + "io" +) + +// ChunkProcessFunc callback function for processing a chunk +type ChunkProcessFunc func(c *ChunkGroup, chunk io.Reader) error + +// ChunkGroup manage groups of chunks +type ChunkGroup struct { + file fsctx.FileHeader + chunkSize uint64 + backoff backoff.Backoff + + fileInfo *fsctx.UploadTaskInfo + currentIndex int + chunkNum uint64 +} + +func NewChunkGroup(file fsctx.FileHeader, chunkSize uint64, backoff backoff.Backoff) *ChunkGroup { + c := &ChunkGroup{ + file: file, + chunkSize: chunkSize, + backoff: backoff, + fileInfo: file.Info(), + currentIndex: -1, + } + + if c.chunkSize == 0 { + c.chunkSize = c.fileInfo.Size + } + + c.chunkNum = c.fileInfo.Size / c.chunkSize + if c.fileInfo.Size%c.chunkSize != 0 || c.fileInfo.Size == 0 { + c.chunkNum++ + } + + return c +} + +// Process a chunk with retry logic +func (c *ChunkGroup) Process(processor ChunkProcessFunc) error { + err := processor(c, io.LimitReader(c.file, int64(c.chunkSize))) + if err != nil { + if err != context.Canceled && c.file.Seekable() && c.backoff.Next() { + if _, seekErr := c.file.Seek(c.Start(), io.SeekStart); seekErr != nil { + return fmt.Errorf("failed to seek back to chunk start: %w, last error: %w", seekErr, err) + } + + util.Log().Debug("Retrying chunk %d, last error: %s", c.currentIndex, err) + return c.Process(processor) + } + + return err + } + + return nil +} + +// Start returns the byte index of current chunk +func (c *ChunkGroup) Start() int64 { + return int64(uint64(c.Index()) * c.chunkSize) +} + +// Index returns current chunk index, starts from 0 +func (c *ChunkGroup) Index() int { + return c.currentIndex +} + +// Next switch to next chunk, returns whether all chunks are processed +func (c *ChunkGroup) Next() bool { + c.currentIndex++ + c.backoff.Reset() + return c.currentIndex < int(c.chunkNum) +} + +// Length returns the length of current chunk +func (c *ChunkGroup) Length() int64 { + contentLength := c.chunkSize + if c.Index() == int(c.chunkNum-1) { + contentLength = c.fileInfo.Size - c.chunkSize*(c.chunkNum-1) + } + + return int64(contentLength) +} diff --git a/pkg/filesystem/driver/local/handler.go b/pkg/filesystem/driver/local/handler.go index 92432d55..85bb3eb8 100644 --- a/pkg/filesystem/driver/local/handler.go +++ b/pkg/filesystem/driver/local/handler.go @@ -267,6 +267,10 @@ func (handler Driver) Source( // Token 获取上传策略和认证Token,本地策略直接返回空值 func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) { + if util.Exists(uploadSession.SavePath) { + return nil, errors.New("placeholder file already exist") + } + return &serializer.UploadCredential{ SessionID: uploadSession.Key, ChunkSize: handler.Policy.OptionsSerialized.ChunkSize, diff --git a/pkg/filesystem/driver/remote/client.go b/pkg/filesystem/driver/remote/client.go index 632333de..ab764f51 100644 --- a/pkg/filesystem/driver/remote/client.go +++ b/pkg/filesystem/driver/remote/client.go @@ -3,25 +3,40 @@ package remote import ( "context" "encoding/json" + "fmt" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/auth" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/gofrs/uuid" + "io" "net/http" "net/url" "path" "strings" + "time" ) const ( - basePath = "/api/v3/slave" + basePath = "/api/v3/slave/" OverwriteHeader = auth.CrHeaderPrefix + "Overwrite" + chunkRetrySleep = time.Duration(5) * time.Second ) -// Client to operate remote slave server +// Client to operate uploading to remote slave server type Client interface { + // CreateUploadSession creates remote upload session CreateUploadSession(ctx context.Context, session *serializer.UploadSession, ttl int64) error + // GetUploadURL signs an url for uploading file GetUploadURL(ttl int64, sessionID string) (string, string, error) + // Upload uploads file to remote server + Upload(ctx context.Context, file fsctx.FileHeader) error + // DeleteUploadSession deletes remote upload session + DeleteUploadSession(ctx context.Context, sessionID string) error } // NewClient creates new Client from given policy @@ -42,6 +57,7 @@ func NewClient(policy *model.Policy) (Client, error) { request.WithEndpoint(serverURL.ResolveReference(base).String()), request.WithCredential(authInstance, int64(signTTL)), request.WithMasterMeta(), + request.WithSlaveMeta(policy.AccessKey), ), }, nil } @@ -52,6 +68,68 @@ type remoteClient struct { httpClient request.Client } +func (c *remoteClient) Upload(ctx context.Context, file fsctx.FileHeader) error { + ttl := model.GetIntSetting("upload_session_timeout", 86400) + fileInfo := file.Info() + session := &serializer.UploadSession{ + Key: uuid.Must(uuid.NewV4()).String(), + VirtualPath: fileInfo.VirtualPath, + Name: fileInfo.FileName, + Size: fileInfo.Size, + SavePath: fileInfo.SavePath, + LastModified: fileInfo.LastModified, + Policy: *c.policy, + } + + // Create upload session + if err := c.CreateUploadSession(ctx, session, int64(ttl)); err != nil { + return fmt.Errorf("failed to create upload session: %w", err) + } + + overwrite := fileInfo.Mode&fsctx.Overwrite == fsctx.Overwrite + + // Initial chunk groups + chunks := chunk.NewChunkGroup(file, c.policy.OptionsSerialized.ChunkSize, &backoff.ConstantBackoff{ + Max: model.GetIntSetting("onedrive_chunk_retries", 1), + Sleep: chunkRetrySleep, + }) + + uploadFunc := func(current *chunk.ChunkGroup, content io.Reader) error { + return c.uploadChunk(ctx, session.Key, current.Index(), content, overwrite, current.Length()) + } + + // upload chunks + for chunks.Next() { + if err := chunks.Process(uploadFunc); err != nil { + if err := c.DeleteUploadSession(ctx, session.Key); err != nil { + util.Log().Warning("failed to delete upload session: %s", err) + } + + return fmt.Errorf("failed to upload chunk #%d: %w", chunks.Index(), err) + } + } + + return nil +} + +func (c *remoteClient) DeleteUploadSession(ctx context.Context, sessionID string) error { + resp, err := c.httpClient.Request( + "DELETE", + "upload/"+sessionID, + nil, + request.WithContext(ctx), + ).CheckHTTPResponse(200).DecodeResponse() + if err != nil { + return err + } + + if resp.Code != 0 { + return serializer.NewErrorFromResponse(resp) + } + + return nil +} + func (c *remoteClient) CreateUploadSession(ctx context.Context, session *serializer.UploadSession, ttl int64) error { reqBodyEncoded, err := json.Marshal(map[string]interface{}{ "session": session, @@ -94,3 +172,24 @@ func (c *remoteClient) GetUploadURL(ttl int64, sessionID string) (string, string req = auth.SignRequest(c.authInstance, req, ttl) return req.URL.String(), req.Header["Authorization"][0], nil } + +func (c *remoteClient) uploadChunk(ctx context.Context, sessionID string, index int, chunk io.Reader, overwrite bool, size int64) error { + resp, err := c.httpClient.Request( + "POST", + fmt.Sprintf("upload/%s?chunk=%d", sessionID, index), + chunk, + request.WithContext(ctx), + request.WithTimeout(time.Duration(0)), + request.WithContentLength(size), + request.WithHeader(map[string][]string{OverwriteHeader: {fmt.Sprintf("%t", overwrite)}}), + ).CheckHTTPResponse(200).DecodeResponse() + if err != nil { + return err + } + + if resp.Code != 0 { + return serializer.NewErrorFromResponse(resp) + } + + return nil +} diff --git a/pkg/filesystem/driver/remote/handler.go b/pkg/filesystem/driver/remote/handler.go index 2793a706..88374393 100644 --- a/pkg/filesystem/driver/remote/handler.go +++ b/pkg/filesystem/driver/remote/handler.go @@ -6,7 +6,6 @@ import ( "encoding/json" "errors" "fmt" - "net/http" "net/url" "path" "strings" @@ -26,10 +25,11 @@ type Driver struct { Policy *model.Policy AuthInstance auth.Auth - client Client + uploadClient Client } // NewDriver initializes a new Driver from policy +// TODO: refactor all method into upload client func NewDriver(policy *model.Policy) (*Driver, error) { client, err := NewClient(policy) if err != nil { @@ -40,12 +40,12 @@ func NewDriver(policy *model.Policy) (*Driver, error) { Policy: policy, Client: request.NewClient(), AuthInstance: auth.HMACAuth{[]byte(policy.SecretKey)}, - client: client, + uploadClient: client, }, nil } // List 列取文件 -func (handler Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) { +func (handler *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) { var res []response.Object reqBody := serializer.ListRequest{ @@ -87,7 +87,7 @@ func (handler Driver) List(ctx context.Context, path string, recursive bool) ([] } // getAPIUrl 获取接口请求地址 -func (handler Driver) getAPIUrl(scope string, routes ...string) string { +func (handler *Driver) getAPIUrl(scope string, routes ...string) string { serverURL, err := url.Parse(handler.Policy.Server) if err != nil { return "" @@ -113,7 +113,7 @@ func (handler Driver) getAPIUrl(scope string, routes ...string) string { } // Get 获取文件内容 -func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) { +func (handler *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) { // 尝试获取速度限制 speedLimit := 0 if user, ok := ctx.Value(fsctx.UserCtx).(model.User); ok { @@ -150,63 +150,15 @@ func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, } // Put 将文件流保存到指定目录 -func (handler Driver) Put(ctx context.Context, file fsctx.FileHeader) error { +func (handler *Driver) Put(ctx context.Context, file fsctx.FileHeader) error { defer file.Close() - // 凭证有效期 - credentialTTL := model.GetIntSetting("upload_credential_timeout", 3600) - - // 生成上传策略 - fileInfo := file.Info() - policy := serializer.UploadPolicy{ - SavePath: path.Dir(fileInfo.SavePath), - FileName: path.Base(fileInfo.FileName), - AutoRename: false, - MaxSize: fileInfo.Size, - } - credential, err := handler.getUploadCredential(ctx, policy, int64(credentialTTL)) - if err != nil { - return err - } - - // 对文件名进行URLEncode - fileName := url.QueryEscape(path.Base(fileInfo.SavePath)) - - // 决定是否要禁用文件覆盖 - overwrite := "false" - if fileInfo.Mode&fsctx.Overwrite == fsctx.Overwrite { - overwrite = "true" - } - - // 上传文件 - resp, err := handler.Client.Request( - "POST", - handler.Policy.GetUploadURL(), - file, - request.WithHeader(map[string][]string{ - "X-Cr-Policy": {credential.Policy}, - "X-Cr-FileName": {fileName}, - "X-Cr-Overwrite": {overwrite}, - }), - request.WithContentLength(int64(fileInfo.Size)), - request.WithTimeout(time.Duration(0)), - request.WithMasterMeta(), - request.WithSlaveMeta(handler.Policy.AccessKey), - request.WithCredential(handler.AuthInstance, int64(credentialTTL)), - ).CheckHTTPResponse(200).DecodeResponse() - if err != nil { - return err - } - if resp.Code != 0 { - return errors.New(resp.Msg) - } - - return nil + return handler.uploadClient.Upload(ctx, file) } // Delete 删除一个或多个文件, // 返回未删除的文件,及遇到的最后一个错误 -func (handler Driver) Delete(ctx context.Context, files []string) ([]string, error) { +func (handler *Driver) Delete(ctx context.Context, files []string) ([]string, error) { // 封装接口请求正文 reqBody := serializer.RemoteDeleteRequest{ Files: files, @@ -252,7 +204,7 @@ func (handler Driver) Delete(ctx context.Context, files []string) ([]string, err } // Thumb 获取文件缩略图 -func (handler Driver) Thumb(ctx context.Context, path string) (*response.ContentResponse, error) { +func (handler *Driver) Thumb(ctx context.Context, path string) (*response.ContentResponse, error) { sourcePath := base64.RawURLEncoding.EncodeToString([]byte(path)) thumbURL := handler.getAPIUrl("thumb") + "/" + sourcePath ttl := model.GetIntSetting("preview_timeout", 60) @@ -268,7 +220,7 @@ func (handler Driver) Thumb(ctx context.Context, path string) (*response.Content } // Source 获取外链URL -func (handler Driver) Source( +func (handler *Driver) Source( ctx context.Context, path string, baseURL url.URL, @@ -322,16 +274,21 @@ func (handler Driver) Source( } // Token 获取上传策略和认证Token -func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) { +func (handler *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) { + siteURL := model.GetSiteURL() + apiBaseURI, _ := url.Parse(path.Join("/api/v3/callback/remote", uploadSession.Key, uploadSession.CallbackSecret)) + apiURL := siteURL.ResolveReference(apiBaseURI) + // 在从机端创建上传会话 - if err := handler.client.CreateUploadSession(ctx, uploadSession, ttl); err != nil { + uploadSession.Callback = apiURL.String() + if err := handler.uploadClient.CreateUploadSession(ctx, uploadSession, ttl); err != nil { return nil, err } // 获取上传地址 - uploadURL, sign, err := handler.client.GetUploadURL(ttl, uploadSession.Key) + uploadURL, sign, err := handler.uploadClient.GetUploadURL(ttl, uploadSession.Key) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to sign upload url: %w", err) } return &serializer.UploadCredential{ @@ -342,30 +299,7 @@ func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *seria }, nil } -func (handler Driver) getUploadCredential(ctx context.Context, policy serializer.UploadPolicy, TTL int64) (serializer.UploadCredential, error) { - policyEncoded, err := policy.EncodeUploadPolicy() - if err != nil { - return serializer.UploadCredential{}, err - } - - // 签名上传策略 - uploadRequest, _ := http.NewRequest("POST", "/api/v3/slave/upload", nil) - uploadRequest.Header = map[string][]string{ - "X-Cr-Policy": {policyEncoded}, - "X-Cr-Overwrite": {"false"}, - } - auth.SignRequest(handler.AuthInstance, uploadRequest, TTL) - - if credential, ok := uploadRequest.Header["Authorization"]; ok && len(credential) == 1 { - return serializer.UploadCredential{ - Token: credential[0], - Policy: policyEncoded, - }, nil - } - return serializer.UploadCredential{}, errors.New("无法签名上传策略") -} - // 取消上传凭证 -func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error { - return nil +func (handler *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error { + return handler.uploadClient.DeleteUploadSession(ctx, uploadSession.Key) } diff --git a/pkg/filesystem/driver/shadow/slaveinmaster/handler.go b/pkg/filesystem/driver/shadow/slaveinmaster/handler.go index 9a79a1a7..7fc7b098 100644 --- a/pkg/filesystem/driver/shadow/slaveinmaster/handler.go +++ b/pkg/filesystem/driver/shadow/slaveinmaster/handler.go @@ -30,7 +30,7 @@ func NewDriver(node cluster.Node, handler driver.Handler, policy *model.Policy) var endpoint *url.URL if serverURL, err := url.Parse(node.DBModel().Server); err == nil { var controller *url.URL - controller, _ = url.Parse("/api/v3/slave") + controller, _ = url.Parse("/api/v3/slave/") endpoint = serverURL.ResolveReference(controller) } @@ -52,14 +52,10 @@ func NewDriver(node cluster.Node, handler driver.Handler, policy *model.Policy) func (d *Driver) Put(ctx context.Context, file fsctx.FileHeader) error { defer file.Close() - src, ok := ctx.Value(fsctx.SlaveSrcPath).(string) - if !ok { - return ErrSlaveSrcPathNotExist - } - + fileInfo := file.Info() req := serializer.SlaveTransferReq{ - Src: src, - Dst: file.Info().SavePath, + Src: fileInfo.Src, + Dst: fileInfo.SavePath, Policy: d.policy, } diff --git a/pkg/filesystem/filesystem.go b/pkg/filesystem/filesystem.go index 5c12ad11..48e03784 100644 --- a/pkg/filesystem/filesystem.go +++ b/pkg/filesystem/filesystem.go @@ -207,7 +207,7 @@ func NewFileSystemFromCallback(c *gin.Context) (*FileSystem, error) { } // 获取回调会话 - callbackSessionRaw, ok := c.Get("callbackSession") + callbackSessionRaw, ok := c.Get(UploadSessionCtx) if !ok { return nil, errors.New("找不到回调会话") } diff --git a/pkg/filesystem/fsctx/stream.go b/pkg/filesystem/fsctx/stream.go index 5d1ea760..c51d28c2 100644 --- a/pkg/filesystem/fsctx/stream.go +++ b/pkg/filesystem/fsctx/stream.go @@ -26,15 +26,18 @@ type UploadTaskInfo struct { UploadSessionID *string AppendStart uint64 Model interface{} + Src string } // FileHeader 上传来的文件数据处理器 type FileHeader interface { io.Reader io.Closer + io.Seeker Info() *UploadTaskInfo SetSize(uint64) SetModel(fileModel interface{}) + Seekable() bool } // FileStream 用户传来的文件 @@ -43,6 +46,7 @@ type FileStream struct { LastModified *time.Time Metadata map[string]string File io.ReadCloser + Seeker io.Seeker Size uint64 VirtualPath string Name string @@ -51,14 +55,31 @@ type FileStream struct { UploadSessionID *string AppendStart uint64 Model interface{} + Src string } func (file *FileStream) Read(p []byte) (n int, err error) { - return file.File.Read(p) + if file.File != nil { + return file.File.Read(p) + } + + return 0, io.EOF } func (file *FileStream) Close() error { - return file.File.Close() + if file.File != nil { + return file.File.Close() + } + + return nil +} + +func (file *FileStream) Seek(offset int64, whence int) (int64, error) { + return file.Seeker.Seek(offset, whence) +} + +func (file *FileStream) Seekable() bool { + return file.Seeker != nil } func (file *FileStream) Info() *UploadTaskInfo { @@ -74,6 +95,7 @@ func (file *FileStream) Info() *UploadTaskInfo { UploadSessionID: file.UploadSessionID, AppendStart: file.AppendStart, Model: file.Model, + Src: file.Src, } } diff --git a/pkg/filesystem/hooks.go b/pkg/filesystem/hooks.go index d2536ede..e73fa33a 100644 --- a/pkg/filesystem/hooks.go +++ b/pkg/filesystem/hooks.go @@ -2,13 +2,12 @@ package filesystem import ( "context" - "errors" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/cache" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/util" "io/ioutil" @@ -178,30 +177,28 @@ func GenericAfterUpdate(ctx context.Context, fs *FileSystem, newFile fsctx.FileH } // SlaveAfterUpload Slave模式下上传完成钩子 -func SlaveAfterUpload(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error { - return errors.New("") - policy := ctx.Value(fsctx.UploadPolicyCtx).(serializer.UploadPolicy) - fileInfo := fileHeader.Info() +func SlaveAfterUpload(session *serializer.UploadSession) Hook { + return func(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error { + fileInfo := fileHeader.Info() - // 构造一个model.File,用于生成缩略图 - file := model.File{ - Name: fileInfo.FileName, - SourceName: fileInfo.SavePath, - } - fs.GenerateThumbnail(ctx, &file) + // 构造一个model.File,用于生成缩略图 + file := model.File{ + Name: fileInfo.FileName, + SourceName: fileInfo.SavePath, + } + fs.GenerateThumbnail(ctx, &file) - if policy.CallbackURL == "" { - return nil - } + if session.Callback == "" { + return nil + } - // 发送回调请求 - callbackBody := serializer.UploadCallback{ - Name: file.Name, - SourceName: file.SourceName, - PicInfo: file.PicInfo, - Size: fileInfo.Size, + // 发送回调请求 + callbackBody := serializer.UploadCallback{ + PicInfo: file.PicInfo, + } + + return cluster.RemoteCallback(session.Callback, callbackBody) } - return request.RemoteCallback(policy.CallbackURL, callbackBody) } // GenericAfterUpload 文件上传完成后,包含数据库操作 @@ -288,12 +285,13 @@ func HookChunkUploadFailed(ctx context.Context, fs *FileSystem, fileHeader fsctx return fileInfo.Model.(*model.File).UpdateSize(fileInfo.AppendStart) } -// HookChunkUploadFinished 分片上传结束后处理文件 -func HookChunkUploadFinished(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error { - fileInfo := fileHeader.Info() - fileModel := fileInfo.Model.(*model.File) - - return fileModel.PopChunkToFile(fileInfo.LastModified) +// HookPopPlaceholderToFile 将占位文件提升为正式文件 +func HookPopPlaceholderToFile(picInfo string) Hook { + return func(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error { + fileInfo := fileHeader.Info() + fileModel := fileInfo.Model.(*model.File) + return fileModel.PopChunkToFile(fileInfo.LastModified, picInfo) + } } // HookChunkUploadFinished 分片上传结束后处理文件 diff --git a/pkg/filesystem/upload.go b/pkg/filesystem/upload.go index f8bec752..38d42eff 100644 --- a/pkg/filesystem/upload.go +++ b/pkg/filesystem/upload.go @@ -23,6 +23,8 @@ import ( const ( UploadSessionMetaKey = "upload_session" + UploadSessionCtx = "uploadSession" + UserCtx = "user" UploadSessionCachePrefix = "callback_" ) @@ -47,11 +49,11 @@ func (fs *FileSystem) Upload(ctx context.Context, file *fsctx.FileStream) (err e file.SavePath = savePath } - // 处理客户端未完成上传时,关闭连接 - go fs.CancelUpload(ctx, savePath, file) - // 保存文件 if file.Mode&fsctx.Nop != fsctx.Nop { + // 处理客户端未完成上传时,关闭连接 + go fs.CancelUpload(ctx, savePath, file) + err = fs.Handler.Put(ctx, file) if err != nil { fs.Trigger(ctx, "AfterUploadFailed", file) @@ -176,20 +178,21 @@ func (fs *FileSystem) CreateUploadSession(ctx context.Context, file *fsctx.FileS fs.Use("AfterUpload", HookClearFileHeaderSize) } - fs.Use("AfterUpload", GenericAfterUpload) + // 验证文件规格 if err := fs.Upload(ctx, file); err != nil { return nil, err } uploadSession := &serializer.UploadSession{ - Key: callbackKey, - UID: fs.User.ID, - Policy: *fs.Policy, - VirtualPath: file.VirtualPath, - Name: file.Name, - Size: fileSize, - SavePath: file.SavePath, - LastModified: file.LastModified, + Key: callbackKey, + UID: fs.User.ID, + Policy: *fs.Policy, + VirtualPath: file.VirtualPath, + Name: file.Name, + Size: fileSize, + SavePath: file.SavePath, + LastModified: file.LastModified, + CallbackSecret: util.RandStringRunes(32), } // 获取上传凭证 @@ -198,10 +201,16 @@ func (fs *FileSystem) CreateUploadSession(ctx context.Context, file *fsctx.FileS return nil, err } + // 创建占位符 + fs.Use("AfterUpload", GenericAfterUpload) + if err := fs.Upload(ctx, file); err != nil { + return nil, err + } + // 创建回调会话 err = cache.Set( UploadSessionCachePrefix+callbackKey, - uploadSession, + *uploadSession, callBackSessionTTL, ) if err != nil { @@ -215,7 +224,16 @@ func (fs *FileSystem) CreateUploadSession(ctx context.Context, file *fsctx.FileS } // UploadFromStream 从文件流上传文件 -func (fs *FileSystem) UploadFromStream(ctx context.Context, file *fsctx.FileStream) error { +func (fs *FileSystem) UploadFromStream(ctx context.Context, file *fsctx.FileStream, resetPolicy bool) error { + if resetPolicy { + // 重设存储策略 + fs.Policy = &fs.User.Policy + err := fs.DispatchHandler() + if err != nil { + return err + } + } + // 给文件系统分配钩子 fs.Lock.Lock() if fs.Hooks == nil { @@ -233,16 +251,7 @@ func (fs *FileSystem) UploadFromStream(ctx context.Context, file *fsctx.FileStre } // UploadFromPath 将本机已有文件上传到用户的文件系统 -func (fs *FileSystem) UploadFromPath(ctx context.Context, src, dst string, resetPolicy bool, mode fsctx.WriteMode) error { - // 重设存储策略 - if resetPolicy { - fs.Policy = &fs.User.Policy - err := fs.DispatchHandler() - if err != nil { - return err - } - } - +func (fs *FileSystem) UploadFromPath(ctx context.Context, src, dst string, mode fsctx.WriteMode) error { file, err := os.Open(util.RelativePath(src)) if err != nil { return err @@ -258,10 +267,11 @@ func (fs *FileSystem) UploadFromPath(ctx context.Context, src, dst string, reset // 开始上传 return fs.UploadFromStream(ctx, &fsctx.FileStream{ - File: nil, + File: file, + Seeker: file, Size: uint64(size), Name: path.Base(dst), VirtualPath: path.Dir(dst), Mode: mode, - }) + }, true) } diff --git a/pkg/request/options.go b/pkg/request/options.go index d4957571..bb9b11cd 100644 --- a/pkg/request/options.go +++ b/pkg/request/options.go @@ -5,6 +5,7 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/auth" "net/http" "net/url" + "strings" "time" ) @@ -103,6 +104,10 @@ func WithSlaveMeta(s string) Option { // Endpoint 使用同一的请求Endpoint func WithEndpoint(endpoint string) Option { + if !strings.HasSuffix(endpoint, "/") { + endpoint += "/" + } + endpointURL, _ := url.Parse(endpoint) return optionFunc(func(o *options) { o.endpoint = endpointURL diff --git a/pkg/request/request.go b/pkg/request/request.go index 9d4249b0..eb7996c5 100644 --- a/pkg/request/request.go +++ b/pkg/request/request.go @@ -7,7 +7,7 @@ import ( "io" "io/ioutil" "net/http" - "path" + "net/url" "strings" "sync" @@ -51,7 +51,7 @@ func NewClient(opts ...Option) Client { } // Request 发送HTTP请求 -func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Option) *Response { +func (c *HTTPClient) Request(method, target string, body io.Reader, opts ...Option) *Response { // 应用额外设置 c.mu.Lock() options := *c.options @@ -70,9 +70,13 @@ func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Optio // 确定请求URL if options.endpoint != nil { + targetPath, err := url.Parse(target) + if err != nil { + return &Response{Err: err} + } + targetURL := *options.endpoint - targetURL.Path = path.Join(targetURL.Path, target) - target = targetURL.String() + target = targetURL.ResolveReference(targetPath).String() } // 创建请求 diff --git a/pkg/request/slave.go b/pkg/request/slave.go deleted file mode 100644 index 29482500..00000000 --- a/pkg/request/slave.go +++ /dev/null @@ -1,52 +0,0 @@ -package request - -import ( - "bytes" - "encoding/json" - "errors" - "time" - - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" -) - -// TODO: move to slave pkg -// RemoteCallback 发送远程存储策略上传回调请求 -func RemoteCallback(url string, body serializer.UploadCallback) error { - callbackBody, err := json.Marshal(struct { - Data serializer.UploadCallback `json:"data"` - }{ - Data: body, - }) - if err != nil { - return serializer.NewError(serializer.CodeCallbackError, "无法编码回调正文", err) - } - - resp := GeneralClient.Request( - "POST", - url, - bytes.NewReader(callbackBody), - WithTimeout(time.Duration(conf.SlaveConfig.CallbackTimeout)*time.Second), - WithCredential(auth.General, int64(conf.SlaveConfig.SignatureTTL)), - ) - - if resp.Err != nil { - return serializer.NewError(serializer.CodeCallbackError, "无法发起回调请求", resp.Err) - } - - // 解析回调服务端响应 - resp = resp.CheckHTTPResponse(200) - if resp.Err != nil { - return serializer.NewError(serializer.CodeCallbackError, "服务器返回异常响应", resp.Err) - } - response, err := resp.DecodeResponse() - if err != nil { - return serializer.NewError(serializer.CodeCallbackError, "无法解析服务端返回的响应", err) - } - if response.Code != 0 { - return serializer.NewError(response.Code, response.Msg, errors.New(response.Error)) - } - - return nil -} diff --git a/pkg/request/slave_test.go b/pkg/request/slave_test.go deleted file mode 100644 index a7f9fcff..00000000 --- a/pkg/request/slave_test.go +++ /dev/null @@ -1,137 +0,0 @@ -package request - -import ( - "bytes" - "encoding/json" - "errors" - "io/ioutil" - "net/http" - "strings" - "testing" - - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" -) - -func TestRemoteCallback(t *testing.T) { - asserts := assert.New(t) - - // 回调成功 - { - clientMock := ClientMock{} - mockResp, _ := json.Marshal(serializer.Response{Code: 0}) - clientMock.On( - "Request", - "POST", - "http://test/test/url", - testMock.Anything, - testMock.Anything, - ).Return(&Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(bytes.NewReader(mockResp)), - }, - }) - GeneralClient = clientMock - resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ - SourceName: "source", - }) - asserts.NoError(resp) - clientMock.AssertExpectations(t) - } - - // 服务端返回业务错误 - { - clientMock := ClientMock{} - mockResp, _ := json.Marshal(serializer.Response{Code: 401}) - clientMock.On( - "Request", - "POST", - "http://test/test/url", - testMock.Anything, - testMock.Anything, - ).Return(&Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(bytes.NewReader(mockResp)), - }, - }) - GeneralClient = clientMock - resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ - SourceName: "source", - }) - asserts.EqualValues(401, resp.(serializer.AppError).Code) - clientMock.AssertExpectations(t) - } - - // 无法解析回调响应 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - "http://test/test/url", - testMock.Anything, - testMock.Anything, - ).Return(&Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("mockResp")), - }, - }) - GeneralClient = clientMock - resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ - SourceName: "source", - }) - asserts.Error(resp) - clientMock.AssertExpectations(t) - } - - // HTTP状态码非200 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - "http://test/test/url", - testMock.Anything, - testMock.Anything, - ).Return(&Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 404, - Body: ioutil.NopCloser(strings.NewReader("mockResp")), - }, - }) - GeneralClient = clientMock - resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ - SourceName: "source", - }) - asserts.Error(resp) - clientMock.AssertExpectations(t) - } - - // 无法发起回调 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - "http://test/test/url", - testMock.Anything, - testMock.Anything, - ).Return(&Response{ - Err: errors.New("error"), - }) - GeneralClient = clientMock - resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ - SourceName: "source", - }) - asserts.Error(resp) - clientMock.AssertExpectations(t) - } -} diff --git a/pkg/serializer/error.go b/pkg/serializer/error.go index 51e2b7c5..910f3f85 100644 --- a/pkg/serializer/error.go +++ b/pkg/serializer/error.go @@ -54,6 +54,8 @@ const ( CodeNoPermissionErr = 403 // CodeNotFound 资源未找到 CodeNotFound = 404 + // CodeConflict 资源冲突 + CodeConflict = 409 // CodeUploadFailed 上传出错 CodeUploadFailed = 40002 // CodeCredentialInvalid 凭证无效 diff --git a/pkg/serializer/upload.go b/pkg/serializer/upload.go index b1f78a1e..225c4934 100644 --- a/pkg/serializer/upload.go +++ b/pkg/serializer/upload.go @@ -37,22 +37,21 @@ type UploadCredential struct { // UploadSession 上传会话 type UploadSession struct { - Key string // 上传会话 GUID - UID uint // 发起者 - VirtualPath string // 用户文件路径,不含文件名 - Name string // 文件名 - Size uint64 // 文件大小 - SavePath string // 物理存储路径,包含物理文件名 - LastModified *time.Time // 可选的文件最后修改日期 - Policy model.Policy + Key string // 上传会话 GUID + UID uint // 发起者 + VirtualPath string // 用户文件路径,不含文件名 + Name string // 文件名 + Size uint64 // 文件大小 + SavePath string // 物理存储路径,包含物理文件名 + LastModified *time.Time // 可选的文件最后修改日期 + Policy model.Policy + Callback string // 回调 URL 地址 + CallbackSecret string // 回调 URL } // UploadCallback 上传回调正文 type UploadCallback struct { - Name string `json:"name"` - SourceName string `json:"source_name"` - PicInfo string `json:"pic_info"` - Size uint64 `json:"size"` + PicInfo string `json:"pic_info"` } // GeneralUploadCallbackFailed 存储策略上传回调失败响应 diff --git a/pkg/task/compress.go b/pkg/task/compress.go index d4a34003..f7314ec3 100644 --- a/pkg/task/compress.go +++ b/pkg/task/compress.go @@ -106,7 +106,7 @@ func (job *CompressTask) Do() { job.TaskModel.SetProgress(TransferringProgress) // 上传文件 - err = fs.UploadFromPath(ctx, zipFile, job.TaskProps.Dst, true, 0) + err = fs.UploadFromPath(ctx, zipFile, job.TaskProps.Dst, 0) if err != nil { job.SetErrorMsg(err.Error()) return diff --git a/pkg/task/tranfer.go b/pkg/task/tranfer.go index 4983c4f1..b8023245 100644 --- a/pkg/task/tranfer.go +++ b/pkg/task/tranfer.go @@ -117,16 +117,18 @@ func (job *TransferTask) Do() { } // 切换为从机节点处理上传 + fs.SetPolicyFromPath(path.Dir(dst)) fs.SwitchToSlaveHandler(node) err = fs.UploadFromStream(context.Background(), &fsctx.FileStream{ File: nil, Size: job.TaskProps.SrcSizes[file], Name: path.Base(dst), VirtualPath: path.Dir(dst), - }) + Src: file, + }, false) } else { // 主机节点中转 - err = fs.UploadFromPath(context.Background(), file, dst, true, 0) + err = fs.UploadFromPath(context.Background(), file, dst, 0) } if err != nil { diff --git a/pkg/task/worker.go b/pkg/task/worker.go index 07ef36bf..3e01f17f 100644 --- a/pkg/task/worker.go +++ b/pkg/task/worker.go @@ -1,6 +1,9 @@ package task -import "github.com/cloudreve/Cloudreve/v3/pkg/util" +import ( + "fmt" + "github.com/cloudreve/Cloudreve/v3/pkg/util" +) // Worker 处理任务的对象 type Worker interface { @@ -20,7 +23,7 @@ func (worker *GeneralWorker) Do(job Job) { // 致命错误捕获 if err := recover(); err != nil { util.Log().Debug("任务执行出错,%s", err) - job.SetError(&JobError{Msg: "致命错误"}) + job.SetError(&JobError{Msg: "致命错误", Error: fmt.Sprintf("%s", err)}) job.SetStatus(Error) } }() diff --git a/routers/controllers/site.go b/routers/controllers/site.go index f74534dc..b066397c 100644 --- a/routers/controllers/site.go +++ b/routers/controllers/site.go @@ -42,9 +42,14 @@ func SiteConfig(c *gin.Context) { // Ping 状态检查页面 func Ping(c *gin.Context) { + version := conf.BackendVersion + if conf.IsPro == "true" { + version += "-pro" + } + c.JSON(200, serializer.Response{ Code: 0, - Data: conf.BackendVersion, + Data: conf.BackendVersion + conf.IsPro, }) } diff --git a/routers/controllers/slave.go b/routers/controllers/slave.go index 0b18f919..5a0d2774 100644 --- a/routers/controllers/slave.go +++ b/routers/controllers/slave.go @@ -28,77 +28,6 @@ func SlaveUpload(c *gin.Context) { } else { c.JSON(200, ErrorResponse(err)) } - - //// 创建上下文 - //ctx, cancel := context.WithCancel(context.Background()) - //ctx = context.WithValue(ctx, fsctx.GinCtx, c) - //defer cancel() - // - //// 创建匿名文件系统 - //fs, err := filesystem.NewAnonymousFileSystem() - //if err != nil { - // c.JSON(200, serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err)) - // return - //} - //fs.Handler = local.Driver{} - // - //// 从请求中取得上传策略 - //uploadPolicyRaw := c.GetHeader("X-Cr-Policy") - //if uploadPolicyRaw == "" { - // c.JSON(200, serializer.ParamErr("未指定上传策略", nil)) - // return - //} - // - //// 解析上传策略 - //uploadPolicy, err := serializer.DecodeUploadPolicy(uploadPolicyRaw) - //if err != nil { - // c.JSON(200, serializer.ParamErr("上传策略格式有误", err)) - // return - //} - //ctx = context.WithValue(ctx, fsctx.UploadPolicyCtx, *uploadPolicy) - // - //// 取得文件大小 - //fileSize, err := strconv.ParseUint(c.Request.Header.Get("Content-Length"), 10, 64) - //if err != nil { - // c.JSON(200, ErrorResponse(err)) - // return - //} - // - //// 解码文件名和路径 - //fileName, err := url.QueryUnescape(c.Request.Header.Get("X-Cr-FileName")) - //if err != nil { - // c.JSON(200, ErrorResponse(err)) - // return - //} - // - //fileData := fsctx.FileStream{ - // MIMEType: c.Request.Header.Get("Content-Type"), - // File: c.Request.Body, - // Name: fileName, - // Size: fileSize, - //} - // - //// 给文件系统分配钩子 - //fs.Use("BeforeUpload", filesystem.HookSlaveUploadValidate) - //fs.Use("AfterUploadCanceled", filesystem.HookDeleteTempFile) - //fs.Use("AfterUpload", filesystem.SlaveAfterUpload) - //fs.Use("AfterValidateFailed", filesystem.HookDeleteTempFile) - // - ////// 是否允许覆盖 - ////if c.Request.Header.Get("X-Cr-Overwrite") == "false" { - //// fileData.Mode = fsctx.Create - ////} - // - //// 执行上传 - //err = fs.LocalUpload(ctx, &fileData) - //if err != nil { - // c.JSON(200, serializer.Err(serializer.CodeUploadFailed, err.Error(), err)) - // return - //} - // - //c.JSON(200, serializer.Response{ - // Code: 0, - //}) } // SlaveGetUploadSession 从机创建上传会话 @@ -116,6 +45,21 @@ func SlaveGetUploadSession(c *gin.Context) { } } +// SlaveDeleteUploadSession 从机删除上传会话 +func SlaveDeleteUploadSession(c *gin.Context) { + // 创建上下文 + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var service explorer.UploadSessionService + if err := c.ShouldBindUri(&service); err == nil { + res := service.SlaveDelete(ctx, c) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + // SlaveDownload 从机文件下载,此请求返回的HTTP状态码不全为200 func SlaveDownload(c *gin.Context) { // 创建上下文 diff --git a/routers/router.go b/routers/router.go index 2b9052d6..21f42ae8 100644 --- a/routers/router.go +++ b/routers/router.go @@ -46,9 +46,15 @@ func InitSlaveRouter() *gin.Engine { // 接收主机心跳包 v3.POST("heartbeat", controllers.SlaveHeartbeat) // 上传 - v3.POST("upload/:sessionId", controllers.SlaveUpload) - // 创建上传会话上传 - v3.PUT("upload", controllers.SlaveGetUploadSession) + upload := v3.Group("upload") + { + // 上传分片 + upload.POST(":sessionId", controllers.SlaveUpload) + // 创建上传会话上传 + upload.PUT("", controllers.SlaveGetUploadSession) + // 删除上传会话 + upload.DELETE(":sessionId", controllers.SlaveDeleteUploadSession) + } // 下载 v3.GET("download/:speed/:path/:name", controllers.SlaveDownload) // 预览 / 外链 @@ -213,7 +219,15 @@ func InitMasterRouter() *gin.Engine { // 事件通知 slave.PUT("notification/:subject", controllers.SlaveNotificationPush) // 上传 - slave.POST("upload", controllers.SlaveUpload) + upload := slave.Group("upload") + { + // 上传分片 + upload.POST(":sessionId", controllers.SlaveUpload) + // 创建上传会话上传 + upload.PUT("", controllers.SlaveGetUploadSession) + // 删除上传会话 + upload.DELETE(":sessionId", controllers.SlaveDeleteUploadSession) + } // OneDrive 存储策略凭证 slave.GET("credential/onedrive/:id", controllers.SlaveGetOneDriveCredential) } @@ -223,7 +237,8 @@ func InitMasterRouter() *gin.Engine { { // 远程策略上传回调 callback.POST( - "remote/:key", + "remote/:sessionID/:key", + middleware.UseUploadSession("remote"), middleware.RemoteCallbackAuth(), controllers.RemoteCallback, ) diff --git a/service/callback/upload.go b/service/callback/upload.go index 54bdf74a..df5d1807 100644 --- a/service/callback/upload.go +++ b/service/callback/upload.go @@ -3,6 +3,7 @@ package callback import ( "context" "fmt" + model "github.com/cloudreve/Cloudreve/v3/models" "strings" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" @@ -11,13 +12,12 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/s3" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/gin-gonic/gin" ) // CallbackProcessService 上传请求回调正文接口 type CallbackProcessService interface { - GetBody(*serializer.UploadSession) serializer.UploadCallback + GetBody() serializer.UploadCallback } // RemoteUploadCallbackService 远程存储上传回调请求服务 @@ -26,7 +26,7 @@ type RemoteUploadCallbackService struct { } // GetBody 返回回调正文 -func (service RemoteUploadCallbackService) GetBody(session *serializer.UploadSession) serializer.UploadCallback { +func (service RemoteUploadCallbackService) GetBody() serializer.UploadCallback { return service.Data } @@ -68,12 +68,8 @@ type S3Callback struct { } // GetBody 返回回调正文 -func (service UpyunCallbackService) GetBody(session *serializer.UploadSession) serializer.UploadCallback { - res := serializer.UploadCallback{ - Name: session.Name, - SourceName: service.SourceName, - Size: service.Size, - } +func (service UpyunCallbackService) GetBody() serializer.UploadCallback { + res := serializer.UploadCallback{} if service.Width != "" { res.PicInfo = service.Width + "," + service.Height } @@ -82,51 +78,41 @@ func (service UpyunCallbackService) GetBody(session *serializer.UploadSession) s } // GetBody 返回回调正文 -func (service UploadCallbackService) GetBody(session *serializer.UploadSession) serializer.UploadCallback { +func (service UploadCallbackService) GetBody() serializer.UploadCallback { return serializer.UploadCallback{ - Name: service.Name, - SourceName: service.SourceName, - PicInfo: service.PicInfo, - Size: service.Size, + PicInfo: service.PicInfo, } } // GetBody 返回回调正文 -func (service OneDriveCallback) GetBody(session *serializer.UploadSession) serializer.UploadCallback { +func (service OneDriveCallback) GetBody() serializer.UploadCallback { var picInfo = "0,0" if service.Meta.Image.Width != 0 { picInfo = fmt.Sprintf("%d,%d", service.Meta.Image.Width, service.Meta.Image.Height) } return serializer.UploadCallback{ - Name: session.Name, - SourceName: session.SavePath, - PicInfo: picInfo, - Size: session.Size, + PicInfo: picInfo, } } // GetBody 返回回调正文 -func (service COSCallback) GetBody(session *serializer.UploadSession) serializer.UploadCallback { +func (service COSCallback) GetBody() serializer.UploadCallback { return serializer.UploadCallback{ - Name: session.Name, - SourceName: session.SavePath, - PicInfo: "", - Size: session.Size, + PicInfo: "", } } // GetBody 返回回调正文 -func (service S3Callback) GetBody(session *serializer.UploadSession) serializer.UploadCallback { +func (service S3Callback) GetBody() serializer.UploadCallback { return serializer.UploadCallback{ - Name: session.Name, - SourceName: session.SavePath, - PicInfo: "", - Size: session.Size, + PicInfo: "", } } // ProcessCallback 处理上传结果回调 func ProcessCallback(service CallbackProcessService, c *gin.Context) serializer.Response { + callbackBody := service.GetBody() + // 创建文件系统 fs, err := filesystem.NewFileSystemFromCallback(c) if err != nil { @@ -134,51 +120,39 @@ func ProcessCallback(service CallbackProcessService, c *gin.Context) serializer. } defer fs.Recycle() - // 获取回调会话 - callbackSessionRaw, _ := c.Get("callbackSession") - callbackSession := callbackSessionRaw.(*serializer.UploadSession) - callbackBody := service.GetBody(callbackSession) - - // 获取父目录 - exist, parentFolder := fs.IsPathExist(callbackSession.VirtualPath) - if !exist { - newFolder, err := fs.CreateDirectory(context.Background(), callbackSession.VirtualPath) - if err != nil { - return serializer.Err(serializer.CodeParamErr, "指定目录不存在", err) - } - parentFolder = newFolder + // 获取上传会话 + uploadSession := c.MustGet(filesystem.UploadSessionCtx).(*serializer.UploadSession) + + // 查找上传会话创建的占位文件 + file, err := model.GetFilesByUploadSession(uploadSession.Key, fs.User.ID) + if err != nil { + return serializer.Err(serializer.CodeUploadSessionExpired, "LocalUpload session file placeholder not exist", err) } - // 创建文件头 - fileHeader := fsctx.FileStream{ - Size: callbackBody.Size, - VirtualPath: callbackSession.VirtualPath, - Name: callbackSession.Name, - SavePath: callbackBody.SourceName, + fileData := fsctx.FileStream{ + Size: uploadSession.Size, + Name: uploadSession.Name, + VirtualPath: uploadSession.VirtualPath, + SavePath: uploadSession.SavePath, + Mode: fsctx.Nop, + Model: file, + LastModified: uploadSession.LastModified, } - // 添加钩子 - fs.Use("BeforeAddFile", filesystem.HookValidateFile) - fs.Use("BeforeAddFile", filesystem.HookValidateCapacity) - fs.Use("AfterValidateFailed", filesystem.HookDeleteTempFile) - fs.Use("BeforeAddFileFailed", filesystem.HookDeleteTempFile) + // 占位符未扣除容量需要校验和扣除 + if !fs.Policy.IsUploadPlaceholderWithSize() { + fs.Use("AfterUpload", filesystem.HookValidateCapacity) + fs.Use("AfterUpload", filesystem.HookChunkUploaded) + } - // 向数据库中添加文件 - file, err := fs.AddFile(context.Background(), parentFolder, &fileHeader) + fs.Use("AfterUpload", filesystem.HookPopPlaceholderToFile(callbackBody.PicInfo)) + fs.Use("AfterValidateFailed", filesystem.HookDeleteTempFile) + err = fs.Upload(context.Background(), &fileData) if err != nil { return serializer.Err(serializer.CodeUploadFailed, err.Error(), err) } - // 如果是图片,则更新图片信息 - if callbackBody.PicInfo != "" { - if err := file.UpdatePicInfo(callbackBody.PicInfo); err != nil { - util.Log().Debug("无法更新回调文件的图片信息:%s", err) - } - } - - return serializer.Response{ - Code: 0, - } + return serializer.Response{} } // PreProcess 对OneDrive客户端回调进行预处理验证 diff --git a/service/explorer/slave.go b/service/explorer/slave.go index 0cd45aa4..32180b85 100644 --- a/service/explorer/slave.go +++ b/service/explorer/slave.go @@ -13,6 +13,7 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/task" "github.com/cloudreve/Cloudreve/v3/pkg/task/slavetask" + "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/gin-gonic/gin" "github.com/jinzhu/gorm" "net/http" @@ -172,6 +173,10 @@ type SlaveCreateUploadSessionService struct { // Create 从机创建上传会话 func (service *SlaveCreateUploadSessionService) Create(ctx context.Context, c *gin.Context) serializer.Response { + if util.Exists(service.Session.SavePath) { + return serializer.Err(serializer.CodeConflict, "placeholder file already exist", nil) + } + err := cache.Set( filesystem.UploadSessionCachePrefix+service.Session.Key, service.Session, diff --git a/service/explorer/upload.go b/service/explorer/upload.go index 6847e03b..ced48411 100644 --- a/service/explorer/upload.go +++ b/service/explorer/upload.go @@ -7,6 +7,7 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/hashid" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" @@ -87,13 +88,13 @@ func (service *UploadService) LocalUpload(ctx context.Context, c *gin.Context) s } if uploadSession.UID != fs.User.ID { - return serializer.Err(serializer.CodeUploadSessionExpired, "LocalUpload session expired or not exist", nil) + return serializer.Err(serializer.CodeUploadSessionExpired, "Local upload session expired or not exist", nil) } // 查找上传会话创建的占位文件 file, err := model.GetFilesByUploadSession(service.ID, fs.User.ID) if err != nil { - return serializer.Err(serializer.CodeUploadSessionExpired, "LocalUpload session file placeholder not exist", err) + return serializer.Err(serializer.CodeUploadSessionExpired, "Local upload session file placeholder not exist", err) } // 重设 fs 存储策略 @@ -120,14 +121,14 @@ func (service *UploadService) LocalUpload(ctx context.Context, c *gin.Context) s util.Log().Info("尝试上传覆盖分片[%d] Start=%d", service.Index, actualSizeStart) } - return processChunkUpload(ctx, c, fs, &uploadSession, service.Index, file, fsctx.Append|fsctx.Overwrite) + return processChunkUpload(ctx, c, fs, &uploadSession, service.Index, file, fsctx.Append) } // SlaveUpload 处理从机文件分片上传 func (service *UploadService) SlaveUpload(ctx context.Context, c *gin.Context) serializer.Response { uploadSessionRaw, ok := cache.Get(filesystem.UploadSessionCachePrefix + service.ID) if !ok { - return serializer.Err(serializer.CodeUploadSessionExpired, "LocalUpload session expired or not exist", nil) + return serializer.Err(serializer.CodeUploadSessionExpired, "Slave upload session expired or not exist", nil) } uploadSession := uploadSessionRaw.(serializer.UploadSession) @@ -137,6 +138,8 @@ func (service *UploadService) SlaveUpload(ctx context.Context, c *gin.Context) s return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) } + fs.Handler = local.Driver{} + // 解析需要的参数 service.Index, _ = strconv.Atoi(c.Query("chunk")) mode := fsctx.Append @@ -165,6 +168,11 @@ func processChunkUpload(ctx context.Context, c *gin.Context, fs *filesystem.File ) } + // 非首个分片时需要允许覆盖 + if index > 0 { + mode |= fsctx.Overwrite + } + fileData := fsctx.FileStream{ MIMEType: c.Request.Header.Get("Content-Type"), File: c.Request.Body, @@ -187,13 +195,14 @@ func processChunkUpload(ctx context.Context, c *gin.Context, fs *filesystem.File fs.Use("AfterUpload", filesystem.HookChunkUploaded) fs.Use("AfterValidateFailed", filesystem.HookChunkUploadFailed) if isLastChunk { - fs.Use("AfterUpload", filesystem.HookChunkUploadFinished) + fs.Use("AfterUpload", filesystem.HookPopPlaceholderToFile("")) fs.Use("AfterUpload", filesystem.HookGenerateThumb) fs.Use("AfterUpload", filesystem.HookDeleteUploadSession(session.Key)) } } else { if isLastChunk { - fs.Use("AfterUpload", filesystem.SlaveAfterUpload) + fs.Use("AfterUpload", filesystem.SlaveAfterUpload(session)) + fs.Use("AfterUpload", filesystem.HookDeleteUploadSession(session.Key)) } } @@ -224,7 +233,7 @@ func (service *UploadSessionService) Delete(ctx context.Context, c *gin.Context) // 查找需要删除的上传会话的占位文件 file, err := model.GetFilesByUploadSession(service.ID, fs.User.ID) if err != nil { - return serializer.Err(serializer.CodeUploadSessionExpired, "LocalUpload session file placeholder not exist", err) + return serializer.Err(serializer.CodeUploadSessionExpired, "Local Upload session file placeholder not exist", err) } // 删除文件 @@ -235,6 +244,28 @@ func (service *UploadSessionService) Delete(ctx context.Context, c *gin.Context) return serializer.Response{} } +// SlaveDelete 从机删除指定上传会话 +func (service *UploadSessionService) SlaveDelete(ctx context.Context, c *gin.Context) serializer.Response { + // 创建文件系统 + fs, err := filesystem.NewAnonymousFileSystem() + if err != nil { + return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) + } + defer fs.Recycle() + + session, ok := cache.Get(filesystem.UploadSessionCachePrefix + service.ID) + if !ok { + return serializer.Err(serializer.CodeUploadSessionExpired, "Slave Upload session file placeholder not exist", nil) + } + + if _, err := fs.Handler.Delete(ctx, []string{session.(serializer.UploadSession).SavePath}); err != nil { + return serializer.Err(serializer.CodeInternalSetting, "Failed to delete temp file", err) + } + + cache.Deletes([]string{service.ID}, filesystem.UploadSessionCachePrefix) + return serializer.Response{} +} + // DeleteAllUploadSession 删除当前用户的全部上传绘会话 func DeleteAllUploadSession(ctx context.Context, c *gin.Context) serializer.Response { // 创建文件系统