Enable overwrite for non-first chunk uploading request

pull/1107/head
HFO4 3 years ago
parent 050a68a359
commit 4925a356e3

@ -26,7 +26,7 @@ jobs:
- name: Get dependencies and build - name: Get dependencies and build
run: | run: |
go get github.com/rakyll/statik go install github.com/rakyll/statik
export PATH=$PATH:~/go/bin/ export PATH=$PATH:~/go/bin/
statik -src=models -f statik -src=models -f
sudo apt-get update sudo apt-get update

@ -1,11 +1,14 @@
package cluster package cluster
import ( import (
"bytes"
"encoding/json" "encoding/json"
"errors"
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/cloudreve/Cloudreve/v3/pkg/util"
@ -408,3 +411,43 @@ func getAria2RequestBody(body *serializer.SlaveAria2Call) (io.Reader, error) {
return strings.NewReader(string(reqBodyEncoded)), nil return strings.NewReader(string(reqBodyEncoded)), nil
} }
// TODO: move to slave pkg
// RemoteCallback 发送远程存储策略上传回调请求
func RemoteCallback(url string, body serializer.UploadCallback) error {
callbackBody, err := json.Marshal(struct {
Data serializer.UploadCallback `json:"data"`
}{
Data: body,
})
if err != nil {
return serializer.NewError(serializer.CodeCallbackError, "无法编码回调正文", err)
}
resp := request.GeneralClient.Request(
"POST",
url,
bytes.NewReader(callbackBody),
request.WithTimeout(time.Duration(conf.SlaveConfig.CallbackTimeout)*time.Second),
request.WithCredential(auth.General, int64(conf.SlaveConfig.SignatureTTL)),
)
if resp.Err != nil {
return serializer.NewError(serializer.CodeCallbackError, "从机无法发起回调请求", resp.Err)
}
// 解析回调服务端响应
resp = resp.CheckHTTPResponse(200)
if resp.Err != nil {
return serializer.NewError(serializer.CodeCallbackError, "主机服务器返回异常响应", resp.Err)
}
response, err := resp.DecodeResponse()
if err != nil {
return serializer.NewError(serializer.CodeCallbackError, "从机无法解析主机返回的响应", err)
}
if response.Code != 0 {
return serializer.NewError(response.Code, response.Msg, errors.New(response.Error))
}
return nil
}

@ -441,3 +441,125 @@ func TestSlaveCaller_DeleteTempFile(t *testing.T) {
a.NoError(err) a.NoError(err)
} }
} }
//func TestRemoteCallback(t *testing.T) {
// asserts := assert.New(t)
//
// // 回调成功
// {
// clientMock := request.ClientMock{}
// mockResp, _ := json.Marshal(serializer.Response{Code: 0})
// clientMock.On(
// "Request",
// "POST",
// "http://test/test/url",
// testMock.Anything,
// testMock.Anything,
// ).Return(&request.Response{
// Err: nil,
// Response: &http.Response{
// StatusCode: 200,
// Body: ioutil.NopCloser(bytes.NewReader(mockResp)),
// },
// })
// request.GeneralClient = clientMock
// resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
// SourceName: "source",
// })
// asserts.NoError(resp)
// clientMock.AssertExpectations(t)
// }
//
// // 服务端返回业务错误
// {
// clientMock := request.ClientMock{}
// mockResp, _ := json.Marshal(serializer.Response{Code: 401})
// clientMock.On(
// "Request",
// "POST",
// "http://test/test/url",
// testMock.Anything,
// testMock.Anything,
// ).Return(&request.Response{
// Err: nil,
// Response: &http.Response{
// StatusCode: 200,
// Body: ioutil.NopCloser(bytes.NewReader(mockResp)),
// },
// })
// request.GeneralClient = clientMock
// resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
// SourceName: "source",
// })
// asserts.EqualValues(401, resp.(serializer.AppError).Code)
// clientMock.AssertExpectations(t)
// }
//
// // 无法解析回调响应
// {
// clientMock := request.ClientMock{}
// clientMock.On(
// "Request",
// "POST",
// "http://test/test/url",
// testMock.Anything,
// testMock.Anything,
// ).Return(&request.Response{
// Err: nil,
// Response: &http.Response{
// StatusCode: 200,
// Body: ioutil.NopCloser(strings.NewReader("mockResp")),
// },
// })
// request.GeneralClient = clientMock
// resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
// SourceName: "source",
// })
// asserts.Error(resp)
// clientMock.AssertExpectations(t)
// }
//
// // HTTP状态码非200
// {
// clientMock := request.ClientMock{}
// clientMock.On(
// "Request",
// "POST",
// "http://test/test/url",
// testMock.Anything,
// testMock.Anything,
// ).Return(&request.Response{
// Err: nil,
// Response: &http.Response{
// StatusCode: 404,
// Body: ioutil.NopCloser(strings.NewReader("mockResp")),
// },
// })
// request.GeneralClient = clientMock
// resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
// SourceName: "source",
// })
// asserts.Error(resp)
// clientMock.AssertExpectations(t)
// }
//
// // 无法发起回调
// {
// clientMock := request.ClientMock{}
// clientMock.On(
// "Request",
// "POST",
// "http://test/test/url",
// testMock.Anything,
// testMock.Anything,
// ).Return(&request.Response{
// Err: errors.New("error"),
// })
// request.GeneralClient = clientMock
// resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
// SourceName: "source",
// })
// asserts.Error(resp)
// clientMock.AssertExpectations(t)
// }
//}

