Enable overwrite for non-first chunk uploading request

pull/1107/head
HFO4 3 years ago
parent 050a68a359
commit 4925a356e3

@ -26,7 +26,7 @@ jobs:
- name: Get dependencies and build
run: |
go get github.com/rakyll/statik
go install github.com/rakyll/statik
export PATH=$PATH:~/go/bin/
statik -src=models -f
sudo apt-get update

@ -1,11 +1,14 @@
package cluster
import (
"bytes"
"encoding/json"
"errors"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
@ -408,3 +411,43 @@ func getAria2RequestBody(body *serializer.SlaveAria2Call) (io.Reader, error) {
return strings.NewReader(string(reqBodyEncoded)), nil
}
// TODO: move to slave pkg
// RemoteCallback 发送远程存储策略上传回调请求
func RemoteCallback(url string, body serializer.UploadCallback) error {
callbackBody, err := json.Marshal(struct {
Data serializer.UploadCallback `json:"data"`
}{
Data: body,
})
if err != nil {
return serializer.NewError(serializer.CodeCallbackError, "无法编码回调正文", err)
}
resp := request.GeneralClient.Request(
"POST",
url,
bytes.NewReader(callbackBody),
request.WithTimeout(time.Duration(conf.SlaveConfig.CallbackTimeout)*time.Second),
request.WithCredential(auth.General, int64(conf.SlaveConfig.SignatureTTL)),
)
if resp.Err != nil {
return serializer.NewError(serializer.CodeCallbackError, "从机无法发起回调请求", resp.Err)
}
// 解析回调服务端响应
resp = resp.CheckHTTPResponse(200)
if resp.Err != nil {
return serializer.NewError(serializer.CodeCallbackError, "主机服务器返回异常响应", resp.Err)
}
response, err := resp.DecodeResponse()
if err != nil {
return serializer.NewError(serializer.CodeCallbackError, "从机无法解析主机返回的响应", err)
}
if response.Code != 0 {
return serializer.NewError(response.Code, response.Msg, errors.New(response.Error))
}
return nil
}

@ -441,3 +441,125 @@ func TestSlaveCaller_DeleteTempFile(t *testing.T) {
a.NoError(err)
}
}
//func TestRemoteCallback(t *testing.T) {
// asserts := assert.New(t)
//
// // 回调成功
// {
// clientMock := request.ClientMock{}
// mockResp, _ := json.Marshal(serializer.Response{Code: 0})
// clientMock.On(
// "Request",
// "POST",
// "http://test/test/url",
// testMock.Anything,
// testMock.Anything,
// ).Return(&request.Response{
// Err: nil,
// Response: &http.Response{
// StatusCode: 200,
// Body: ioutil.NopCloser(bytes.NewReader(mockResp)),
// },
// })
// request.GeneralClient = clientMock
// resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
// SourceName: "source",
// })
// asserts.NoError(resp)
// clientMock.AssertExpectations(t)
// }
//
// // 服务端返回业务错误
// {
// clientMock := request.ClientMock{}
// mockResp, _ := json.Marshal(serializer.Response{Code: 401})
// clientMock.On(
// "Request",
// "POST",
// "http://test/test/url",
// testMock.Anything,
// testMock.Anything,
// ).Return(&request.Response{
// Err: nil,
// Response: &http.Response{
// StatusCode: 200,
// Body: ioutil.NopCloser(bytes.NewReader(mockResp)),
// },
// })
// request.GeneralClient = clientMock
// resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
// SourceName: "source",
// })
// asserts.EqualValues(401, resp.(serializer.AppError).Code)
// clientMock.AssertExpectations(t)
// }
//
// // 无法解析回调响应
// {
// clientMock := request.ClientMock{}
// clientMock.On(
// "Request",
// "POST",
// "http://test/test/url",
// testMock.Anything,
// testMock.Anything,
// ).Return(&request.Response{
// Err: nil,
// Response: &http.Response{
// StatusCode: 200,
// Body: ioutil.NopCloser(strings.NewReader("mockResp")),
// },
// })
// request.GeneralClient = clientMock
// resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
// SourceName: "source",
// })
// asserts.Error(resp)
// clientMock.AssertExpectations(t)
// }
//
// // HTTP状态码非200
// {
// clientMock := request.ClientMock{}
// clientMock.On(
// "Request",
// "POST",
// "http://test/test/url",
// testMock.Anything,
// testMock.Anything,
// ).Return(&request.Response{
// Err: nil,
// Response: &http.Response{
// StatusCode: 404,
// Body: ioutil.NopCloser(strings.NewReader("mockResp")),
// },
// })
// request.GeneralClient = clientMock
// resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
// SourceName: "source",
// })
// asserts.Error(resp)
// clientMock.AssertExpectations(t)
// }
//
// // 无法发起回调
// {
// clientMock := request.ClientMock{}
// clientMock.On(
// "Request",
// "POST",
// "http://test/test/url",
// testMock.Anything,
// testMock.Anything,
// ).Return(&request.Response{
// Err: errors.New("error"),
// })
// request.GeneralClient = clientMock
// resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
// SourceName: "source",
// })
// asserts.Error(resp)
// clientMock.AssertExpectations(t)
// }
//}

