Merge remote-tracking branch 'upstream/master'

pull/1203/head
vvisionnn 4 years ago committed by GitHub
commit 3a8071f3ba

@ -10,10 +10,10 @@ jobs:
runs-on: ubuntu-18.04 runs-on: ubuntu-18.04
steps: steps:
- name: Set up Go 1.13 - name: Set up Golang
uses: actions/setup-go@v1 uses: actions/setup-go@v1
with: with:
go-version: 1.13 go-version: 1.17
id: go id: go
- name: Check out code into the Go module directory - name: Check out code into the Go module directory
@ -26,7 +26,7 @@ jobs:
- name: Get dependencies and build - name: Get dependencies and build
run: | run: |
go get github.com/rakyll/statik go install github.com/rakyll/statik
export PATH=$PATH:~/go/bin/ export PATH=$PATH:~/go/bin/
statik -src=models -f statik -src=models -f
sudo apt-get update sudo apt-get update

@ -14,10 +14,10 @@ jobs:
runs-on: ubuntu-18.04 runs-on: ubuntu-18.04
steps: steps:
- name: Set up Go 1.13 - name: Set up Golang
uses: actions/setup-go@v1 uses: actions/setup-go@v1
with: with:
go-version: 1.13 go-version: 1.17
id: go id: go
- name: Check out code into the Go module directory - name: Check out code into the Go module directory

@ -1,6 +1,6 @@
language: go language: go
go: go:
- 1.13.x - 1.17.x
node_js: "12.16.3" node_js: "12.16.3"
git: git:
depth: 1 depth: 1

@ -1,4 +1,4 @@
FROM golang:alpine as cloudreve_builder FROM golang:1.17.7-alpine as cloudreve_builder
# install dependencies and build tools # install dependencies and build tools

@ -1 +1 @@
Subproject commit eb3f32922ab9cd2f9fbef4860b93fec759a7054d Subproject commit e0da8f48856e3fb6e3e9cc920a32390ca132935e

@ -1,24 +1,19 @@
package middleware package middleware
import ( import (
"bytes" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"context"
"crypto/md5"
"fmt"
"io/ioutil"
"net/http" "net/http"
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/onedrive"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/oss"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/upyun"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/qiniu/api.v7/v7/auth/qbox" )
const (
CallbackFailedStatusCode = http.StatusUnauthorized
) )
// SignRequired 验证请求签名 // SignRequired 验证请求签名
@ -117,48 +112,60 @@ func WebDAVAuth() gin.HandlerFunc {
} }
} }
// 对上传会话进行验证
func UseUploadSession(policyType string) gin.HandlerFunc {
return func(c *gin.Context) {
// 验证key并查找用户
resp := uploadCallbackCheck(c, policyType)
if resp.Code != 0 {
c.JSON(CallbackFailedStatusCode, resp)
c.Abort()
return
}
c.Next()
}
}
// uploadCallbackCheck 对上传回调请求的 callback key 进行验证,如果成功则返回上传用户 // uploadCallbackCheck 对上传回调请求的 callback key 进行验证,如果成功则返回上传用户
func uploadCallbackCheck(c *gin.Context) (serializer.Response, *model.User) { func uploadCallbackCheck(c *gin.Context, policyType string) serializer.Response {
// 验证 Callback Key // 验证 Callback Key
callbackKey := c.Param("key") sessionID := c.Param("sessionID")
if callbackKey == "" { if sessionID == "" {
return serializer.ParamErr("Callback Key 不能为空", nil), nil return serializer.ParamErr("Session ID 不能为空", nil)
} }
callbackSessionRaw, exist := cache.Get("callback_" + callbackKey)
callbackSessionRaw, exist := cache.Get(filesystem.UploadSessionCachePrefix + sessionID)
if !exist { if !exist {
return serializer.ParamErr("回调会话不存在或已过期", nil), nil return serializer.ParamErr("上传会话不存在或已过期", nil)
} }
callbackSession := callbackSessionRaw.(serializer.UploadSession) callbackSession := callbackSessionRaw.(serializer.UploadSession)
c.Set("callbackSession", &callbackSession) c.Set(filesystem.UploadSessionCtx, &callbackSession)
if callbackSession.Policy.Type != policyType {
return serializer.Err(serializer.CodePolicyNotAllowed, "Policy not supported", nil)
}
// 清理回调会话 // 清理回调会话
_ = cache.Deletes([]string{callbackKey}, "callback_") _ = cache.Deletes([]string{sessionID}, filesystem.UploadSessionCachePrefix)
// 查找用户 // 查找用户
user, err := model.GetActiveUserByID(callbackSession.UID) user, err := model.GetActiveUserByID(callbackSession.UID)
if err != nil { if err != nil {
return serializer.Err(serializer.CodeCheckLogin, "找不到用户", err), nil return serializer.Err(serializer.CodeCheckLogin, "找不到用户", err)
} }
c.Set("user", &user) c.Set(filesystem.UserCtx, &user)
return serializer.Response{}
return serializer.Response{}, &user
} }
// RemoteCallbackAuth 远程回调签名验证 // RemoteCallbackAuth 远程回调签名验证
func RemoteCallbackAuth() gin.HandlerFunc { func RemoteCallbackAuth() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// 验证key并查找用户
resp, user := uploadCallbackCheck(c)
if resp.Code != 0 {
c.JSON(200, resp)
c.Abort()
return
}
// 验证签名 // 验证签名
authInstance := auth.HMACAuth{SecretKey: []byte(user.Policy.SecretKey)} session := c.MustGet(filesystem.UploadSessionCtx).(*serializer.UploadSession)
authInstance := auth.HMACAuth{SecretKey: []byte(session.Policy.SecretKey)}
if err := auth.CheckRequest(authInstance, c.Request); err != nil { if err := auth.CheckRequest(authInstance, c.Request); err != nil {
c.JSON(200, serializer.Err(serializer.CodeCheckLogin, err.Error(), err)) c.JSON(CallbackFailedStatusCode, serializer.Err(serializer.CodeCredentialInvalid, err.Error(), err))
c.Abort() c.Abort()
return return
} }
@ -171,28 +178,28 @@ func RemoteCallbackAuth() gin.HandlerFunc {
// QiniuCallbackAuth 七牛回调签名验证 // QiniuCallbackAuth 七牛回调签名验证
func QiniuCallbackAuth() gin.HandlerFunc { func QiniuCallbackAuth() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// 验证key并查找用户 //// 验证key并查找用户
resp, user := uploadCallbackCheck(c) //resp, user := uploadCallbackCheck(c)
if resp.Code != 0 { //if resp.Code != 0 {
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg}) // c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg})
c.Abort() // c.Abort()
return // return
} //}
//
// 验证回调是否来自qiniu //// 验证回调是否来自qiniu
mac := qbox.NewMac(user.Policy.AccessKey, user.Policy.SecretKey) //mac := qbox.NewMac(user.Policy.AccessKey, user.Policy.SecretKey)
ok, err := mac.VerifyCallback(c.Request) //ok, err := mac.VerifyCallback(c.Request)
if err != nil { //if err != nil {
util.Log().Debug("无法验证回调请求,%s", err) // util.Log().Debug("无法验证回调请求,%s", err)
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "无法验证回调请求"}) // c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "无法验证回调请求"})
c.Abort() // c.Abort()
return // return
} //}
if !ok { //if !ok {
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "回调签名无效"}) // c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "回调签名无效"})
c.Abort() // c.Abort()
return // return
} //}
c.Next() c.Next()
} }
@ -201,21 +208,21 @@ func QiniuCallbackAuth() gin.HandlerFunc {
// OSSCallbackAuth 阿里云OSS回调签名验证 // OSSCallbackAuth 阿里云OSS回调签名验证
func OSSCallbackAuth() gin.HandlerFunc { func OSSCallbackAuth() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// 验证key并查找用户 //// 验证key并查找用户
resp, _ := uploadCallbackCheck(c) //resp, _ := uploadCallbackCheck(c)
if resp.Code != 0 { //if resp.Code != 0 {
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg}) // c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg})
c.Abort() // c.Abort()
return // return
} //}
//
err := oss.VerifyCallbackSignature(c.Request) //err := oss.VerifyCallbackSignature(c.Request)
if err != nil { //if err != nil {
util.Log().Debug("回调签名验证失败,%s", err) // util.Log().Debug("回调签名验证失败,%s", err)
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "回调签名验证失败"}) // c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "回调签名验证失败"})
c.Abort() // c.Abort()
return // return
} //}
c.Next() c.Next()
} }
@ -224,53 +231,53 @@ func OSSCallbackAuth() gin.HandlerFunc {
// UpyunCallbackAuth 又拍云回调签名验证 // UpyunCallbackAuth 又拍云回调签名验证
func UpyunCallbackAuth() gin.HandlerFunc { func UpyunCallbackAuth() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// 验证key并查找用户 //// 验证key并查找用户
resp, user := uploadCallbackCheck(c) //resp, user := uploadCallbackCheck(c)
if resp.Code != 0 { //if resp.Code != 0 {
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg}) // c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg})
c.Abort() // c.Abort()
return // return
} //}
//
// 获取请求正文 //// 获取请求正文
body, err := ioutil.ReadAll(c.Request.Body) //body, err := ioutil.ReadAll(c.Request.Body)
c.Request.Body.Close() //c.Request.Body.Close()
if err != nil { //if err != nil {
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: err.Error()}) // c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: err.Error()})
c.Abort() // c.Abort()
return // return
} //}
//
c.Request.Body = ioutil.NopCloser(bytes.NewReader(body)) //c.Request.Body = ioutil.NopCloser(bytes.NewReader(body))
//
// 准备验证Upyun回调签名 //// 准备验证Upyun回调签名
handler := upyun.Driver{Policy: &user.Policy} //handler := upyun.Driver{Policy: &user.Policy}
contentMD5 := c.Request.Header.Get("Content-Md5") //contentMD5 := c.Request.Header.Get("Content-Md5")
date := c.Request.Header.Get("Date") //date := c.Request.Header.Get("Date")
actualSignature := c.Request.Header.Get("Authorization") //actualSignature := c.Request.Header.Get("Authorization")
//
// 计算正文MD5 //// 计算正文MD5
actualContentMD5 := fmt.Sprintf("%x", md5.Sum(body)) //actualContentMD5 := fmt.Sprintf("%x", md5.Sum(body))
if actualContentMD5 != contentMD5 { //if actualContentMD5 != contentMD5 {
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "MD5不一致"}) // c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "MD5不一致"})
c.Abort() // c.Abort()
return // return
} //}
//
// 计算理论签名 //// 计算理论签名
signature := handler.Sign(context.Background(), []string{ //signature := handler.Sign(context.Background(), []string{
"POST", // "POST",
c.Request.URL.Path, // c.Request.URL.Path,
date, // date,
contentMD5, // contentMD5,
}) //})
//
// 对比签名 //// 对比签名
if signature != actualSignature { //if signature != actualSignature {
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "鉴权失败"}) // c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "鉴权失败"})
c.Abort() // c.Abort()
return // return
} //}
c.Next() c.Next()
} }
@ -280,16 +287,16 @@ func UpyunCallbackAuth() gin.HandlerFunc {
// TODO 解耦 // TODO 解耦
func OneDriveCallbackAuth() gin.HandlerFunc { func OneDriveCallbackAuth() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// 验证key并查找用户 //// 验证key并查找用户
resp, _ := uploadCallbackCheck(c) //resp, _ := uploadCallbackCheck(c)
if resp.Code != 0 { //if resp.Code != 0 {
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg}) // c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg})
c.Abort() // c.Abort()
return // return
} //}
//
// 发送回调结束信号 //// 发送回调结束信号
onedrive.FinishCallback(c.Param("key")) //onedrive.FinishCallback(c.Param("key"))
c.Next() c.Next()
} }
@ -299,13 +306,13 @@ func OneDriveCallbackAuth() gin.HandlerFunc {
// TODO 解耦 测试 // TODO 解耦 测试
func COSCallbackAuth() gin.HandlerFunc { func COSCallbackAuth() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// 验证key并查找用户 //// 验证key并查找用户
resp, _ := uploadCallbackCheck(c) //resp, _ := uploadCallbackCheck(c)
if resp.Code != 0 { //if resp.Code != 0 {
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg}) // c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg})
c.Abort() // c.Abort()
return // return
} //}
c.Next() c.Next()
} }
@ -314,13 +321,13 @@ func COSCallbackAuth() gin.HandlerFunc {
// S3CallbackAuth Amazon S3回调签名验证 // S3CallbackAuth Amazon S3回调签名验证
func S3CallbackAuth() gin.HandlerFunc { func S3CallbackAuth() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// 验证key并查找用户 //// 验证key并查找用户
resp, _ := uploadCallbackCheck(c) //resp, _ := uploadCallbackCheck(c)
if resp.Code != 0 { //if resp.Code != 0 {
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg}) // c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg})
c.Abort() // c.Abort()
return // return
} //}
c.Next() c.Next()
} }