@ -20,7 +20,9 @@ const (
// Client to operate remote slave server // Client to operate remote slave server
type Client interface { type Client interface {
// CreateUploadSession creates remote upload session
CreateUploadSession(ctx context.Context, session *serializer.UploadSession, ttl int64) error CreateUploadSession(ctx context.Context, session *serializer.UploadSession, ttl int64) error
// GetUploadURL signs an url for uploading file
GetUploadURL(ttl int64, sessionID string) (string, string, error) GetUploadURL(ttl int64, sessionID string) (string, string, error)
} }

@ -323,7 +323,12 @@ func (handler Driver) Source(
// Token 获取上传策略和认证Token // Token 获取上传策略和认证Token
func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) { func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
siteURL := model.GetSiteURL()
apiBaseURI, _ := url.Parse(path.Join("/api/v3/callback/remote" + uploadSession.Key + uploadSession.CallbackSecret))
apiURL := siteURL.ResolveReference(apiBaseURI)
// 在从机端创建上传会话 // 在从机端创建上传会话
uploadSession.Callback = apiURL.String()
if err := handler.client.CreateUploadSession(ctx, uploadSession, ttl); err != nil { if err := handler.client.CreateUploadSession(ctx, uploadSession, ttl); err != nil {
return nil, err return nil, err
} }
@ -331,7 +336,7 @@ func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *seria
// 获取上传地址 // 获取上传地址
uploadURL, sign, err := handler.client.GetUploadURL(ttl, uploadSession.Key) uploadURL, sign, err := handler.client.GetUploadURL(ttl, uploadSession.Key)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to sign upload url: %w", err)
} }
return &serializer.UploadCredential{ return &serializer.UploadCredential{

@ -2,13 +2,12 @@ package filesystem
import ( import (
"context" "context"
"errors"
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/cloudreve/Cloudreve/v3/pkg/util"
"io/ioutil" "io/ioutil"
@ -178,30 +177,30 @@ func GenericAfterUpdate(ctx context.Context, fs *FileSystem, newFile fsctx.FileH
} }
// SlaveAfterUpload Slave模式下上传完成钩子 // SlaveAfterUpload Slave模式下上传完成钩子
func SlaveAfterUpload(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error { func SlaveAfterUpload(session *serializer.UploadSession) Hook {
return errors.New("") return func(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
policy := ctx.Value(fsctx.UploadPolicyCtx).(serializer.UploadPolicy) fileInfo := fileHeader.Info()
fileInfo := fileHeader.Info()
// 构造一个model.File用于生成缩略图 // 构造一个model.File用于生成缩略图
file := model.File{ file := model.File{
Name: fileInfo.FileName, Name: fileInfo.FileName,
SourceName: fileInfo.SavePath, SourceName: fileInfo.SavePath,
} }
fs.GenerateThumbnail(ctx, &file) fs.GenerateThumbnail(ctx, &file)
if policy.CallbackURL == "" { if session.Callback == "" {
return nil return nil
} }
// 发送回调请求
callbackBody := serializer.UploadCallback{
SourceName: file.SourceName,
PicInfo: file.PicInfo,
Size: fileInfo.Size,
}
// 发送回调请求 return cluster.RemoteCallback(session.Callback, callbackBody)
callbackBody := serializer.UploadCallback{
Name: file.Name,
SourceName: file.SourceName,
PicInfo: file.PicInfo,
Size: fileInfo.Size,
} }
return request.RemoteCallback(policy.CallbackURL, callbackBody)
} }
// GenericAfterUpload 文件上传完成后,包含数据库操作 // GenericAfterUpload 文件上传完成后,包含数据库操作

@ -182,14 +182,15 @@ func (fs *FileSystem) CreateUploadSession(ctx context.Context, file *fsctx.FileS
} }
uploadSession := &serializer.UploadSession{ uploadSession := &serializer.UploadSession{
Key: callbackKey, Key: callbackKey,
UID: fs.User.ID, UID: fs.User.ID,
Policy: *fs.Policy, Policy: *fs.Policy,
VirtualPath: file.VirtualPath, VirtualPath: file.VirtualPath,
Name: file.Name, Name: file.Name,
Size: fileSize, Size: fileSize,
SavePath: file.SavePath, SavePath: file.SavePath,
LastModified: file.LastModified, LastModified: file.LastModified,
CallbackSecret: util.RandStringRunes(32),
} }
// 获取上传凭证 // 获取上传凭证

@ -51,7 +51,7 @@ func NewClient(opts ...Option) Client {
} }
// Request 发送HTTP请求 // Request 发送HTTP请求
func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Option) *Response { func (c *HTTPClient) Request(method, target string, body io.Reader, opts ...Option) *Response {
// 应用额外设置 // 应用额外设置
c.mu.Lock() c.mu.Lock()
options := *c.options options := *c.options

@ -1,52 +0,0 @@
package request
import (
"bytes"
"encoding/json"
"errors"
"time"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)
// TODO: move to slave pkg
// RemoteCallback 发送远程存储策略上传回调请求
func RemoteCallback(url string, body serializer.UploadCallback) error {
callbackBody, err := json.Marshal(struct {
Data serializer.UploadCallback `json:"data"`
}{
Data: body,
})
if err != nil {
return serializer.NewError(serializer.CodeCallbackError, "无法编码回调正文", err)
}
resp := GeneralClient.Request(
"POST",
url,
bytes.NewReader(callbackBody),
WithTimeout(time.Duration(conf.SlaveConfig.CallbackTimeout)*time.Second),
WithCredential(auth.General, int64(conf.SlaveConfig.SignatureTTL)),
)
if resp.Err != nil {
return serializer.NewError(serializer.CodeCallbackError, "无法发起回调请求", resp.Err)
}
// 解析回调服务端响应
resp = resp.CheckHTTPResponse(200)
if resp.Err != nil {
return serializer.NewError(serializer.CodeCallbackError, "服务器返回异常响应", resp.Err)
}
response, err := resp.DecodeResponse()
if err != nil {
return serializer.NewError(serializer.CodeCallbackError, "无法解析服务端返回的响应", err)
}
if response.Code != 0 {
return serializer.NewError(response.Code, response.Msg, errors.New(response.Error))
}
return nil
}

@ -1,137 +0,0 @@
package request
import (
"bytes"
"encoding/json"
"errors"
"io/ioutil"
"net/http"
"strings"
"testing"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
)
func TestRemoteCallback(t *testing.T) {
asserts := assert.New(t)
// 回调成功
{
clientMock := ClientMock{}
mockResp, _ := json.Marshal(serializer.Response{Code: 0})
clientMock.On(
"Request",
"POST",
"http://test/test/url",
testMock.Anything,
testMock.Anything,
).Return(&Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(mockResp)),
},
})
GeneralClient = clientMock
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
SourceName: "source",
})
asserts.NoError(resp)
clientMock.AssertExpectations(t)
}
// 服务端返回业务错误
{
clientMock := ClientMock{}
mockResp, _ := json.Marshal(serializer.Response{Code: 401})
clientMock.On(
"Request",
"POST",
"http://test/test/url",
testMock.Anything,
testMock.Anything,
).Return(&Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(mockResp)),
},
})
GeneralClient = clientMock
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
SourceName: "source",
})
asserts.EqualValues(401, resp.(serializer.AppError).Code)
clientMock.AssertExpectations(t)
}
// 无法解析回调响应
{
clientMock := ClientMock{}
clientMock.On(
"Request",
"POST",
"http://test/test/url",
testMock.Anything,
testMock.Anything,
).Return(&Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("mockResp")),
},
})
GeneralClient = clientMock
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
SourceName: "source",
})
asserts.Error(resp)
clientMock.AssertExpectations(t)
}
// HTTP状态码非200
{
clientMock := ClientMock{}
clientMock.On(
"Request",
"POST",
"http://test/test/url",
testMock.Anything,
testMock.Anything,
).Return(&Response{
Err: nil,
Response: &http.Response{
StatusCode: 404,
Body: ioutil.NopCloser(strings.NewReader("mockResp")),
},
})
GeneralClient = clientMock
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
SourceName: "source",
})
asserts.Error(resp)
clientMock.AssertExpectations(t)
}
// 无法发起回调
{
clientMock := ClientMock{}
clientMock.On(
"Request",
"POST",
"http://test/test/url",
testMock.Anything,
testMock.Anything,
).Return(&Response{
Err: errors.New("error"),
})
GeneralClient = clientMock
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
SourceName: "source",
})
asserts.Error(resp)
clientMock.AssertExpectations(t)
}
}