@ -20,7 +20,9 @@ const (
// Client to operate remote slave server
type Client interface {
// CreateUploadSession creates remote upload session
CreateUploadSession(ctx context.Context, session *serializer.UploadSession, ttl int64) error
// GetUploadURL signs an url for uploading file
GetUploadURL(ttl int64, sessionID string) (string, string, error)
}

@ -323,7 +323,12 @@ func (handler Driver) Source(
// Token 获取上传策略和认证Token
func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
siteURL := model.GetSiteURL()
apiBaseURI, _ := url.Parse(path.Join("/api/v3/callback/remote" + uploadSession.Key + uploadSession.CallbackSecret))
apiURL := siteURL.ResolveReference(apiBaseURI)
// 在从机端创建上传会话
uploadSession.Callback = apiURL.String()
if err := handler.client.CreateUploadSession(ctx, uploadSession, ttl); err != nil {
return nil, err
}
@ -331,7 +336,7 @@ func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *seria
// 获取上传地址
uploadURL, sign, err := handler.client.GetUploadURL(ttl, uploadSession.Key)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to sign upload url: %w", err)
}
return &serializer.UploadCredential{

@ -2,13 +2,12 @@ package filesystem
import (
"context"
"errors"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"io/ioutil"
@ -178,9 +177,8 @@ func GenericAfterUpdate(ctx context.Context, fs *FileSystem, newFile fsctx.FileH
}
// SlaveAfterUpload Slave模式下上传完成钩子
func SlaveAfterUpload(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
return errors.New("")
policy := ctx.Value(fsctx.UploadPolicyCtx).(serializer.UploadPolicy)
func SlaveAfterUpload(session *serializer.UploadSession) Hook {
return func(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
fileInfo := fileHeader.Info()
// 构造一个model.File用于生成缩略图
@ -190,18 +188,19 @@ func SlaveAfterUpload(ctx context.Context, fs *FileSystem, fileHeader fsctx.File
}
fs.GenerateThumbnail(ctx, &file)
if policy.CallbackURL == "" {
if session.Callback == "" {
return nil
}
// 发送回调请求
callbackBody := serializer.UploadCallback{
Name: file.Name,
SourceName: file.SourceName,
PicInfo: file.PicInfo,
Size: fileInfo.Size,
}
return request.RemoteCallback(policy.CallbackURL, callbackBody)
return cluster.RemoteCallback(session.Callback, callbackBody)
}
}
// GenericAfterUpload 文件上传完成后,包含数据库操作

@ -190,6 +190,7 @@ func (fs *FileSystem) CreateUploadSession(ctx context.Context, file *fsctx.FileS
Size: fileSize,
SavePath: file.SavePath,
LastModified: file.LastModified,
CallbackSecret: util.RandStringRunes(32),
}
// 获取上传凭证

@ -51,7 +51,7 @@ func NewClient(opts ...Option) Client {
}
// Request 发送HTTP请求
func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Option) *Response {
func (c *HTTPClient) Request(method, target string, body io.Reader, opts ...Option) *Response {
// 应用额外设置
c.mu.Lock()
options := *c.options

@ -1,52 +0,0 @@
package request
import (
"bytes"
"encoding/json"
"errors"
"time"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)
// TODO: move to slave pkg
// RemoteCallback 发送远程存储策略上传回调请求
func RemoteCallback(url string, body serializer.UploadCallback) error {
callbackBody, err := json.Marshal(struct {
Data serializer.UploadCallback `json:"data"`
}{
Data: body,
})
if err != nil {
return serializer.NewError(serializer.CodeCallbackError, "无法编码回调正文", err)
}
resp := GeneralClient.Request(
"POST",
url,
bytes.NewReader(callbackBody),
WithTimeout(time.Duration(conf.SlaveConfig.CallbackTimeout)*time.Second),
WithCredential(auth.General, int64(conf.SlaveConfig.SignatureTTL)),
)
if resp.Err != nil {
return serializer.NewError(serializer.CodeCallbackError, "无法发起回调请求", resp.Err)
}
// 解析回调服务端响应
resp = resp.CheckHTTPResponse(200)
if resp.Err != nil {
return serializer.NewError(serializer.CodeCallbackError, "服务器返回异常响应", resp.Err)
}
response, err := resp.DecodeResponse()
if err != nil {
return serializer.NewError(serializer.CodeCallbackError, "无法解析服务端返回的响应", err)
}
if response.Code != 0 {
return serializer.NewError(response.Code, response.Msg, errors.New(response.Error))
}
return nil
}

@ -1,137 +0,0 @@
package request
import (
"bytes"
"encoding/json"
"errors"
"io/ioutil"
"net/http"
"strings"
"testing"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
)
func TestRemoteCallback(t *testing.T) {
asserts := assert.New(t)
// 回调成功
{
clientMock := ClientMock{}
mockResp, _ := json.Marshal(serializer.Response{Code: 0})
clientMock.On(
"Request",
"POST",
"http://test/test/url",
testMock.Anything,
testMock.Anything,
).Return(&Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(mockResp)),
},
})
GeneralClient = clientMock
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
SourceName: "source",
})
asserts.NoError(resp)
clientMock.AssertExpectations(t)
}
// 服务端返回业务错误
{
clientMock := ClientMock{}
mockResp, _ := json.Marshal(serializer.Response{Code: 401})
clientMock.On(
"Request",
"POST",
"http://test/test/url",
testMock.Anything,
testMock.Anything,
).Return(&Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(mockResp)),
},
})
GeneralClient = clientMock
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
SourceName: "source",
})
asserts.EqualValues(401, resp.(serializer.AppError).Code)
clientMock.AssertExpectations(t)
}
// 无法解析回调响应
{
clientMock := ClientMock{}
clientMock.On(
"Request",
"POST",
"http://test/test/url",
testMock.Anything,
testMock.Anything,
).Return(&Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("mockResp")),
},
})
GeneralClient = clientMock
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
SourceName: "source",
})
asserts.Error(resp)
clientMock.AssertExpectations(t)
}
// HTTP状态码非200
{
clientMock := ClientMock{}
clientMock.On(
"Request",
"POST",
"http://test/test/url",
testMock.Anything,
testMock.Anything,
).Return(&Response{
Err: nil,
Response: &http.Response{
StatusCode: 404,
Body: ioutil.NopCloser(strings.NewReader("mockResp")),
},
})
GeneralClient = clientMock
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
SourceName: "source",
})
asserts.Error(resp)
clientMock.AssertExpectations(t)
}
// 无法发起回调
{
clientMock := ClientMock{}
clientMock.On(
"Request",
"POST",
"http://test/test/url",
testMock.Anything,
testMock.Anything,
).Return(&Response{
Err: errors.New("error"),
})
GeneralClient = clientMock
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
SourceName: "source",
})
asserts.Error(resp)
clientMock.AssertExpectations(t)
}
}