@ -268,6 +268,7 @@ func (file *File) UpdatePicInfo(value string) error {
} }
// UpdateSize 更新文件的大小信息 // UpdateSize 更新文件的大小信息
// TODO: 全局锁
func (file *File) UpdateSize(value uint64) error { func (file *File) UpdateSize(value uint64) error {
tx := DB.Begin() tx := DB.Begin()
var sizeDelta uint64 var sizeDelta uint64
@ -281,7 +282,10 @@ func (file *File) UpdateSize(value uint64) error {
sizeDelta = file.Size - value sizeDelta = file.Size - value
} }
if res := tx.Model(&file).Set("gorm:association_autoupdate", false).Update("size", value); res.Error != nil { if res := tx.Model(&file).
Where("size = ?", file.Size).
Set("gorm:association_autoupdate", false).
Update("size", value); res.Error != nil {
tx.Rollback() tx.Rollback()
return res.Error return res.Error
} }
@ -291,6 +295,7 @@ func (file *File) UpdateSize(value uint64) error {
return err return err
} }
file.Size = value
return tx.Commit().Error return tx.Commit().Error
} }
@ -299,7 +304,7 @@ func (file *File) UpdateSourceName(value string) error {
return DB.Model(&file).Set("gorm:association_autoupdate", false).Update("source_name", value).Error return DB.Model(&file).Set("gorm:association_autoupdate", false).Update("source_name", value).Error
} }
func (file *File) PopChunkToFile(lastModified *time.Time) error { func (file *File) PopChunkToFile(lastModified *time.Time, picInfo string) error {
file.UploadSessionID = nil file.UploadSessionID = nil
if lastModified != nil { if lastModified != nil {
file.UpdatedAt = *lastModified file.UpdatedAt = *lastModified
@ -308,6 +313,7 @@ func (file *File) PopChunkToFile(lastModified *time.Time) error {
return DB.Model(file).UpdateColumns(map[string]interface{}{ return DB.Model(file).UpdateColumns(map[string]interface{}{
"upload_session_id": file.UploadSessionID, "upload_session_id": file.UploadSessionID,
"updated_at": file.UpdatedAt, "updated_at": file.UpdatedAt,
"pic_info": picInfo,
}).Error }).Error
} }

@ -125,6 +125,7 @@ func addDefaultSettings() {
{Name: "onedrive_callback_check", Value: `20`, Type: "timeout"}, {Name: "onedrive_callback_check", Value: `20`, Type: "timeout"},
{Name: "folder_props_timeout", Value: `300`, Type: "timeout"}, {Name: "folder_props_timeout", Value: `300`, Type: "timeout"},
{Name: "onedrive_chunk_retries", Value: `1`, Type: "retry"}, {Name: "onedrive_chunk_retries", Value: `1`, Type: "retry"},
{Name: "slave_chunk_retries", Value: `1`, Type: "retry"},
{Name: "onedrive_source_timeout", Value: `1800`, Type: "timeout"}, {Name: "onedrive_source_timeout", Value: `1800`, Type: "timeout"},
{Name: "reset_after_upload_failed", Value: `0`, Type: "upload"}, {Name: "reset_after_upload_failed", Value: `0`, Type: "upload"},
{Name: "login_captcha", Value: `0`, Type: "login"}, {Name: "login_captcha", Value: `0`, Type: "login"},

@ -64,7 +64,7 @@ func ListTasks(uid uint, page, pageSize int, order string) ([]Task, int) {
dbChain = dbChain.Where("user_id = ?", uid) dbChain = dbChain.Where("user_id = ?", uid)
// 计算总数用于分页 // 计算总数用于分页
dbChain.Model(&Share{}).Count(&total) dbChain.Model(&Task{}).Count(&total)
// 查询记录 // 查询记录
dbChain.Limit(pageSize).Offset((page - 1) * pageSize).Order(order).Find(&tasks) dbChain.Limit(pageSize).Offset((page - 1) * pageSize).Order(order).Find(&tasks)

@ -1,11 +1,15 @@
package cluster package cluster
import ( import (
"bytes"
"encoding/json" "encoding/json"
"errors"
"fmt"
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/cloudreve/Cloudreve/v3/pkg/util"
@ -40,7 +44,7 @@ func (node *SlaveNode) Init(nodeModel *model.Node) {
var endpoint *url.URL var endpoint *url.URL
if serverURL, err := url.Parse(node.Model.Server); err == nil { if serverURL, err := url.Parse(node.Model.Server); err == nil {
var controller *url.URL var controller *url.URL
controller, _ = url.Parse("/api/v3/slave") controller, _ = url.Parse("/api/v3/slave/")
endpoint = serverURL.ResolveReference(controller) endpoint = serverURL.ResolveReference(controller)
} }
@ -408,3 +412,41 @@ func getAria2RequestBody(body *serializer.SlaveAria2Call) (io.Reader, error) {
return strings.NewReader(string(reqBodyEncoded)), nil return strings.NewReader(string(reqBodyEncoded)), nil
} }
// TODO: move to slave pkg
// RemoteCallback 发送远程存储策略上传回调请求
func RemoteCallback(url string, body serializer.UploadCallback) error {
callbackBody, err := json.Marshal(struct {
Data serializer.UploadCallback `json:"data"`
}{
Data: body,
})
if err != nil {
return serializer.NewError(serializer.CodeCallbackError, "无法编码回调正文", err)
}
resp := request.GeneralClient.Request(
"POST",
url,
bytes.NewReader(callbackBody),
request.WithTimeout(time.Duration(conf.SlaveConfig.CallbackTimeout)*time.Second),
request.WithCredential(auth.General, int64(conf.SlaveConfig.SignatureTTL)),
)
if resp.Err != nil {
return serializer.NewError(serializer.CodeCallbackError, "从机无法发起回调请求", resp.Err)
}
// 解析回调服务端响应
response, err := resp.DecodeResponse()
if err != nil {
msg := fmt.Sprintf("从机无法解析主机返回的响应 (StatusCode=%d)", resp.Response.StatusCode)
return serializer.NewError(serializer.CodeCallbackError, msg, err)
}
if response.Code != 0 {
return serializer.NewError(response.Code, response.Msg, errors.New(response.Error))
}
return nil
}

@ -441,3 +441,125 @@ func TestSlaveCaller_DeleteTempFile(t *testing.T) {
a.NoError(err) a.NoError(err)
} }
} }
//func TestRemoteCallback(t *testing.T) {
// asserts := assert.New(t)
//
// // 回调成功
// {
// clientMock := request.ClientMock{}
// mockResp, _ := json.Marshal(serializer.Response{Code: 0})
// clientMock.On(
// "Request",
// "POST",
// "http://test/test/url",
// testMock.Anything,
// testMock.Anything,
// ).Return(&request.Response{
// Err: nil,
// Response: &http.Response{
// StatusCode: 200,
// Body: ioutil.NopCloser(bytes.NewReader(mockResp)),
// },
// })
// request.GeneralClient = clientMock
// resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
// SourceName: "source",
// })
// asserts.NoError(resp)
// clientMock.AssertExpectations(t)
// }
//
// // 服务端返回业务错误
// {
// clientMock := request.ClientMock{}
// mockResp, _ := json.Marshal(serializer.Response{Code: 401})
// clientMock.On(
// "Request",
// "POST",
// "http://test/test/url",
// testMock.Anything,
// testMock.Anything,
// ).Return(&request.Response{
// Err: nil,
// Response: &http.Response{
// StatusCode: 200,
// Body: ioutil.NopCloser(bytes.NewReader(mockResp)),
// },
// })
// request.GeneralClient = clientMock
// resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
// SourceName: "source",
// })
// asserts.EqualValues(401, resp.(serializer.AppError).Code)
// clientMock.AssertExpectations(t)
// }
//
// // 无法解析回调响应
// {
// clientMock := request.ClientMock{}
// clientMock.On(
// "Request",
// "POST",
// "http://test/test/url",
// testMock.Anything,
// testMock.Anything,
// ).Return(&request.Response{
// Err: nil,
// Response: &http.Response{
// StatusCode: 200,
// Body: ioutil.NopCloser(strings.NewReader("mockResp")),
// },
// })
// request.GeneralClient = clientMock
// resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
// SourceName: "source",
// })
// asserts.Error(resp)
// clientMock.AssertExpectations(t)
// }
//
// // HTTP状态码非200
// {
// clientMock := request.ClientMock{}
// clientMock.On(
// "Request",
// "POST",
// "http://test/test/url",
// testMock.Anything,
// testMock.Anything,
// ).Return(&request.Response{
// Err: nil,
// Response: &http.Response{
// StatusCode: 404,
// Body: ioutil.NopCloser(strings.NewReader("mockResp")),
// },
// })
// request.GeneralClient = clientMock
// resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
// SourceName: "source",
// })
// asserts.Error(resp)
// clientMock.AssertExpectations(t)
// }
//
// // 无法发起回调
// {
// clientMock := request.ClientMock{}
// clientMock.On(
// "Request",
// "POST",
// "http://test/test/url",
// testMock.Anything,
// testMock.Anything,
// ).Return(&request.Response{
// Err: errors.New("error"),
// })
// request.GeneralClient = clientMock
// resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
// SourceName: "source",
// })
// asserts.Error(resp)
// clientMock.AssertExpectations(t)
// }
//}

@ -306,9 +306,9 @@ func (fs *FileSystem) Decompress(ctx context.Context, src, dst string) error {
err = fs.UploadFromStream(ctx, &fsctx.FileStream{ err = fs.UploadFromStream(ctx, &fsctx.FileStream{
File: fileStream, File: fileStream,
Size: uint64(size), Size: uint64(size),
Name: path.Base(dst), Name: path.Base(savePath),
VirtualPath: path.Dir(dst), VirtualPath: path.Dir(savePath),
}) }, true)
fileStream.Close() fileStream.Close()
if err != nil { if err != nil {
util.Log().Debug("无法上传压缩包内的文件%s , %s , 跳过", rawPath, err) util.Log().Debug("无法上传压缩包内的文件%s , %s , 跳过", rawPath, err)

@ -0,0 +1,31 @@
package backoff
import "time"
// Backoff used for retry sleep backoff
type Backoff interface {
Next() bool
Reset()
}
// ConstantBackoff implements Backoff interface with constant sleep time
type ConstantBackoff struct {
Sleep time.Duration
Max int
tried int
}
func (c *ConstantBackoff) Next() bool {
c.tried++
if c.tried > c.Max {
return false
}
time.Sleep(c.Sleep)
return true
}
func (c *ConstantBackoff) Reset() {
c.tried = 0
}

@ -0,0 +1,91 @@
package chunk
import (
"context"
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"io"
)
// ChunkProcessFunc callback function for processing a chunk
type ChunkProcessFunc func(c *ChunkGroup, chunk io.Reader) error
// ChunkGroup manage groups of chunks
type ChunkGroup struct {
file fsctx.FileHeader
chunkSize uint64
backoff backoff.Backoff
fileInfo *fsctx.UploadTaskInfo
currentIndex int
chunkNum uint64
}
func NewChunkGroup(file fsctx.FileHeader, chunkSize uint64, backoff backoff.Backoff) *ChunkGroup {
c := &ChunkGroup{
file: file,
chunkSize: chunkSize,
backoff: backoff,
fileInfo: file.Info(),
currentIndex: -1,
}
if c.chunkSize == 0 {
c.chunkSize = c.fileInfo.Size
}
c.chunkNum = c.fileInfo.Size / c.chunkSize
if c.fileInfo.Size%c.chunkSize != 0 || c.fileInfo.Size == 0 {
c.chunkNum++
}
return c
}
// Process a chunk with retry logic
func (c *ChunkGroup) Process(processor ChunkProcessFunc) error {
err := processor(c, io.LimitReader(c.file, int64(c.chunkSize)))
if err != nil {
if err != context.Canceled && c.file.Seekable() && c.backoff.Next() {
if _, seekErr := c.file.Seek(c.Start(), io.SeekStart); seekErr != nil {
return fmt.Errorf("failed to seek back to chunk start: %w, last error: %w", seekErr, err)
}
util.Log().Debug("Retrying chunk %d, last error: %s", c.currentIndex, err)
return c.Process(processor)
}
return err
}
return nil
}
// Start returns the byte index of current chunk
func (c *ChunkGroup) Start() int64 {
return int64(uint64(c.Index()) * c.chunkSize)
}
// Index returns current chunk index, starts from 0
func (c *ChunkGroup) Index() int {
return c.currentIndex
}
// Next switch to next chunk, returns whether all chunks are processed
func (c *ChunkGroup) Next() bool {
c.currentIndex++
c.backoff.Reset()
return c.currentIndex < int(c.chunkNum)
}
// Length returns the length of current chunk
func (c *ChunkGroup) Length() int64 {
contentLength := c.chunkSize
if c.Index() == int(c.chunkNum-1) {
contentLength = c.fileInfo.Size - c.chunkSize*(c.chunkNum-1)
}
return int64(contentLength)
}

@ -267,6 +267,10 @@ func (handler Driver) Source(
// Token 获取上传策略和认证Token本地策略直接返回空值 // Token 获取上传策略和认证Token本地策略直接返回空值
func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) { func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
if util.Exists(uploadSession.SavePath) {
return nil, errors.New("placeholder file already exist")
}
return &serializer.UploadCredential{ return &serializer.UploadCredential{
SessionID: uploadSession.Key, SessionID: uploadSession.Key,
ChunkSize: handler.Policy.OptionsSerialized.ChunkSize, ChunkSize: handler.Policy.OptionsSerialized.ChunkSize,

@ -3,25 +3,40 @@ package remote
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gofrs/uuid"
"io"
"net/http" "net/http"
"net/url" "net/url"
"path" "path"
"strings" "strings"
"time"
) )
const ( const (
basePath = "/api/v3/slave" basePath = "/api/v3/slave/"
OverwriteHeader = auth.CrHeaderPrefix + "Overwrite" OverwriteHeader = auth.CrHeaderPrefix + "Overwrite"
chunkRetrySleep = time.Duration(5) * time.Second
) )
// Client to operate remote slave server // Client to operate uploading to remote slave server
type Client interface { type Client interface {
// CreateUploadSession creates remote upload session
CreateUploadSession(ctx context.Context, session *serializer.UploadSession, ttl int64) error CreateUploadSession(ctx context.Context, session *serializer.UploadSession, ttl int64) error
// GetUploadURL signs an url for uploading file
GetUploadURL(ttl int64, sessionID string) (string, string, error) GetUploadURL(ttl int64, sessionID string) (string, string, error)
// Upload uploads file to remote server
Upload(ctx context.Context, file fsctx.FileHeader) error
// DeleteUploadSession deletes remote upload session
DeleteUploadSession(ctx context.Context, sessionID string) error
} }
// NewClient creates new Client from given policy // NewClient creates new Client from given policy
@ -42,6 +57,7 @@ func NewClient(policy *model.Policy) (Client, error) {
request.WithEndpoint(serverURL.ResolveReference(base).String()), request.WithEndpoint(serverURL.ResolveReference(base).String()),
request.WithCredential(authInstance, int64(signTTL)), request.WithCredential(authInstance, int64(signTTL)),
request.WithMasterMeta(), request.WithMasterMeta(),
request.WithSlaveMeta(policy.AccessKey),
), ),
}, nil }, nil
} }
@ -52,6 +68,68 @@ type remoteClient struct {
httpClient request.Client httpClient request.Client
} }
func (c *remoteClient) Upload(ctx context.Context, file fsctx.FileHeader) error {
ttl := model.GetIntSetting("upload_session_timeout", 86400)
fileInfo := file.Info()
session := &serializer.UploadSession{
Key: uuid.Must(uuid.NewV4()).String(),
VirtualPath: fileInfo.VirtualPath,
Name: fileInfo.FileName,
Size: fileInfo.Size,
SavePath: fileInfo.SavePath,
LastModified: fileInfo.LastModified,
Policy: *c.policy,
}
// Create upload session
if err := c.CreateUploadSession(ctx, session, int64(ttl)); err != nil {
return fmt.Errorf("failed to create upload session: %w", err)
}
overwrite := fileInfo.Mode&fsctx.Overwrite == fsctx.Overwrite
// Initial chunk groups
chunks := chunk.NewChunkGroup(file, c.policy.OptionsSerialized.ChunkSize, &backoff.ConstantBackoff{
Max: model.GetIntSetting("onedrive_chunk_retries", 1),
Sleep: chunkRetrySleep,
})
uploadFunc := func(current *chunk.ChunkGroup, content io.Reader) error {
return c.uploadChunk(ctx, session.Key, current.Index(), content, overwrite, current.Length())
}
// upload chunks
for chunks.Next() {
if err := chunks.Process(uploadFunc); err != nil {
if err := c.DeleteUploadSession(ctx, session.Key); err != nil {
util.Log().Warning("failed to delete upload session: %s", err)
}
return fmt.Errorf("failed to upload chunk #%d: %w", chunks.Index(), err)
}
}
return nil
}
func (c *remoteClient) DeleteUploadSession(ctx context.Context, sessionID string) error {
resp, err := c.httpClient.Request(
"DELETE",
"upload/"+sessionID,
nil,
request.WithContext(ctx),
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return err
}
if resp.Code != 0 {
return serializer.NewErrorFromResponse(resp)
}
return nil
}
func (c *remoteClient) CreateUploadSession(ctx context.Context, session *serializer.UploadSession, ttl int64) error { func (c *remoteClient) CreateUploadSession(ctx context.Context, session *serializer.UploadSession, ttl int64) error {
reqBodyEncoded, err := json.Marshal(map[string]interface{}{ reqBodyEncoded, err := json.Marshal(map[string]interface{}{
"session": session, "session": session,
@ -94,3 +172,24 @@ func (c *remoteClient) GetUploadURL(ttl int64, sessionID string) (string, string
req = auth.SignRequest(c.authInstance, req, ttl) req = auth.SignRequest(c.authInstance, req, ttl)
return req.URL.String(), req.Header["Authorization"][0], nil return req.URL.String(), req.Header["Authorization"][0], nil
} }
func (c *remoteClient) uploadChunk(ctx context.Context, sessionID string, index int, chunk io.Reader, overwrite bool, size int64) error {
resp, err := c.httpClient.Request(
"POST",
fmt.Sprintf("upload/%s?chunk=%d", sessionID, index),
chunk,
request.WithContext(ctx),
request.WithTimeout(time.Duration(0)),
request.WithContentLength(size),
request.WithHeader(map[string][]string{OverwriteHeader: {fmt.Sprintf("%t", overwrite)}}),
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return err
}
if resp.Code != 0 {
return serializer.NewErrorFromResponse(resp)
}
return nil
}

@ -6,7 +6,6 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/http"
"net/url" "net/url"
"path" "path"
"strings" "strings"
@ -26,10 +25,11 @@ type Driver struct {
Policy *model.Policy Policy *model.Policy
AuthInstance auth.Auth AuthInstance auth.Auth
client Client uploadClient Client
} }
// NewDriver initializes a new Driver from policy // NewDriver initializes a new Driver from policy
// TODO: refactor all method into upload client
func NewDriver(policy *model.Policy) (*Driver, error) { func NewDriver(policy *model.Policy) (*Driver, error) {
client, err := NewClient(policy) client, err := NewClient(policy)
if err != nil { if err != nil {
@ -40,12 +40,12 @@ func NewDriver(policy *model.Policy) (*Driver, error) {
Policy: policy, Policy: policy,
Client: request.NewClient(), Client: request.NewClient(),
AuthInstance: auth.HMACAuth{[]byte(policy.SecretKey)}, AuthInstance: auth.HMACAuth{[]byte(policy.SecretKey)},
client: client, uploadClient: client,
}, nil }, nil
} }
// List 列取文件 // List 列取文件
func (handler Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) { func (handler *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) {
var res []response.Object var res []response.Object
reqBody := serializer.ListRequest{ reqBody := serializer.ListRequest{
@ -87,7 +87,7 @@ func (handler Driver) List(ctx context.Context, path string, recursive bool) ([]
} }
// getAPIUrl 获取接口请求地址 // getAPIUrl 获取接口请求地址
func (handler Driver) getAPIUrl(scope string, routes ...string) string { func (handler *Driver) getAPIUrl(scope string, routes ...string) string {
serverURL, err := url.Parse(handler.Policy.Server) serverURL, err := url.Parse(handler.Policy.Server)
if err != nil { if err != nil {
return "" return ""
@ -113,7 +113,7 @@ func (handler Driver) getAPIUrl(scope string, routes ...string) string {
} }
// Get 获取文件内容 // Get 获取文件内容
func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) { func (handler *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
// 尝试获取速度限制 // 尝试获取速度限制
speedLimit := 0 speedLimit := 0
if user, ok := ctx.Value(fsctx.UserCtx).(model.User); ok { if user, ok := ctx.Value(fsctx.UserCtx).(model.User); ok {
@ -150,63 +150,15 @@ func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser,
} }
// Put 将文件流保存到指定目录 // Put 将文件流保存到指定目录
func (handler Driver) Put(ctx context.Context, file fsctx.FileHeader) error { func (handler *Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
defer file.Close() defer file.Close()
// 凭证有效期 return handler.uploadClient.Upload(ctx, file)
credentialTTL := model.GetIntSetting("upload_credential_timeout", 3600)
// 生成上传策略
fileInfo := file.Info()
policy := serializer.UploadPolicy{
SavePath: path.Dir(fileInfo.SavePath),
FileName: path.Base(fileInfo.FileName),
AutoRename: false,
MaxSize: fileInfo.Size,
}
credential, err := handler.getUploadCredential(ctx, policy, int64(credentialTTL))
if err != nil {
return err
}
// 对文件名进行URLEncode
fileName := url.QueryEscape(path.Base(fileInfo.SavePath))
// 决定是否要禁用文件覆盖
overwrite := "false"
if fileInfo.Mode&fsctx.Overwrite == fsctx.Overwrite {
overwrite = "true"
}
// 上传文件
resp, err := handler.Client.Request(
"POST",
handler.Policy.GetUploadURL(),
file,
request.WithHeader(map[string][]string{
"X-Cr-Policy": {credential.Policy},
"X-Cr-FileName": {fileName},
"X-Cr-Overwrite": {overwrite},
}),
request.WithContentLength(int64(fileInfo.Size)),
request.WithTimeout(time.Duration(0)),
request.WithMasterMeta(),
request.WithSlaveMeta(handler.Policy.AccessKey),
request.WithCredential(handler.AuthInstance, int64(credentialTTL)),
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return err
}
if resp.Code != 0 {
return errors.New(resp.Msg)
}
return nil
} }
// Delete 删除一个或多个文件, // Delete 删除一个或多个文件,
// 返回未删除的文件,及遇到的最后一个错误 // 返回未删除的文件,及遇到的最后一个错误
func (handler Driver) Delete(ctx context.Context, files []string) ([]string, error) { func (handler *Driver) Delete(ctx context.Context, files []string) ([]string, error) {
// 封装接口请求正文 // 封装接口请求正文
reqBody := serializer.RemoteDeleteRequest{ reqBody := serializer.RemoteDeleteRequest{
Files: files, Files: files,
@ -252,7 +204,7 @@ func (handler Driver) Delete(ctx context.Context, files []string) ([]string, err
} }
// Thumb 获取文件缩略图 // Thumb 获取文件缩略图
func (handler Driver) Thumb(ctx context.Context, path string) (*response.ContentResponse, error) { func (handler *Driver) Thumb(ctx context.Context, path string) (*response.ContentResponse, error) {
sourcePath := base64.RawURLEncoding.EncodeToString([]byte(path)) sourcePath := base64.RawURLEncoding.EncodeToString([]byte(path))
thumbURL := handler.getAPIUrl("thumb") + "/" + sourcePath thumbURL := handler.getAPIUrl("thumb") + "/" + sourcePath
ttl := model.GetIntSetting("preview_timeout", 60) ttl := model.GetIntSetting("preview_timeout", 60)
@ -268,7 +220,7 @@ func (handler Driver) Thumb(ctx context.Context, path string) (*response.Content
} }
// Source 获取外链URL // Source 获取外链URL
func (handler Driver) Source( func (handler *Driver) Source(
ctx context.Context, ctx context.Context,
path string, path string,
baseURL url.URL, baseURL url.URL,
@ -322,16 +274,21 @@ func (handler Driver) Source(
} }
// Token 获取上传策略和认证Token // Token 获取上传策略和认证Token
func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) { func (handler *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
siteURL := model.GetSiteURL()
apiBaseURI, _ := url.Parse(path.Join("/api/v3/callback/remote", uploadSession.Key, uploadSession.CallbackSecret))
apiURL := siteURL.ResolveReference(apiBaseURI)
// 在从机端创建上传会话 // 在从机端创建上传会话
if err := handler.client.CreateUploadSession(ctx, uploadSession, ttl); err != nil { uploadSession.Callback = apiURL.String()
if err := handler.uploadClient.CreateUploadSession(ctx, uploadSession, ttl); err != nil {
return nil, err return nil, err
} }
// 获取上传地址 // 获取上传地址
uploadURL, sign, err := handler.client.GetUploadURL(ttl, uploadSession.Key) uploadURL, sign, err := handler.uploadClient.GetUploadURL(ttl, uploadSession.Key)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to sign upload url: %w", err)
} }
return &serializer.UploadCredential{ return &serializer.UploadCredential{
@ -342,30 +299,7 @@ func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *seria
}, nil }, nil
} }
func (handler Driver) getUploadCredential(ctx context.Context, policy serializer.UploadPolicy, TTL int64) (serializer.UploadCredential, error) {
policyEncoded, err := policy.EncodeUploadPolicy()
if err != nil {
return serializer.UploadCredential{}, err
}
// 签名上传策略
uploadRequest, _ := http.NewRequest("POST", "/api/v3/slave/upload", nil)
uploadRequest.Header = map[string][]string{
"X-Cr-Policy": {policyEncoded},
"X-Cr-Overwrite": {"false"},
}
auth.SignRequest(handler.AuthInstance, uploadRequest, TTL)
if credential, ok := uploadRequest.Header["Authorization"]; ok && len(credential) == 1 {
return serializer.UploadCredential{
Token: credential[0],
Policy: policyEncoded,
}, nil
}
return serializer.UploadCredential{}, errors.New("无法签名上传策略")
}
// 取消上传凭证 // 取消上传凭证
func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error { func (handler *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
return nil return handler.uploadClient.DeleteUploadSession(ctx, uploadSession.Key)
} }

@ -30,7 +30,7 @@ func NewDriver(node cluster.Node, handler driver.Handler, policy *model.Policy)
var endpoint *url.URL var endpoint *url.URL
if serverURL, err := url.Parse(node.DBModel().Server); err == nil { if serverURL, err := url.Parse(node.DBModel().Server); err == nil {
var controller *url.URL var controller *url.URL
controller, _ = url.Parse("/api/v3/slave") controller, _ = url.Parse("/api/v3/slave/")
endpoint = serverURL.ResolveReference(controller) endpoint = serverURL.ResolveReference(controller)
} }
@ -52,14 +52,10 @@ func NewDriver(node cluster.Node, handler driver.Handler, policy *model.Policy)
func (d *Driver) Put(ctx context.Context, file fsctx.FileHeader) error { func (d *Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
defer file.Close() defer file.Close()
src, ok := ctx.Value(fsctx.SlaveSrcPath).(string) fileInfo := file.Info()
if !ok {
return ErrSlaveSrcPathNotExist
}
req := serializer.SlaveTransferReq{ req := serializer.SlaveTransferReq{
Src: src, Src: fileInfo.Src,
Dst: file.Info().SavePath, Dst: fileInfo.SavePath,
Policy: d.policy, Policy: d.policy,
} }

@ -207,7 +207,7 @@ func NewFileSystemFromCallback(c *gin.Context) (*FileSystem, error) {
} }
// 获取回调会话 // 获取回调会话
callbackSessionRaw, ok := c.Get("callbackSession") callbackSessionRaw, ok := c.Get(UploadSessionCtx)
if !ok { if !ok {
return nil, errors.New("找不到回调会话") return nil, errors.New("找不到回调会话")
} }

@ -26,15 +26,18 @@ type UploadTaskInfo struct {
UploadSessionID *string UploadSessionID *string
AppendStart uint64 AppendStart uint64
Model interface{} Model interface{}
Src string
} }
// FileHeader 上传来的文件数据处理器 // FileHeader 上传来的文件数据处理器
type FileHeader interface { type FileHeader interface {
io.Reader io.Reader
io.Closer io.Closer
io.Seeker
Info() *UploadTaskInfo Info() *UploadTaskInfo
SetSize(uint64) SetSize(uint64)
SetModel(fileModel interface{}) SetModel(fileModel interface{})
Seekable() bool
} }
// FileStream 用户传来的文件 // FileStream 用户传来的文件
@ -43,6 +46,7 @@ type FileStream struct {
LastModified *time.Time LastModified *time.Time
Metadata map[string]string Metadata map[string]string
File io.ReadCloser File io.ReadCloser
Seeker io.Seeker
Size uint64 Size uint64
VirtualPath string VirtualPath string
Name string Name string
@ -51,14 +55,31 @@ type FileStream struct {
UploadSessionID *string UploadSessionID *string
AppendStart uint64 AppendStart uint64
Model interface{} Model interface{}
Src string
} }
func (file *FileStream) Read(p []byte) (n int, err error) { func (file *FileStream) Read(p []byte) (n int, err error) {
if file.File != nil {
return file.File.Read(p) return file.File.Read(p)
}
return 0, io.EOF
} }
func (file *FileStream) Close() error { func (file *FileStream) Close() error {
if file.File != nil {
return file.File.Close() return file.File.Close()
}
return nil
}
func (file *FileStream) Seek(offset int64, whence int) (int64, error) {
return file.Seeker.Seek(offset, whence)
}
func (file *FileStream) Seekable() bool {
return file.Seeker != nil
} }
func (file *FileStream) Info() *UploadTaskInfo { func (file *FileStream) Info() *UploadTaskInfo {
@ -74,6 +95,7 @@ func (file *FileStream) Info() *UploadTaskInfo {
UploadSessionID: file.UploadSessionID, UploadSessionID: file.UploadSessionID,
AppendStart: file.AppendStart, AppendStart: file.AppendStart,
Model: file.Model, Model: file.Model,
Src: file.Src,
} }
} }

@ -2,13 +2,12 @@ package filesystem
import ( import (
"context" "context"
"errors"
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/cloudreve/Cloudreve/v3/pkg/util"
"io/ioutil" "io/ioutil"
@ -178,9 +177,8 @@ func GenericAfterUpdate(ctx context.Context, fs *FileSystem, newFile fsctx.FileH
} }
// SlaveAfterUpload Slave模式下上传完成钩子 // SlaveAfterUpload Slave模式下上传完成钩子
func SlaveAfterUpload(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error { func SlaveAfterUpload(session *serializer.UploadSession) Hook {
return errors.New("") return func(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
policy := ctx.Value(fsctx.UploadPolicyCtx).(serializer.UploadPolicy)
fileInfo := fileHeader.Info() fileInfo := fileHeader.Info()
// 构造一个model.File用于生成缩略图 // 构造一个model.File用于生成缩略图
@ -190,18 +188,17 @@ func SlaveAfterUpload(ctx context.Context, fs *FileSystem, fileHeader fsctx.File
} }
fs.GenerateThumbnail(ctx, &file) fs.GenerateThumbnail(ctx, &file)
if policy.CallbackURL == "" { if session.Callback == "" {
return nil return nil
} }
// 发送回调请求 // 发送回调请求
callbackBody := serializer.UploadCallback{ callbackBody := serializer.UploadCallback{
Name: file.Name,
SourceName: file.SourceName,
PicInfo: file.PicInfo, PicInfo: file.PicInfo,
Size: fileInfo.Size,
} }
return request.RemoteCallback(policy.CallbackURL, callbackBody)
return cluster.RemoteCallback(session.Callback, callbackBody)
}
} }
// GenericAfterUpload 文件上传完成后,包含数据库操作 // GenericAfterUpload 文件上传完成后,包含数据库操作
@ -288,12 +285,13 @@ func HookChunkUploadFailed(ctx context.Context, fs *FileSystem, fileHeader fsctx
return fileInfo.Model.(*model.File).UpdateSize(fileInfo.AppendStart) return fileInfo.Model.(*model.File).UpdateSize(fileInfo.AppendStart)
} }
// HookChunkUploadFinished 分片上传结束后处理文件 // HookPopPlaceholderToFile 将占位文件提升为正式文件
func HookChunkUploadFinished(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error { func HookPopPlaceholderToFile(picInfo string) Hook {
return func(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
fileInfo := fileHeader.Info() fileInfo := fileHeader.Info()
fileModel := fileInfo.Model.(*model.File) fileModel := fileInfo.Model.(*model.File)
return fileModel.PopChunkToFile(fileInfo.LastModified, picInfo)
return fileModel.PopChunkToFile(fileInfo.LastModified) }
} }
// HookChunkUploadFinished 分片上传结束后处理文件 // HookChunkUploadFinished 分片上传结束后处理文件

@ -23,6 +23,8 @@ import (
const ( const (
UploadSessionMetaKey = "upload_session" UploadSessionMetaKey = "upload_session"
UploadSessionCtx = "uploadSession"
UserCtx = "user"
UploadSessionCachePrefix = "callback_" UploadSessionCachePrefix = "callback_"
) )
@ -47,11 +49,11 @@ func (fs *FileSystem) Upload(ctx context.Context, file *fsctx.FileStream) (err e
file.SavePath = savePath file.SavePath = savePath
} }
// 保存文件
if file.Mode&fsctx.Nop != fsctx.Nop {
// 处理客户端未完成上传时,关闭连接 // 处理客户端未完成上传时,关闭连接
go fs.CancelUpload(ctx, savePath, file) go fs.CancelUpload(ctx, savePath, file)
// 保存文件
if file.Mode&fsctx.Nop != fsctx.Nop {
err = fs.Handler.Put(ctx, file) err = fs.Handler.Put(ctx, file)
if err != nil { if err != nil {
fs.Trigger(ctx, "AfterUploadFailed", file) fs.Trigger(ctx, "AfterUploadFailed", file)
@ -176,7 +178,7 @@ func (fs *FileSystem) CreateUploadSession(ctx context.Context, file *fsctx.FileS
fs.Use("AfterUpload", HookClearFileHeaderSize) fs.Use("AfterUpload", HookClearFileHeaderSize)
} }
fs.Use("AfterUpload", GenericAfterUpload) // 验证文件规格
if err := fs.Upload(ctx, file); err != nil { if err := fs.Upload(ctx, file); err != nil {
return nil, err return nil, err
} }
@ -190,6 +192,7 @@ func (fs *FileSystem) CreateUploadSession(ctx context.Context, file *fsctx.FileS
Size: fileSize, Size: fileSize,
SavePath: file.SavePath, SavePath: file.SavePath,
LastModified: file.LastModified, LastModified: file.LastModified,
CallbackSecret: util.RandStringRunes(32),
} }
// 获取上传凭证 // 获取上传凭证
@ -198,10 +201,16 @@ func (fs *FileSystem) CreateUploadSession(ctx context.Context, file *fsctx.FileS
return nil, err return nil, err
} }
// 创建占位符
fs.Use("AfterUpload", GenericAfterUpload)
if err := fs.Upload(ctx, file); err != nil {
return nil, err
}
// 创建回调会话 // 创建回调会话
err = cache.Set( err = cache.Set(
UploadSessionCachePrefix+callbackKey, UploadSessionCachePrefix+callbackKey,
uploadSession, *uploadSession,
callBackSessionTTL, callBackSessionTTL,
) )
if err != nil { if err != nil {
@ -215,7 +224,16 @@ func (fs *FileSystem) CreateUploadSession(ctx context.Context, file *fsctx.FileS
} }
// UploadFromStream 从文件流上传文件 // UploadFromStream 从文件流上传文件
func (fs *FileSystem) UploadFromStream(ctx context.Context, file *fsctx.FileStream) error { func (fs *FileSystem) UploadFromStream(ctx context.Context, file *fsctx.FileStream, resetPolicy bool) error {
if resetPolicy {
// 重设存储策略
fs.Policy = &fs.User.Policy
err := fs.DispatchHandler()
if err != nil {
return err
}
}
// 给文件系统分配钩子 // 给文件系统分配钩子
fs.Lock.Lock() fs.Lock.Lock()
if fs.Hooks == nil { if fs.Hooks == nil {
@ -233,16 +251,7 @@ func (fs *FileSystem) UploadFromStream(ctx context.Context, file *fsctx.FileStre
} }
// UploadFromPath 将本机已有文件上传到用户的文件系统 // UploadFromPath 将本机已有文件上传到用户的文件系统
func (fs *FileSystem) UploadFromPath(ctx context.Context, src, dst string, resetPolicy bool, mode fsctx.WriteMode) error { func (fs *FileSystem) UploadFromPath(ctx context.Context, src, dst string, mode fsctx.WriteMode) error {
// 重设存储策略
if resetPolicy {
fs.Policy = &fs.User.Policy
err := fs.DispatchHandler()
if err != nil {
return err
}
}
file, err := os.Open(util.RelativePath(src)) file, err := os.Open(util.RelativePath(src))
if err != nil { if err != nil {
return err return err
@ -258,10 +267,11 @@ func (fs *FileSystem) UploadFromPath(ctx context.Context, src, dst string, reset
// 开始上传 // 开始上传
return fs.UploadFromStream(ctx, &fsctx.FileStream{ return fs.UploadFromStream(ctx, &fsctx.FileStream{
File: nil, File: file,
Seeker: file,
Size: uint64(size), Size: uint64(size),
Name: path.Base(dst), Name: path.Base(dst),
VirtualPath: path.Dir(dst), VirtualPath: path.Dir(dst),
Mode: mode, Mode: mode,
}) }, true)
} }

@ -5,6 +5,7 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/auth"
"net/http" "net/http"
"net/url" "net/url"
"strings"
"time" "time"
) )
@ -103,6 +104,10 @@ func WithSlaveMeta(s string) Option {
// Endpoint 使用同一的请求Endpoint // Endpoint 使用同一的请求Endpoint
func WithEndpoint(endpoint string) Option { func WithEndpoint(endpoint string) Option {
if !strings.HasSuffix(endpoint, "/") {
endpoint += "/"
}
endpointURL, _ := url.Parse(endpoint) endpointURL, _ := url.Parse(endpoint)
return optionFunc(func(o *options) { return optionFunc(func(o *options) {
o.endpoint = endpointURL o.endpoint = endpointURL

@ -7,7 +7,7 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"path" "net/url"
"strings" "strings"
"sync" "sync"
@ -51,7 +51,7 @@ func NewClient(opts ...Option) Client {
} }
// Request 发送HTTP请求 // Request 发送HTTP请求
func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Option) *Response { func (c *HTTPClient) Request(method, target string, body io.Reader, opts ...Option) *Response {
// 应用额外设置 // 应用额外设置
c.mu.Lock() c.mu.Lock()
options := *c.options options := *c.options
@ -70,9 +70,13 @@ func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Optio
// 确定请求URL // 确定请求URL
if options.endpoint != nil { if options.endpoint != nil {
targetPath, err := url.Parse(target)
if err != nil {
return &Response{Err: err}
}
targetURL := *options.endpoint targetURL := *options.endpoint
targetURL.Path = path.Join(targetURL.Path, target) target = targetURL.ResolveReference(targetPath).String()
target = targetURL.String()
} }
// 创建请求 // 创建请求

@ -1,52 +0,0 @@
package request
import (
"bytes"
"encoding/json"
"errors"
"time"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)
// TODO: move to slave pkg
// RemoteCallback 发送远程存储策略上传回调请求
func RemoteCallback(url string, body serializer.UploadCallback) error {
callbackBody, err := json.Marshal(struct {
Data serializer.UploadCallback `json:"data"`
}{
Data: body,
})
if err != nil {
return serializer.NewError(serializer.CodeCallbackError, "无法编码回调正文", err)
}
resp := GeneralClient.Request(
"POST",
url,
bytes.NewReader(callbackBody),
WithTimeout(time.Duration(conf.SlaveConfig.CallbackTimeout)*time.Second),
WithCredential(auth.General, int64(conf.SlaveConfig.SignatureTTL)),
)
if resp.Err != nil {
return serializer.NewError(serializer.CodeCallbackError, "无法发起回调请求", resp.Err)
}
// 解析回调服务端响应
resp = resp.CheckHTTPResponse(200)
if resp.Err != nil {
return serializer.NewError(serializer.CodeCallbackError, "服务器返回异常响应", resp.Err)
}
response, err := resp.DecodeResponse()
if err != nil {
return serializer.NewError(serializer.CodeCallbackError, "无法解析服务端返回的响应", err)
}
if response.Code != 0 {
return serializer.NewError(response.Code, response.Msg, errors.New(response.Error))
}
return nil
}

@ -1,137 +0,0 @@
package request
import (
"bytes"
"encoding/json"
"errors"
"io/ioutil"
"net/http"
"strings"
"testing"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
)
func TestRemoteCallback(t *testing.T) {
asserts := assert.New(t)
// 回调成功
{
clientMock := ClientMock{}
mockResp, _ := json.Marshal(serializer.Response{Code: 0})
clientMock.On(
"Request",
"POST",
"http://test/test/url",
testMock.Anything,
testMock.Anything,
).Return(&Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(mockResp)),
},
})
GeneralClient = clientMock
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
SourceName: "source",
})
asserts.NoError(resp)
clientMock.AssertExpectations(t)
}
// 服务端返回业务错误
{
clientMock := ClientMock{}
mockResp, _ := json.Marshal(serializer.Response{Code: 401})
clientMock.On(
"Request",
"POST",
"http://test/test/url",
testMock.Anything,
testMock.Anything,
).Return(&Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(mockResp)),
},
})
GeneralClient = clientMock
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
SourceName: "source",
})
asserts.EqualValues(401, resp.(serializer.AppError).Code)
clientMock.AssertExpectations(t)
}
// 无法解析回调响应
{
clientMock := ClientMock{}
clientMock.On(
"Request",
"POST",
"http://test/test/url",
testMock.Anything,
testMock.Anything,
).Return(&Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("mockResp")),
},
})
GeneralClient = clientMock
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
SourceName: "source",
})
asserts.Error(resp)
clientMock.AssertExpectations(t)
}
// HTTP状态码非200
{
clientMock := ClientMock{}
clientMock.On(
"Request",
"POST",
"http://test/test/url",
testMock.Anything,
testMock.Anything,
).Return(&Response{
Err: nil,
Response: &http.Response{
StatusCode: 404,
Body: ioutil.NopCloser(strings.NewReader("mockResp")),
},
})
GeneralClient = clientMock
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
SourceName: "source",
})
asserts.Error(resp)
clientMock.AssertExpectations(t)
}
// 无法发起回调
{
clientMock := ClientMock{}
clientMock.On(
"Request",
"POST",
"http://test/test/url",
testMock.Anything,
testMock.Anything,
).Return(&Response{
Err: errors.New("error"),
})
GeneralClient = clientMock
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
SourceName: "source",
})
asserts.Error(resp)
clientMock.AssertExpectations(t)
}
}

@ -54,6 +54,8 @@ const (
CodeNoPermissionErr = 403 CodeNoPermissionErr = 403
// CodeNotFound 资源未找到 // CodeNotFound 资源未找到
CodeNotFound = 404 CodeNotFound = 404
// CodeConflict 资源冲突
CodeConflict = 409
// CodeUploadFailed 上传出错 // CodeUploadFailed 上传出错
CodeUploadFailed = 40002 CodeUploadFailed = 40002
// CodeCredentialInvalid 凭证无效 // CodeCredentialInvalid 凭证无效

@ -45,14 +45,13 @@ type UploadSession struct {
SavePath string // 物理存储路径,包含物理文件名 SavePath string // 物理存储路径,包含物理文件名
LastModified *time.Time // 可选的文件最后修改日期 LastModified *time.Time // 可选的文件最后修改日期
Policy model.Policy Policy model.Policy
Callback string // 回调 URL 地址
CallbackSecret string // 回调 URL
} }
// UploadCallback 上传回调正文 // UploadCallback 上传回调正文
type UploadCallback struct { type UploadCallback struct {
Name string `json:"name"`
SourceName string `json:"source_name"`
PicInfo string `json:"pic_info"` PicInfo string `json:"pic_info"`
Size uint64 `json:"size"`
} }
// GeneralUploadCallbackFailed 存储策略上传回调失败响应 // GeneralUploadCallbackFailed 存储策略上传回调失败响应

@ -106,7 +106,7 @@ func (job *CompressTask) Do() {
job.TaskModel.SetProgress(TransferringProgress) job.TaskModel.SetProgress(TransferringProgress)
// 上传文件 // 上传文件
err = fs.UploadFromPath(ctx, zipFile, job.TaskProps.Dst, true, 0) err = fs.UploadFromPath(ctx, zipFile, job.TaskProps.Dst, 0)
if err != nil { if err != nil {
job.SetErrorMsg(err.Error()) job.SetErrorMsg(err.Error())
return return

@ -117,16 +117,18 @@ func (job *TransferTask) Do() {
} }
// 切换为从机节点处理上传 // 切换为从机节点处理上传
fs.SetPolicyFromPath(path.Dir(dst))
fs.SwitchToSlaveHandler(node) fs.SwitchToSlaveHandler(node)
err = fs.UploadFromStream(context.Background(), &fsctx.FileStream{ err = fs.UploadFromStream(context.Background(), &fsctx.FileStream{
File: nil, File: nil,
Size: job.TaskProps.SrcSizes[file], Size: job.TaskProps.SrcSizes[file],
Name: path.Base(dst), Name: path.Base(dst),
VirtualPath: path.Dir(dst), VirtualPath: path.Dir(dst),
}) Src: file,
}, false)
} else { } else {
// 主机节点中转 // 主机节点中转
err = fs.UploadFromPath(context.Background(), file, dst, true, 0) err = fs.UploadFromPath(context.Background(), file, dst, 0)
} }
if err != nil { if err != nil {

@ -1,6 +1,9 @@
package task package task
import "github.com/cloudreve/Cloudreve/v3/pkg/util" import (
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
// Worker 处理任务的对象 // Worker 处理任务的对象
type Worker interface { type Worker interface {
@ -20,7 +23,7 @@ func (worker *GeneralWorker) Do(job Job) {
// 致命错误捕获 // 致命错误捕获
if err := recover(); err != nil { if err := recover(); err != nil {
util.Log().Debug("任务执行出错,%s", err) util.Log().Debug("任务执行出错,%s", err)
job.SetError(&JobError{Msg: "致命错误"}) job.SetError(&JobError{Msg: "致命错误", Error: fmt.Sprintf("%s", err)})
job.SetStatus(Error) job.SetStatus(Error)
} }
}() }()

@ -42,9 +42,14 @@ func SiteConfig(c *gin.Context) {
// Ping 状态检查页面 // Ping 状态检查页面
func Ping(c *gin.Context) { func Ping(c *gin.Context) {
version := conf.BackendVersion
if conf.IsPro == "true" {
version += "-pro"
}
c.JSON(200, serializer.Response{ c.JSON(200, serializer.Response{
Code: 0, Code: 0,
Data: conf.BackendVersion, Data: conf.BackendVersion + conf.IsPro,
}) })
} }

@ -28,77 +28,6 @@ func SlaveUpload(c *gin.Context) {
} else { } else {
c.JSON(200, ErrorResponse(err)) c.JSON(200, ErrorResponse(err))
} }
//// 创建上下文
//ctx, cancel := context.WithCancel(context.Background())
//ctx = context.WithValue(ctx, fsctx.GinCtx, c)
//defer cancel()
//
//// 创建匿名文件系统
//fs, err := filesystem.NewAnonymousFileSystem()
//if err != nil {
// c.JSON(200, serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err))
// return
//}
//fs.Handler = local.Driver{}
//
//// 从请求中取得上传策略
//uploadPolicyRaw := c.GetHeader("X-Cr-Policy")
//if uploadPolicyRaw == "" {
// c.JSON(200, serializer.ParamErr("未指定上传策略", nil))
// return
//}
//
//// 解析上传策略
//uploadPolicy, err := serializer.DecodeUploadPolicy(uploadPolicyRaw)
//if err != nil {
// c.JSON(200, serializer.ParamErr("上传策略格式有误", err))
// return
//}
//ctx = context.WithValue(ctx, fsctx.UploadPolicyCtx, *uploadPolicy)
//
//// 取得文件大小
//fileSize, err := strconv.ParseUint(c.Request.Header.Get("Content-Length"), 10, 64)
//if err != nil {
// c.JSON(200, ErrorResponse(err))
// return
//}
//
//// 解码文件名和路径
//fileName, err := url.QueryUnescape(c.Request.Header.Get("X-Cr-FileName"))
//if err != nil {
// c.JSON(200, ErrorResponse(err))
// return
//}
//
//fileData := fsctx.FileStream{
// MIMEType: c.Request.Header.Get("Content-Type"),
// File: c.Request.Body,
// Name: fileName,
// Size: fileSize,
//}
//
//// 给文件系统分配钩子
//fs.Use("BeforeUpload", filesystem.HookSlaveUploadValidate)
//fs.Use("AfterUploadCanceled", filesystem.HookDeleteTempFile)
//fs.Use("AfterUpload", filesystem.SlaveAfterUpload)
//fs.Use("AfterValidateFailed", filesystem.HookDeleteTempFile)
//
////// 是否允许覆盖
////if c.Request.Header.Get("X-Cr-Overwrite") == "false" {
//// fileData.Mode = fsctx.Create
////}
//
//// 执行上传
//err = fs.LocalUpload(ctx, &fileData)
//if err != nil {
// c.JSON(200, serializer.Err(serializer.CodeUploadFailed, err.Error(), err))
// return
//}
//
//c.JSON(200, serializer.Response{
// Code: 0,
//})
} }
// SlaveGetUploadSession 从机创建上传会话 // SlaveGetUploadSession 从机创建上传会话
@ -116,6 +45,21 @@ func SlaveGetUploadSession(c *gin.Context) {
} }
} }
// SlaveDeleteUploadSession 从机删除上传会话
func SlaveDeleteUploadSession(c *gin.Context) {
// 创建上下文
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var service explorer.UploadSessionService
if err := c.ShouldBindUri(&service); err == nil {
res := service.SlaveDelete(ctx, c)
c.JSON(200, res)
} else {
c.JSON(200, ErrorResponse(err))
}
}
// SlaveDownload 从机文件下载,此请求返回的HTTP状态码不全为200 // SlaveDownload 从机文件下载,此请求返回的HTTP状态码不全为200
func SlaveDownload(c *gin.Context) { func SlaveDownload(c *gin.Context) {
// 创建上下文 // 创建上下文

@ -46,9 +46,15 @@ func InitSlaveRouter() *gin.Engine {
// 接收主机心跳包 // 接收主机心跳包
v3.POST("heartbeat", controllers.SlaveHeartbeat) v3.POST("heartbeat", controllers.SlaveHeartbeat)
// 上传 // 上传
v3.POST("upload/:sessionId", controllers.SlaveUpload) upload := v3.Group("upload")
{
// 上传分片
upload.POST(":sessionId", controllers.SlaveUpload)
// 创建上传会话上传 // 创建上传会话上传
v3.PUT("upload", controllers.SlaveGetUploadSession) upload.PUT("", controllers.SlaveGetUploadSession)
// 删除上传会话
upload.DELETE(":sessionId", controllers.SlaveDeleteUploadSession)
}
// 下载 // 下载
v3.GET("download/:speed/:path/:name", controllers.SlaveDownload) v3.GET("download/:speed/:path/:name", controllers.SlaveDownload)
// 预览 / 外链 // 预览 / 外链
@ -213,7 +219,15 @@ func InitMasterRouter() *gin.Engine {
// 事件通知 // 事件通知
slave.PUT("notification/:subject", controllers.SlaveNotificationPush) slave.PUT("notification/:subject", controllers.SlaveNotificationPush)
// 上传 // 上传
slave.POST("upload", controllers.SlaveUpload) upload := slave.Group("upload")
{
// 上传分片
upload.POST(":sessionId", controllers.SlaveUpload)
// 创建上传会话上传
upload.PUT("", controllers.SlaveGetUploadSession)
// 删除上传会话
upload.DELETE(":sessionId", controllers.SlaveDeleteUploadSession)
}
// OneDrive 存储策略凭证 // OneDrive 存储策略凭证
slave.GET("credential/onedrive/:id", controllers.SlaveGetOneDriveCredential) slave.GET("credential/onedrive/:id", controllers.SlaveGetOneDriveCredential)
} }
@ -223,7 +237,8 @@ func InitMasterRouter() *gin.Engine {
{ {
// 远程策略上传回调 // 远程策略上传回调
callback.POST( callback.POST(
"remote/:key", "remote/:sessionID/:key",
middleware.UseUploadSession("remote"),
middleware.RemoteCallbackAuth(), middleware.RemoteCallbackAuth(),
controllers.RemoteCallback, controllers.RemoteCallback,
) )

@ -3,6 +3,7 @@ package callback
import ( import (
"context" "context"
"fmt" "fmt"
model "github.com/cloudreve/Cloudreve/v3/models"
"strings" "strings"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
@ -11,13 +12,12 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/s3" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/s3"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// CallbackProcessService 上传请求回调正文接口 // CallbackProcessService 上传请求回调正文接口
type CallbackProcessService interface { type CallbackProcessService interface {
GetBody(*serializer.UploadSession) serializer.UploadCallback GetBody() serializer.UploadCallback
} }
// RemoteUploadCallbackService 远程存储上传回调请求服务 // RemoteUploadCallbackService 远程存储上传回调请求服务
@ -26,7 +26,7 @@ type RemoteUploadCallbackService struct {
} }
// GetBody 返回回调正文 // GetBody 返回回调正文
func (service RemoteUploadCallbackService) GetBody(session *serializer.UploadSession) serializer.UploadCallback { func (service RemoteUploadCallbackService) GetBody() serializer.UploadCallback {
return service.Data return service.Data
} }
@ -68,12 +68,8 @@ type S3Callback struct {
} }
// GetBody 返回回调正文 // GetBody 返回回调正文
func (service UpyunCallbackService) GetBody(session *serializer.UploadSession) serializer.UploadCallback { func (service UpyunCallbackService) GetBody() serializer.UploadCallback {
res := serializer.UploadCallback{ res := serializer.UploadCallback{}
Name: session.Name,
SourceName: service.SourceName,
Size: service.Size,
}
if service.Width != "" { if service.Width != "" {
res.PicInfo = service.Width + "," + service.Height res.PicInfo = service.Width + "," + service.Height
} }
@ -82,51 +78,41 @@ func (service UpyunCallbackService) GetBody(session *serializer.UploadSession) s
} }
// GetBody 返回回调正文 // GetBody 返回回调正文
func (service UploadCallbackService) GetBody(session *serializer.UploadSession) serializer.UploadCallback { func (service UploadCallbackService) GetBody() serializer.UploadCallback {
return serializer.UploadCallback{ return serializer.UploadCallback{
Name: service.Name,
SourceName: service.SourceName,
PicInfo: service.PicInfo, PicInfo: service.PicInfo,
Size: service.Size,
} }
} }
// GetBody 返回回调正文 // GetBody 返回回调正文
func (service OneDriveCallback) GetBody(session *serializer.UploadSession) serializer.UploadCallback { func (service OneDriveCallback) GetBody() serializer.UploadCallback {
var picInfo = "0,0" var picInfo = "0,0"
if service.Meta.Image.Width != 0 { if service.Meta.Image.Width != 0 {
picInfo = fmt.Sprintf("%d,%d", service.Meta.Image.Width, service.Meta.Image.Height) picInfo = fmt.Sprintf("%d,%d", service.Meta.Image.Width, service.Meta.Image.Height)
} }
return serializer.UploadCallback{ return serializer.UploadCallback{
Name: session.Name,
SourceName: session.SavePath,
PicInfo: picInfo, PicInfo: picInfo,
Size: session.Size,
} }
} }
// GetBody 返回回调正文 // GetBody 返回回调正文
func (service COSCallback) GetBody(session *serializer.UploadSession) serializer.UploadCallback { func (service COSCallback) GetBody() serializer.UploadCallback {
return serializer.UploadCallback{ return serializer.UploadCallback{
Name: session.Name,
SourceName: session.SavePath,
PicInfo: "", PicInfo: "",
Size: session.Size,
} }
} }
// GetBody 返回回调正文 // GetBody 返回回调正文
func (service S3Callback) GetBody(session *serializer.UploadSession) serializer.UploadCallback { func (service S3Callback) GetBody() serializer.UploadCallback {
return serializer.UploadCallback{ return serializer.UploadCallback{
Name: session.Name,
SourceName: session.SavePath,
PicInfo: "", PicInfo: "",
Size: session.Size,
} }
} }
// ProcessCallback 处理上传结果回调 // ProcessCallback 处理上传结果回调
func ProcessCallback(service CallbackProcessService, c *gin.Context) serializer.Response { func ProcessCallback(service CallbackProcessService, c *gin.Context) serializer.Response {
callbackBody := service.GetBody()
// 创建文件系统 // 创建文件系统
fs, err := filesystem.NewFileSystemFromCallback(c) fs, err := filesystem.NewFileSystemFromCallback(c)
if err != nil { if err != nil {
@ -134,51 +120,39 @@ func ProcessCallback(service CallbackProcessService, c *gin.Context) serializer.
} }
defer fs.Recycle() defer fs.Recycle()
// 获取回调会话 // 获取上传会话
callbackSessionRaw, _ := c.Get("callbackSession") uploadSession := c.MustGet(filesystem.UploadSessionCtx).(*serializer.UploadSession)
callbackSession := callbackSessionRaw.(*serializer.UploadSession)
callbackBody := service.GetBody(callbackSession)
// 获取父目录 // 查找上传会话创建的占位文件
exist, parentFolder := fs.IsPathExist(callbackSession.VirtualPath) file, err := model.GetFilesByUploadSession(uploadSession.Key, fs.User.ID)
if !exist {
newFolder, err := fs.CreateDirectory(context.Background(), callbackSession.VirtualPath)
if err != nil { if err != nil {
return serializer.Err(serializer.CodeParamErr, "指定目录不存在", err) return serializer.Err(serializer.CodeUploadSessionExpired, "LocalUpload session file placeholder not exist", err)
} }
parentFolder = newFolder
fileData := fsctx.FileStream{
Size: uploadSession.Size,
Name: uploadSession.Name,
VirtualPath: uploadSession.VirtualPath,
SavePath: uploadSession.SavePath,
Mode: fsctx.Nop,
Model: file,
LastModified: uploadSession.LastModified,
} }
// 创建文件头 // 占位符未扣除容量需要校验和扣除
fileHeader := fsctx.FileStream{ if !fs.Policy.IsUploadPlaceholderWithSize() {
Size: callbackBody.Size, fs.Use("AfterUpload", filesystem.HookValidateCapacity)
VirtualPath: callbackSession.VirtualPath, fs.Use("AfterUpload", filesystem.HookChunkUploaded)
Name: callbackSession.Name,
SavePath: callbackBody.SourceName,
} }
// 添加钩子 fs.Use("AfterUpload", filesystem.HookPopPlaceholderToFile(callbackBody.PicInfo))
fs.Use("BeforeAddFile", filesystem.HookValidateFile)
fs.Use("BeforeAddFile", filesystem.HookValidateCapacity)
fs.Use("AfterValidateFailed", filesystem.HookDeleteTempFile) fs.Use("AfterValidateFailed", filesystem.HookDeleteTempFile)
fs.Use("BeforeAddFileFailed", filesystem.HookDeleteTempFile) err = fs.Upload(context.Background(), &fileData)
// 向数据库中添加文件
file, err := fs.AddFile(context.Background(), parentFolder, &fileHeader)
if err != nil { if err != nil {
return serializer.Err(serializer.CodeUploadFailed, err.Error(), err) return serializer.Err(serializer.CodeUploadFailed, err.Error(), err)
} }
// 如果是图片,则更新图片信息 return serializer.Response{}
if callbackBody.PicInfo != "" {
if err := file.UpdatePicInfo(callbackBody.PicInfo); err != nil {
util.Log().Debug("无法更新回调文件的图片信息:%s", err)
}
}
return serializer.Response{
Code: 0,
}
} }
// PreProcess 对OneDrive客户端回调进行预处理验证 // PreProcess 对OneDrive客户端回调进行预处理验证

@ -13,6 +13,7 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/task" "github.com/cloudreve/Cloudreve/v3/pkg/task"
"github.com/cloudreve/Cloudreve/v3/pkg/task/slavetask" "github.com/cloudreve/Cloudreve/v3/pkg/task/slavetask"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"net/http" "net/http"
@ -172,6 +173,10 @@ type SlaveCreateUploadSessionService struct {
// Create 从机创建上传会话 // Create 从机创建上传会话
func (service *SlaveCreateUploadSessionService) Create(ctx context.Context, c *gin.Context) serializer.Response { func (service *SlaveCreateUploadSessionService) Create(ctx context.Context, c *gin.Context) serializer.Response {
if util.Exists(service.Session.SavePath) {
return serializer.Err(serializer.CodeConflict, "placeholder file already exist", nil)
}
err := cache.Set( err := cache.Set(
filesystem.UploadSessionCachePrefix+service.Session.Key, filesystem.UploadSessionCachePrefix+service.Session.Key,
service.Session, service.Session,

@ -7,6 +7,7 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/hashid" "github.com/cloudreve/Cloudreve/v3/pkg/hashid"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
@ -87,13 +88,13 @@ func (service *UploadService) LocalUpload(ctx context.Context, c *gin.Context) s
} }
if uploadSession.UID != fs.User.ID { if uploadSession.UID != fs.User.ID {
return serializer.Err(serializer.CodeUploadSessionExpired, "LocalUpload session expired or not exist", nil) return serializer.Err(serializer.CodeUploadSessionExpired, "Local upload session expired or not exist", nil)
} }
// 查找上传会话创建的占位文件 // 查找上传会话创建的占位文件
file, err := model.GetFilesByUploadSession(service.ID, fs.User.ID) file, err := model.GetFilesByUploadSession(service.ID, fs.User.ID)
if err != nil { if err != nil {
return serializer.Err(serializer.CodeUploadSessionExpired, "LocalUpload session file placeholder not exist", err) return serializer.Err(serializer.CodeUploadSessionExpired, "Local upload session file placeholder not exist", err)
} }
// 重设 fs 存储策略 // 重设 fs 存储策略
@ -120,14 +121,14 @@ func (service *UploadService) LocalUpload(ctx context.Context, c *gin.Context) s
util.Log().Info("尝试上传覆盖分片[%d] Start=%d", service.Index, actualSizeStart) util.Log().Info("尝试上传覆盖分片[%d] Start=%d", service.Index, actualSizeStart)
} }
return processChunkUpload(ctx, c, fs, &uploadSession, service.Index, file, fsctx.Append|fsctx.Overwrite) return processChunkUpload(ctx, c, fs, &uploadSession, service.Index, file, fsctx.Append)
} }
// SlaveUpload 处理从机文件分片上传 // SlaveUpload 处理从机文件分片上传
func (service *UploadService) SlaveUpload(ctx context.Context, c *gin.Context) serializer.Response { func (service *UploadService) SlaveUpload(ctx context.Context, c *gin.Context) serializer.Response {
uploadSessionRaw, ok := cache.Get(filesystem.UploadSessionCachePrefix + service.ID) uploadSessionRaw, ok := cache.Get(filesystem.UploadSessionCachePrefix + service.ID)
if !ok { if !ok {
return serializer.Err(serializer.CodeUploadSessionExpired, "LocalUpload session expired or not exist", nil) return serializer.Err(serializer.CodeUploadSessionExpired, "Slave upload session expired or not exist", nil)
} }
uploadSession := uploadSessionRaw.(serializer.UploadSession) uploadSession := uploadSessionRaw.(serializer.UploadSession)
@ -137,6 +138,8 @@ func (service *UploadService) SlaveUpload(ctx context.Context, c *gin.Context) s
return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err)
} }
fs.Handler = local.Driver{}
// 解析需要的参数 // 解析需要的参数
service.Index, _ = strconv.Atoi(c.Query("chunk")) service.Index, _ = strconv.Atoi(c.Query("chunk"))
mode := fsctx.Append mode := fsctx.Append
@ -165,6 +168,11 @@ func processChunkUpload(ctx context.Context, c *gin.Context, fs *filesystem.File
) )
} }
// 非首个分片时需要允许覆盖
if index > 0 {
mode |= fsctx.Overwrite
}
fileData := fsctx.FileStream{ fileData := fsctx.FileStream{
MIMEType: c.Request.Header.Get("Content-Type"), MIMEType: c.Request.Header.Get("Content-Type"),
File: c.Request.Body, File: c.Request.Body,
@ -187,13 +195,14 @@ func processChunkUpload(ctx context.Context, c *gin.Context, fs *filesystem.File
fs.Use("AfterUpload", filesystem.HookChunkUploaded) fs.Use("AfterUpload", filesystem.HookChunkUploaded)
fs.Use("AfterValidateFailed", filesystem.HookChunkUploadFailed) fs.Use("AfterValidateFailed", filesystem.HookChunkUploadFailed)
if isLastChunk { if isLastChunk {
fs.Use("AfterUpload", filesystem.HookChunkUploadFinished) fs.Use("AfterUpload", filesystem.HookPopPlaceholderToFile(""))
fs.Use("AfterUpload", filesystem.HookGenerateThumb) fs.Use("AfterUpload", filesystem.HookGenerateThumb)
fs.Use("AfterUpload", filesystem.HookDeleteUploadSession(session.Key)) fs.Use("AfterUpload", filesystem.HookDeleteUploadSession(session.Key))
} }
} else { } else {
if isLastChunk { if isLastChunk {
fs.Use("AfterUpload", filesystem.SlaveAfterUpload) fs.Use("AfterUpload", filesystem.SlaveAfterUpload(session))
fs.Use("AfterUpload", filesystem.HookDeleteUploadSession(session.Key))
} }
} }
@ -224,7 +233,7 @@ func (service *UploadSessionService) Delete(ctx context.Context, c *gin.Context)
// 查找需要删除的上传会话的占位文件 // 查找需要删除的上传会话的占位文件
file, err := model.GetFilesByUploadSession(service.ID, fs.User.ID) file, err := model.GetFilesByUploadSession(service.ID, fs.User.ID)
if err != nil { if err != nil {
return serializer.Err(serializer.CodeUploadSessionExpired, "LocalUpload session file placeholder not exist", err) return serializer.Err(serializer.CodeUploadSessionExpired, "Local Upload session file placeholder not exist", err)
} }
// 删除文件 // 删除文件
@ -235,6 +244,28 @@ func (service *UploadSessionService) Delete(ctx context.Context, c *gin.Context)
return serializer.Response{} return serializer.Response{}
} }
// SlaveDelete 从机删除指定上传会话
func (service *UploadSessionService) SlaveDelete(ctx context.Context, c *gin.Context) serializer.Response {
// 创建文件系统
fs, err := filesystem.NewAnonymousFileSystem()
if err != nil {
return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err)
}
defer fs.Recycle()
session, ok := cache.Get(filesystem.UploadSessionCachePrefix + service.ID)
if !ok {
return serializer.Err(serializer.CodeUploadSessionExpired, "Slave Upload session file placeholder not exist", nil)
}
if _, err := fs.Handler.Delete(ctx, []string{session.(serializer.UploadSession).SavePath}); err != nil {
return serializer.Err(serializer.CodeInternalSetting, "Failed to delete temp file", err)
}
cache.Deletes([]string{service.ID}, filesystem.UploadSessionCachePrefix)
return serializer.Response{}
}
// DeleteAllUploadSession 删除当前用户的全部上传绘会话 // DeleteAllUploadSession 删除当前用户的全部上传绘会话
func DeleteAllUploadSession(ctx context.Context, c *gin.Context) serializer.Response { func DeleteAllUploadSession(ctx context.Context, c *gin.Context) serializer.Response {
// 创建文件系统 // 创建文件系统

Loading…
Cancel
Save