diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 9f4eb13..510e481 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -26,7 +26,7 @@ jobs: - name: Get dependencies and build run: | - go get github.com/rakyll/statik + go install github.com/rakyll/statik export PATH=$PATH:~/go/bin/ statik -src=models -f sudo apt-get update diff --git a/pkg/cluster/slave.go b/pkg/cluster/slave.go index ac8a46a..7f5bd20 100644 --- a/pkg/cluster/slave.go +++ b/pkg/cluster/slave.go @@ -1,11 +1,14 @@ package cluster import ( + "bytes" "encoding/json" + "errors" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/auth" + "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/util" @@ -408,3 +411,43 @@ func getAria2RequestBody(body *serializer.SlaveAria2Call) (io.Reader, error) { return strings.NewReader(string(reqBodyEncoded)), nil } + +// TODO: move to slave pkg +// RemoteCallback 发送远程存储策略上传回调请求 +func RemoteCallback(url string, body serializer.UploadCallback) error { + callbackBody, err := json.Marshal(struct { + Data serializer.UploadCallback `json:"data"` + }{ + Data: body, + }) + if err != nil { + return serializer.NewError(serializer.CodeCallbackError, "无法编码回调正文", err) + } + + resp := request.GeneralClient.Request( + "POST", + url, + bytes.NewReader(callbackBody), + request.WithTimeout(time.Duration(conf.SlaveConfig.CallbackTimeout)*time.Second), + request.WithCredential(auth.General, int64(conf.SlaveConfig.SignatureTTL)), + ) + + if resp.Err != nil { + return serializer.NewError(serializer.CodeCallbackError, "从机无法发起回调请求", resp.Err) + } + + // 解析回调服务端响应 + resp = resp.CheckHTTPResponse(200) + if resp.Err != nil { + return serializer.NewError(serializer.CodeCallbackError, "主机服务器返回异常响应", resp.Err) + } + response, err := resp.DecodeResponse() + if err != nil { + return serializer.NewError(serializer.CodeCallbackError, "从机无法解析主机返回的响应", err) + } + if response.Code != 0 { + return serializer.NewError(response.Code, response.Msg, errors.New(response.Error)) + } + + return nil +} diff --git a/pkg/cluster/slave_test.go b/pkg/cluster/slave_test.go index 0b70caa..2580936 100644 --- a/pkg/cluster/slave_test.go +++ b/pkg/cluster/slave_test.go @@ -441,3 +441,125 @@ func TestSlaveCaller_DeleteTempFile(t *testing.T) { a.NoError(err) } } + +//func TestRemoteCallback(t *testing.T) { +// asserts := assert.New(t) +// +// // 回调成功 +// { +// clientMock := request.ClientMock{} +// mockResp, _ := json.Marshal(serializer.Response{Code: 0}) +// clientMock.On( +// "Request", +// "POST", +// "http://test/test/url", +// testMock.Anything, +// testMock.Anything, +// ).Return(&request.Response{ +// Err: nil, +// Response: &http.Response{ +// StatusCode: 200, +// Body: ioutil.NopCloser(bytes.NewReader(mockResp)), +// }, +// }) +// request.GeneralClient = clientMock +// resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ +// SourceName: "source", +// }) +// asserts.NoError(resp) +// clientMock.AssertExpectations(t) +// } +// +// // 服务端返回业务错误 +// { +// clientMock := request.ClientMock{} +// mockResp, _ := json.Marshal(serializer.Response{Code: 401}) +// clientMock.On( +// "Request", +// "POST", +// "http://test/test/url", +// testMock.Anything, +// testMock.Anything, +// ).Return(&request.Response{ +// Err: nil, +// Response: &http.Response{ +// StatusCode: 200, +// Body: ioutil.NopCloser(bytes.NewReader(mockResp)), +// }, +// }) +// request.GeneralClient = clientMock +// resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ +// SourceName: "source", +// }) +// asserts.EqualValues(401, resp.(serializer.AppError).Code) +// clientMock.AssertExpectations(t) +// } +// +// // 无法解析回调响应 +// { +// clientMock := request.ClientMock{} +// clientMock.On( +// "Request", +// "POST", +// "http://test/test/url", +// testMock.Anything, +// testMock.Anything, +// ).Return(&request.Response{ +// Err: nil, +// Response: &http.Response{ +// StatusCode: 200, +// Body: ioutil.NopCloser(strings.NewReader("mockResp")), +// }, +// }) +// request.GeneralClient = clientMock +// resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ +// SourceName: "source", +// }) +// asserts.Error(resp) +// clientMock.AssertExpectations(t) +// } +// +// // HTTP状态码非200 +// { +// clientMock := request.ClientMock{} +// clientMock.On( +// "Request", +// "POST", +// "http://test/test/url", +// testMock.Anything, +// testMock.Anything, +// ).Return(&request.Response{ +// Err: nil, +// Response: &http.Response{ +// StatusCode: 404, +// Body: ioutil.NopCloser(strings.NewReader("mockResp")), +// }, +// }) +// request.GeneralClient = clientMock +// resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ +// SourceName: "source", +// }) +// asserts.Error(resp) +// clientMock.AssertExpectations(t) +// } +// +// // 无法发起回调 +// { +// clientMock := request.ClientMock{} +// clientMock.On( +// "Request", +// "POST", +// "http://test/test/url", +// testMock.Anything, +// testMock.Anything, +// ).Return(&request.Response{ +// Err: errors.New("error"), +// }) +// request.GeneralClient = clientMock +// resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ +// SourceName: "source", +// }) +// asserts.Error(resp) +// clientMock.AssertExpectations(t) +// } +//} diff --git a/pkg/filesystem/driver/remote/client.go b/pkg/filesystem/driver/remote/client.go index 632333d..d153ab3 100644 --- a/pkg/filesystem/driver/remote/client.go +++ b/pkg/filesystem/driver/remote/client.go @@ -20,7 +20,9 @@ const ( // Client to operate remote slave server type Client interface { + // CreateUploadSession creates remote upload session CreateUploadSession(ctx context.Context, session *serializer.UploadSession, ttl int64) error + // GetUploadURL signs an url for uploading file GetUploadURL(ttl int64, sessionID string) (string, string, error) } diff --git a/pkg/filesystem/driver/remote/handler.go b/pkg/filesystem/driver/remote/handler.go index 2793a70..ce869ac 100644 --- a/pkg/filesystem/driver/remote/handler.go +++ b/pkg/filesystem/driver/remote/handler.go @@ -323,7 +323,12 @@ func (handler Driver) Source( // Token 获取上传策略和认证Token func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) { + siteURL := model.GetSiteURL() + apiBaseURI, _ := url.Parse(path.Join("/api/v3/callback/remote" + uploadSession.Key + uploadSession.CallbackSecret)) + apiURL := siteURL.ResolveReference(apiBaseURI) + // 在从机端创建上传会话 + uploadSession.Callback = apiURL.String() if err := handler.client.CreateUploadSession(ctx, uploadSession, ttl); err != nil { return nil, err } @@ -331,7 +336,7 @@ func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *seria // 获取上传地址 uploadURL, sign, err := handler.client.GetUploadURL(ttl, uploadSession.Key) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to sign upload url: %w", err) } return &serializer.UploadCredential{ diff --git a/pkg/filesystem/hooks.go b/pkg/filesystem/hooks.go index d2536ed..f83e259 100644 --- a/pkg/filesystem/hooks.go +++ b/pkg/filesystem/hooks.go @@ -2,13 +2,12 @@ package filesystem import ( "context" - "errors" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/cache" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/util" "io/ioutil" @@ -178,30 +177,30 @@ func GenericAfterUpdate(ctx context.Context, fs *FileSystem, newFile fsctx.FileH } // SlaveAfterUpload Slave模式下上传完成钩子 -func SlaveAfterUpload(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error { - return errors.New("") - policy := ctx.Value(fsctx.UploadPolicyCtx).(serializer.UploadPolicy) - fileInfo := fileHeader.Info() +func SlaveAfterUpload(session *serializer.UploadSession) Hook { + return func(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error { + fileInfo := fileHeader.Info() - // 构造一个model.File,用于生成缩略图 - file := model.File{ - Name: fileInfo.FileName, - SourceName: fileInfo.SavePath, - } - fs.GenerateThumbnail(ctx, &file) + // 构造一个model.File,用于生成缩略图 + file := model.File{ + Name: fileInfo.FileName, + SourceName: fileInfo.SavePath, + } + fs.GenerateThumbnail(ctx, &file) - if policy.CallbackURL == "" { - return nil - } + if session.Callback == "" { + return nil + } + + // 发送回调请求 + callbackBody := serializer.UploadCallback{ + SourceName: file.SourceName, + PicInfo: file.PicInfo, + Size: fileInfo.Size, + } - // 发送回调请求 - callbackBody := serializer.UploadCallback{ - Name: file.Name, - SourceName: file.SourceName, - PicInfo: file.PicInfo, - Size: fileInfo.Size, + return cluster.RemoteCallback(session.Callback, callbackBody) } - return request.RemoteCallback(policy.CallbackURL, callbackBody) } // GenericAfterUpload 文件上传完成后,包含数据库操作 diff --git a/pkg/filesystem/upload.go b/pkg/filesystem/upload.go index f8bec75..0416ef3 100644 --- a/pkg/filesystem/upload.go +++ b/pkg/filesystem/upload.go @@ -182,14 +182,15 @@ func (fs *FileSystem) CreateUploadSession(ctx context.Context, file *fsctx.FileS } uploadSession := &serializer.UploadSession{ - Key: callbackKey, - UID: fs.User.ID, - Policy: *fs.Policy, - VirtualPath: file.VirtualPath, - Name: file.Name, - Size: fileSize, - SavePath: file.SavePath, - LastModified: file.LastModified, + Key: callbackKey, + UID: fs.User.ID, + Policy: *fs.Policy, + VirtualPath: file.VirtualPath, + Name: file.Name, + Size: fileSize, + SavePath: file.SavePath, + LastModified: file.LastModified, + CallbackSecret: util.RandStringRunes(32), } // 获取上传凭证 diff --git a/pkg/request/request.go b/pkg/request/request.go index 9d4249b..7411ab8 100644 --- a/pkg/request/request.go +++ b/pkg/request/request.go @@ -51,7 +51,7 @@ func NewClient(opts ...Option) Client { } // Request 发送HTTP请求 -func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Option) *Response { +func (c *HTTPClient) Request(method, target string, body io.Reader, opts ...Option) *Response { // 应用额外设置 c.mu.Lock() options := *c.options diff --git a/pkg/request/slave.go b/pkg/request/slave.go deleted file mode 100644 index 2948250..0000000 --- a/pkg/request/slave.go +++ /dev/null @@ -1,52 +0,0 @@ -package request - -import ( - "bytes" - "encoding/json" - "errors" - "time" - - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" -) - -// TODO: move to slave pkg -// RemoteCallback 发送远程存储策略上传回调请求 -func RemoteCallback(url string, body serializer.UploadCallback) error { - callbackBody, err := json.Marshal(struct { - Data serializer.UploadCallback `json:"data"` - }{ - Data: body, - }) - if err != nil { - return serializer.NewError(serializer.CodeCallbackError, "无法编码回调正文", err) - } - - resp := GeneralClient.Request( - "POST", - url, - bytes.NewReader(callbackBody), - WithTimeout(time.Duration(conf.SlaveConfig.CallbackTimeout)*time.Second), - WithCredential(auth.General, int64(conf.SlaveConfig.SignatureTTL)), - ) - - if resp.Err != nil { - return serializer.NewError(serializer.CodeCallbackError, "无法发起回调请求", resp.Err) - } - - // 解析回调服务端响应 - resp = resp.CheckHTTPResponse(200) - if resp.Err != nil { - return serializer.NewError(serializer.CodeCallbackError, "服务器返回异常响应", resp.Err) - } - response, err := resp.DecodeResponse() - if err != nil { - return serializer.NewError(serializer.CodeCallbackError, "无法解析服务端返回的响应", err) - } - if response.Code != 0 { - return serializer.NewError(response.Code, response.Msg, errors.New(response.Error)) - } - - return nil -} diff --git a/pkg/request/slave_test.go b/pkg/request/slave_test.go deleted file mode 100644 index a7f9fcf..0000000 --- a/pkg/request/slave_test.go +++ /dev/null @@ -1,137 +0,0 @@ -package request - -import ( - "bytes" - "encoding/json" - "errors" - "io/ioutil" - "net/http" - "strings" - "testing" - - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" -) - -func TestRemoteCallback(t *testing.T) { - asserts := assert.New(t) - - // 回调成功 - { - clientMock := ClientMock{} - mockResp, _ := json.Marshal(serializer.Response{Code: 0}) - clientMock.On( - "Request", - "POST", - "http://test/test/url", - testMock.Anything, - testMock.Anything, - ).Return(&Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(bytes.NewReader(mockResp)), - }, - }) - GeneralClient = clientMock - resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ - SourceName: "source", - }) - asserts.NoError(resp) - clientMock.AssertExpectations(t) - } - - // 服务端返回业务错误 - { - clientMock := ClientMock{} - mockResp, _ := json.Marshal(serializer.Response{Code: 401}) - clientMock.On( - "Request", - "POST", - "http://test/test/url", - testMock.Anything, - testMock.Anything, - ).Return(&Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(bytes.NewReader(mockResp)), - }, - }) - GeneralClient = clientMock - resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ - SourceName: "source", - }) - asserts.EqualValues(401, resp.(serializer.AppError).Code) - clientMock.AssertExpectations(t) - } - - // 无法解析回调响应 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - "http://test/test/url", - testMock.Anything, - testMock.Anything, - ).Return(&Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("mockResp")), - }, - }) - GeneralClient = clientMock - resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ - SourceName: "source", - }) - asserts.Error(resp) - clientMock.AssertExpectations(t) - } - - // HTTP状态码非200 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - "http://test/test/url", - testMock.Anything, - testMock.Anything, - ).Return(&Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 404, - Body: ioutil.NopCloser(strings.NewReader("mockResp")), - }, - }) - GeneralClient = clientMock - resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ - SourceName: "source", - }) - asserts.Error(resp) - clientMock.AssertExpectations(t) - } - - // 无法发起回调 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - "http://test/test/url", - testMock.Anything, - testMock.Anything, - ).Return(&Response{ - Err: errors.New("error"), - }) - GeneralClient = clientMock - resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ - SourceName: "source", - }) - asserts.Error(resp) - clientMock.AssertExpectations(t) - } -} diff --git a/pkg/serializer/upload.go b/pkg/serializer/upload.go index b1f78a1..12dce42 100644 --- a/pkg/serializer/upload.go +++ b/pkg/serializer/upload.go @@ -37,19 +37,20 @@ type UploadCredential struct { // UploadSession 上传会话 type UploadSession struct { - Key string // 上传会话 GUID - UID uint // 发起者 - VirtualPath string // 用户文件路径,不含文件名 - Name string // 文件名 - Size uint64 // 文件大小 - SavePath string // 物理存储路径,包含物理文件名 - LastModified *time.Time // 可选的文件最后修改日期 - Policy model.Policy + Key string // 上传会话 GUID + UID uint // 发起者 + VirtualPath string // 用户文件路径,不含文件名 + Name string // 文件名 + Size uint64 // 文件大小 + SavePath string // 物理存储路径,包含物理文件名 + LastModified *time.Time // 可选的文件最后修改日期 + Policy model.Policy + Callback string // 回调 URL 地址 + CallbackSecret string // 回调 URL } // UploadCallback 上传回调正文 type UploadCallback struct { - Name string `json:"name"` SourceName string `json:"source_name"` PicInfo string `json:"pic_info"` Size uint64 `json:"size"` diff --git a/service/callback/upload.go b/service/callback/upload.go index 54bdf74..59f7a4f 100644 --- a/service/callback/upload.go +++ b/service/callback/upload.go @@ -70,7 +70,6 @@ type S3Callback struct { // GetBody 返回回调正文 func (service UpyunCallbackService) GetBody(session *serializer.UploadSession) serializer.UploadCallback { res := serializer.UploadCallback{ - Name: session.Name, SourceName: service.SourceName, Size: service.Size, } @@ -84,7 +83,6 @@ func (service UpyunCallbackService) GetBody(session *serializer.UploadSession) s // GetBody 返回回调正文 func (service UploadCallbackService) GetBody(session *serializer.UploadSession) serializer.UploadCallback { return serializer.UploadCallback{ - Name: service.Name, SourceName: service.SourceName, PicInfo: service.PicInfo, Size: service.Size, @@ -98,7 +96,6 @@ func (service OneDriveCallback) GetBody(session *serializer.UploadSession) seria picInfo = fmt.Sprintf("%d,%d", service.Meta.Image.Width, service.Meta.Image.Height) } return serializer.UploadCallback{ - Name: session.Name, SourceName: session.SavePath, PicInfo: picInfo, Size: session.Size, @@ -108,7 +105,6 @@ func (service OneDriveCallback) GetBody(session *serializer.UploadSession) seria // GetBody 返回回调正文 func (service COSCallback) GetBody(session *serializer.UploadSession) serializer.UploadCallback { return serializer.UploadCallback{ - Name: session.Name, SourceName: session.SavePath, PicInfo: "", Size: session.Size, @@ -118,7 +114,6 @@ func (service COSCallback) GetBody(session *serializer.UploadSession) serializer // GetBody 返回回调正文 func (service S3Callback) GetBody(session *serializer.UploadSession) serializer.UploadCallback { return serializer.UploadCallback{ - Name: session.Name, SourceName: session.SavePath, PicInfo: "", Size: session.Size, diff --git a/service/explorer/upload.go b/service/explorer/upload.go index 6847e03..754d897 100644 --- a/service/explorer/upload.go +++ b/service/explorer/upload.go @@ -120,7 +120,7 @@ func (service *UploadService) LocalUpload(ctx context.Context, c *gin.Context) s util.Log().Info("尝试上传覆盖分片[%d] Start=%d", service.Index, actualSizeStart) } - return processChunkUpload(ctx, c, fs, &uploadSession, service.Index, file, fsctx.Append|fsctx.Overwrite) + return processChunkUpload(ctx, c, fs, &uploadSession, service.Index, file, fsctx.Append) } // SlaveUpload 处理从机文件分片上传 @@ -165,6 +165,11 @@ func processChunkUpload(ctx context.Context, c *gin.Context, fs *filesystem.File ) } + // 非首个分片时需要允许覆盖 + if index > 0 { + mode |= fsctx.Overwrite + } + fileData := fsctx.FileStream{ MIMEType: c.Request.Header.Get("Content-Type"), File: c.Request.Body, @@ -193,7 +198,7 @@ func processChunkUpload(ctx context.Context, c *gin.Context, fs *filesystem.File } } else { if isLastChunk { - fs.Use("AfterUpload", filesystem.SlaveAfterUpload) + fs.Use("AfterUpload", filesystem.SlaveAfterUpload(session)) } }