@ -45,11 +45,12 @@ type UploadSession struct {
SavePath string // 物理存储路径,包含物理文件名
LastModified *time.Time // 可选的文件最后修改日期
Policy model.Policy
Callback string // 回调 URL 地址
CallbackSecret string // 回调 URL
}
// UploadCallback 上传回调正文
type UploadCallback struct {
Name string `json:"name"`
SourceName string `json:"source_name"`
PicInfo string `json:"pic_info"`
Size uint64 `json:"size"`

@ -70,7 +70,6 @@ type S3Callback struct {
// GetBody 返回回调正文
func (service UpyunCallbackService) GetBody(session *serializer.UploadSession) serializer.UploadCallback {
res := serializer.UploadCallback{
Name: session.Name,
SourceName: service.SourceName,
Size: service.Size,
}
@ -84,7 +83,6 @@ func (service UpyunCallbackService) GetBody(session *serializer.UploadSession) s
// GetBody 返回回调正文
func (service UploadCallbackService) GetBody(session *serializer.UploadSession) serializer.UploadCallback {
return serializer.UploadCallback{
Name: service.Name,
SourceName: service.SourceName,
PicInfo: service.PicInfo,
Size: service.Size,
@ -98,7 +96,6 @@ func (service OneDriveCallback) GetBody(session *serializer.UploadSession) seria
picInfo = fmt.Sprintf("%d,%d", service.Meta.Image.Width, service.Meta.Image.Height)
}
return serializer.UploadCallback{
Name: session.Name,
SourceName: session.SavePath,
PicInfo: picInfo,
Size: session.Size,
@ -108,7 +105,6 @@ func (service OneDriveCallback) GetBody(session *serializer.UploadSession) seria
// GetBody 返回回调正文
func (service COSCallback) GetBody(session *serializer.UploadSession) serializer.UploadCallback {
return serializer.UploadCallback{
Name: session.Name,
SourceName: session.SavePath,
PicInfo: "",
Size: session.Size,
@ -118,7 +114,6 @@ func (service COSCallback) GetBody(session *serializer.UploadSession) serializer
// GetBody 返回回调正文
func (service S3Callback) GetBody(session *serializer.UploadSession) serializer.UploadCallback {
return serializer.UploadCallback{
Name: session.Name,
SourceName: session.SavePath,
PicInfo: "",
Size: session.Size,

@ -120,7 +120,7 @@ func (service *UploadService) LocalUpload(ctx context.Context, c *gin.Context) s
util.Log().Info("尝试上传覆盖分片[%d] Start=%d", service.Index, actualSizeStart)
}
return processChunkUpload(ctx, c, fs, &uploadSession, service.Index, file, fsctx.Append|fsctx.Overwrite)
return processChunkUpload(ctx, c, fs, &uploadSession, service.Index, file, fsctx.Append)
}
// SlaveUpload 处理从机文件分片上传
@ -165,6 +165,11 @@ func processChunkUpload(ctx context.Context, c *gin.Context, fs *filesystem.File
)
}
// 非首个分片时需要允许覆盖
if index > 0 {
mode |= fsctx.Overwrite
}
fileData := fsctx.FileStream{
MIMEType: c.Request.Header.Get("Content-Type"),
File: c.Request.Body,
@ -193,7 +198,7 @@ func processChunkUpload(ctx context.Context, c *gin.Context, fs *filesystem.File
}
} else {
if isLastChunk {
fs.Use("AfterUpload", filesystem.SlaveAfterUpload)
fs.Use("AfterUpload", filesystem.SlaveAfterUpload(session))
}
}

Loading…
Cancel
Save