@ -37,19 +37,20 @@ type UploadCredential struct {
// UploadSession 上传会话 // UploadSession 上传会话
type UploadSession struct { type UploadSession struct {
Key string // 上传会话 GUID Key string // 上传会话 GUID
UID uint // 发起者 UID uint // 发起者
VirtualPath string // 用户文件路径,不含文件名 VirtualPath string // 用户文件路径,不含文件名
Name string // 文件名 Name string // 文件名
Size uint64 // 文件大小 Size uint64 // 文件大小
SavePath string // 物理存储路径,包含物理文件名 SavePath string // 物理存储路径,包含物理文件名
LastModified *time.Time // 可选的文件最后修改日期 LastModified *time.Time // 可选的文件最后修改日期
Policy model.Policy Policy model.Policy
Callback string // 回调 URL 地址
CallbackSecret string // 回调 URL
} }
// UploadCallback 上传回调正文 // UploadCallback 上传回调正文
type UploadCallback struct { type UploadCallback struct {
Name string `json:"name"`
SourceName string `json:"source_name"` SourceName string `json:"source_name"`
PicInfo string `json:"pic_info"` PicInfo string `json:"pic_info"`
Size uint64 `json:"size"` Size uint64 `json:"size"`

@ -70,7 +70,6 @@ type S3Callback struct {
// GetBody 返回回调正文 // GetBody 返回回调正文
func (service UpyunCallbackService) GetBody(session *serializer.UploadSession) serializer.UploadCallback { func (service UpyunCallbackService) GetBody(session *serializer.UploadSession) serializer.UploadCallback {
res := serializer.UploadCallback{ res := serializer.UploadCallback{
Name: session.Name,
SourceName: service.SourceName, SourceName: service.SourceName,
Size: service.Size, Size: service.Size,
} }
@ -84,7 +83,6 @@ func (service UpyunCallbackService) GetBody(session *serializer.UploadSession) s
// GetBody 返回回调正文 // GetBody 返回回调正文
func (service UploadCallbackService) GetBody(session *serializer.UploadSession) serializer.UploadCallback { func (service UploadCallbackService) GetBody(session *serializer.UploadSession) serializer.UploadCallback {
return serializer.UploadCallback{ return serializer.UploadCallback{
Name: service.Name,
SourceName: service.SourceName, SourceName: service.SourceName,
PicInfo: service.PicInfo, PicInfo: service.PicInfo,
Size: service.Size, Size: service.Size,
@ -98,7 +96,6 @@ func (service OneDriveCallback) GetBody(session *serializer.UploadSession) seria
picInfo = fmt.Sprintf("%d,%d", service.Meta.Image.Width, service.Meta.Image.Height) picInfo = fmt.Sprintf("%d,%d", service.Meta.Image.Width, service.Meta.Image.Height)
} }
return serializer.UploadCallback{ return serializer.UploadCallback{
Name: session.Name,
SourceName: session.SavePath, SourceName: session.SavePath,
PicInfo: picInfo, PicInfo: picInfo,
Size: session.Size, Size: session.Size,
@ -108,7 +105,6 @@ func (service OneDriveCallback) GetBody(session *serializer.UploadSession) seria
// GetBody 返回回调正文 // GetBody 返回回调正文
func (service COSCallback) GetBody(session *serializer.UploadSession) serializer.UploadCallback { func (service COSCallback) GetBody(session *serializer.UploadSession) serializer.UploadCallback {
return serializer.UploadCallback{ return serializer.UploadCallback{
Name: session.Name,
SourceName: session.SavePath, SourceName: session.SavePath,
PicInfo: "", PicInfo: "",
Size: session.Size, Size: session.Size,
@ -118,7 +114,6 @@ func (service COSCallback) GetBody(session *serializer.UploadSession) serializer
// GetBody 返回回调正文 // GetBody 返回回调正文
func (service S3Callback) GetBody(session *serializer.UploadSession) serializer.UploadCallback { func (service S3Callback) GetBody(session *serializer.UploadSession) serializer.UploadCallback {
return serializer.UploadCallback{ return serializer.UploadCallback{
Name: session.Name,
SourceName: session.SavePath, SourceName: session.SavePath,
PicInfo: "", PicInfo: "",
Size: session.Size, Size: session.Size,

@ -120,7 +120,7 @@ func (service *UploadService) LocalUpload(ctx context.Context, c *gin.Context) s
util.Log().Info("尝试上传覆盖分片[%d] Start=%d", service.Index, actualSizeStart) util.Log().Info("尝试上传覆盖分片[%d] Start=%d", service.Index, actualSizeStart)
} }
return processChunkUpload(ctx, c, fs, &uploadSession, service.Index, file, fsctx.Append|fsctx.Overwrite) return processChunkUpload(ctx, c, fs, &uploadSession, service.Index, file, fsctx.Append)
} }
// SlaveUpload 处理从机文件分片上传 // SlaveUpload 处理从机文件分片上传
@ -165,6 +165,11 @@ func processChunkUpload(ctx context.Context, c *gin.Context, fs *filesystem.File
) )
} }
// 非首个分片时需要允许覆盖
if index > 0 {
mode |= fsctx.Overwrite
}
fileData := fsctx.FileStream{ fileData := fsctx.FileStream{
MIMEType: c.Request.Header.Get("Content-Type"), MIMEType: c.Request.Header.Get("Content-Type"),
File: c.Request.Body, File: c.Request.Body,
@ -193,7 +198,7 @@ func processChunkUpload(ctx context.Context, c *gin.Context, fs *filesystem.File
} }
} else { } else {
if isLastChunk { if isLastChunk {
fs.Use("AfterUpload", filesystem.SlaveAfterUpload) fs.Use("AfterUpload", filesystem.SlaveAfterUpload(session))
} }
} }

Loading…
Cancel
Save